BLI: refactor IndexMask for better performance and memory usage #104629
@ -588,10 +588,11 @@ constexpr bool has_mask_segment_and_start_parameter =
|
||||
|
||||
template<typename Fn> inline void IndexMask::foreach_index(Fn &&fn) const
|
||||
{
|
||||
this->foreach_span([&](const OffsetSpan<int64_t, int16_t> indices, const int64_t start) {
|
||||
this->foreach_span([&](const OffsetSpan<int64_t, int16_t> indices,
|
||||
[[maybe_unused]] const int64_t start_mask_position) {
|
||||
if constexpr (std::is_invocable_r_v<void, Fn, int64_t, int64_t>) {
|
||||
for (const int64_t i : indices.index_range()) {
|
||||
fn(indices[i], start + i);
|
||||
fn(indices[i], start_mask_position + i);
|
||||
}
|
||||
}
|
||||
else {
|
||||
@ -607,9 +608,9 @@ inline void IndexMask::foreach_index(const GrainSize grain_size, Fn &&fn) const
|
||||
{
|
||||
threading::parallel_for(this->index_range(), grain_size.value, [&](const IndexRange range) {
|
||||
const IndexMask sub_mask = this->slice(range);
|
||||
sub_mask.foreach_index([&](const int64_t i, const int64_t i_in_mask) {
|
||||
sub_mask.foreach_index([&](const int64_t i, [[maybe_unused]] const int64_t mask_position) {
|
||||
if constexpr (std::is_invocable_r_v<void, Fn, int64_t, int64_t>) {
|
||||
fn(i, i_in_mask + range.start());
|
||||
fn(i, mask_position + range.start());
|
||||
}
|
||||
else {
|
||||
fn(i);
|
||||
@ -637,26 +638,27 @@ template<typename Fn>
|
||||
[[gnu::optimize("-funroll-loops")]] [[gnu::optimize("O3")]]
|
||||
#endif
|
||||
inline void
|
||||
foreach_index_in_range(const IndexRange range, const int64_t offset, Fn &&fn)
|
||||
foreach_index_in_range(const IndexRange range, const int64_t start_mask_position, Fn &&fn)
|
||||
{
|
||||
const int64_t start = range.start();
|
||||
const int64_t end = range.one_after_last();
|
||||
for (int64_t i = start, mask_i = offset; i < end; i++, mask_i++) {
|
||||
fn(i, mask_i);
|
||||
for (int64_t i = start, mask_position = start_mask_position; i < end; i++, mask_position++) {
|
||||
fn(i, mask_position);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Fn> inline void IndexMask::foreach_index_optimized(Fn &&fn) const
|
||||
{
|
||||
this->foreach_span_or_range([&](const auto mask_segment, const int64_t start) {
|
||||
this->foreach_span_or_range(
|
||||
[&](const auto mask_segment, [[maybe_unused]] const int64_t start_mask_position) {
|
||||
constexpr bool is_range = std::is_same_v<std::decay_t<decltype(mask_segment)>, IndexRange>;
|
||||
if constexpr (std::is_invocable_r_v<void, Fn, int64_t, int64_t>) {
|
||||
if constexpr (is_range) {
|
||||
foreach_index_in_range(mask_segment, start, fn);
|
||||
foreach_index_in_range(mask_segment, start_mask_position, fn);
|
||||
}
|
||||
else {
|
||||
for (const int64_t i : mask_segment.index_range()) {
|
||||
fn(mask_segment[i], start + i);
|
||||
fn(mask_segment[i], start_mask_position + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -680,7 +682,9 @@ inline void IndexMask::foreach_index_optimized(const GrainSize grain_size, Fn &&
|
||||
const IndexMask sub_mask = this->slice(range);
|
||||
if constexpr (std::is_invocable_r_v<void, Fn, int64_t, int64_t>) {
|
||||
sub_mask.foreach_index_optimized(
|
||||
[&](const int64_t i, const int64_t i_in_mask) { fn(i, i_in_mask + range.start()); });
|
||||
[&fn, range_start = range.start()](const int64_t i, const int64_t mask_position) {
|
||||
fn(i, mask_position + range_start);
|
||||
});
|
||||
}
|
||||
else {
|
||||
sub_mask.foreach_index_optimized(fn);
|
||||
@ -690,11 +694,12 @@ inline void IndexMask::foreach_index_optimized(const GrainSize grain_size, Fn &&
|
||||
|
||||
template<typename Fn> inline void IndexMask::foreach_span_or_range(Fn &&fn) const
|
||||
{
|
||||
this->foreach_span([&](const OffsetSpan<int64_t, int16_t> mask_segment, const int64_t start) {
|
||||
this->foreach_span([&](const OffsetSpan<int64_t, int16_t> mask_segment,
|
||||
[[maybe_unused]] const int64_t start_mask_position) {
|
||||
if (unique_sorted_indices::non_empty_is_range(mask_segment.base_span())) {
|
||||
const IndexRange range(mask_segment[0], mask_segment.size());
|
||||
if constexpr (has_mask_segment_and_start_parameter<Fn>) {
|
||||
fn(range, start);
|
||||
fn(range, start_mask_position);
|
||||
}
|
||||
else {
|
||||
fn(range);
|
||||
@ -702,7 +707,7 @@ template<typename Fn> inline void IndexMask::foreach_span_or_range(Fn &&fn) cons
|
||||
}
|
||||
else {
|
||||
if constexpr (has_mask_segment_and_start_parameter<Fn>) {
|
||||
fn(mask_segment, start);
|
||||
fn(mask_segment, start_mask_position);
|
||||
}
|
||||
else {
|
||||
fn(mask_segment);
|
||||
@ -716,9 +721,11 @@ inline void IndexMask::foreach_span_or_range(const GrainSize grain_size, Fn &&fn
|
||||
{
|
||||
threading::parallel_for(this->index_range(), grain_size.value, [&](const IndexRange range) {
|
||||
const IndexMask sub_mask = this->slice(range);
|
||||
sub_mask.foreach_span_or_range([&](const auto mask_segment, const int64_t start) {
|
||||
sub_mask.foreach_span_or_range(
|
||||
[&fn, range_start = range.start()](const auto mask_segment,
|
||||
[[maybe_unused]] const int64_t start_mask_position) {
|
||||
if constexpr (has_mask_segment_and_start_parameter<Fn>) {
|
||||
fn(mask_segment, start + range.start());
|
||||
fn(mask_segment, start_mask_position + range_start);
|
||||
}
|
||||
else {
|
||||
fn(mask_segment);
|
||||
@ -729,12 +736,12 @@ inline void IndexMask::foreach_span_or_range(const GrainSize grain_size, Fn &&fn
|
||||
|
||||
template<typename Fn> inline void IndexMask::foreach_span(Fn &&fn) const
|
||||
{
|
||||
[[maybe_unused]] int64_t counter = 0;
|
||||
[[maybe_unused]] int64_t mask_position = 0;
|
||||
for (const int64_t segment_i : IndexRange(segments_num_)) {
|
||||
const OffsetSpan<int64_t, int16_t> segment = this->segment(segment_i);
|
||||
if constexpr (has_mask_segment_and_start_parameter<Fn>) {
|
||||
fn(segment, counter);
|
||||
counter += segment.size();
|
||||
fn(segment, mask_position);
|
||||
mask_position += segment.size();
|
||||
}
|
||||
else {
|
||||
fn(segment);
|
||||
@ -748,9 +755,10 @@ inline void IndexMask::foreach_span(const GrainSize grain_size, Fn &&fn) const
|
||||
threading::parallel_for(this->index_range(), grain_size.value, [&](const IndexRange range) {
|
||||
const IndexMask sub_mask = this->slice(range);
|
||||
sub_mask.foreach_span(
|
||||
[&](const OffsetSpan<int64_t, int16_t> mask_segment, const int64_t start) {
|
||||
[&fn, range_start = range.start()](const OffsetSpan<int64_t, int16_t> mask_segment,
|
||||
[[maybe_unused]] const int64_t start_mask_position) {
|
||||
if constexpr (has_mask_segment_and_start_parameter<Fn>) {
|
||||
fn(mask_segment, start + range.start());
|
||||
fn(mask_segment, start_mask_position + range_start);
|
||||
}
|
||||
else {
|
||||
fn(mask_segment);
|
||||
@ -761,18 +769,19 @@ inline void IndexMask::foreach_span(const GrainSize grain_size, Fn &&fn) const
|
||||
|
||||
template<typename Fn> inline void IndexMask::foreach_range(Fn &&fn) const
|
||||
{
|
||||
this->foreach_span([&](const OffsetSpan<int64_t, int16_t> indices, int64_t start) {
|
||||
this->foreach_span([&](const OffsetSpan<int64_t, int16_t> indices,
|
||||
[[maybe_unused]] int64_t start_mask_position) {
|
||||
Span<int16_t> base_indices = indices.base_span();
|
||||
while (!base_indices.is_empty()) {
|
||||
const int64_t next_range_size = unique_sorted_indices::find_size_of_next_range(base_indices);
|
||||
const IndexRange range(int64_t(base_indices[0]) + indices.offset(), next_range_size);
|
||||
if constexpr (has_mask_segment_and_start_parameter<Fn>) {
|
||||
fn(range, start);
|
||||
fn(range, start_mask_position);
|
||||
}
|
||||
else {
|
||||
fn(range);
|
||||
}
|
||||
start += next_range_size;
|
||||
start_mask_position += next_range_size;
|
||||
base_indices = base_indices.drop_front(next_range_size);
|
||||
}
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user