Refactoring: Geometry Node: Avoid copy last buffer in result for Blur Attribute node #106860

Merged
Jacques Lucke merged 7 commits from mod_moder/blender:blur_avoid_extra_copy into main 2023-04-17 08:09:15 +02:00
1 changed files with 50 additions and 68 deletions
Showing only changes of commit 5581c33d42 - Show all commits

View File

@ -246,17 +246,17 @@ static Array<Vector<int>> create_mesh_map(const Mesh &mesh,
} }
template<typename T> template<typename T>
static void blur_on_mesh_exec(const Span<float> neighbor_weights, static Span<T> blur_on_mesh_exec(const Span<float> neighbor_weights,
const Span<Vector<int>> neighbors_map, const Span<Vector<int>> neighbors_map,
const int iterations, const int iterations,
const MutableSpan<T> main_buffer, const MutableSpan<T> main_buffer,
const MutableSpan<T> tmp_buffer, const MutableSpan<T> tmp_buffer)
bool &r_tmp_buffer_is_result)
{ {
MutableSpan<T> src = main_buffer; MutableSpan<T> src = tmp_buffer;
mod_moder marked this conversation as resolved Outdated

Add comment Source is set to buffer_b even though it is actually in buffer_a because the loop below starts with swapping both.

Add comment `Source is set to buffer_b even though it is actually in buffer_a because the loop below starts with swapping both.`
MutableSpan<T> dst = tmp_buffer; MutableSpan<T> dst = main_buffer;
for ([[maybe_unused]] const int64_t iteration : IndexRange(iterations)) { for ([[maybe_unused]] const int64_t iteration : IndexRange(iterations)) {
std::swap(src, dst);
attribute_math::DefaultMixer<T> mixer{dst, IndexMask(0)}; attribute_math::DefaultMixer<T> mixer{dst, IndexMask(0)};
threading::parallel_for(dst.index_range(), 1024, [&](const IndexRange range) { threading::parallel_for(dst.index_range(), 1024, [&](const IndexRange range) {
for (const int64_t index : range) { for (const int64_t index : range) {
@ -269,54 +269,51 @@ static void blur_on_mesh_exec(const Span<float> neighbor_weights,
} }
mixer.finalize(range); mixer.finalize(range);
}); });
std::swap(src, dst);
} }
if (dst.data() == main_buffer.data()) { return dst;
r_tmp_buffer_is_result = true;
}
} }
static void blur_on_mesh(const Mesh &mesh, static GSpan blur_on_mesh(const Mesh &mesh,
const eAttrDomain domain, const eAttrDomain domain,
const int iterations, const int iterations,
const Span<float> neighbor_weights, const Span<float> neighbor_weights,
const GMutableSpan main_buffer, const GMutableSpan main_buffer,
const GMutableSpan tmp_buffer, const GMutableSpan tmp_buffer)
bool &tmp_buffer_is_result)
{ {
Array<Vector<int>> neighbors_map = create_mesh_map(mesh, domain, neighbor_weights.index_range()); Array<Vector<int>> neighbors_map = create_mesh_map(mesh, domain, neighbor_weights.index_range());
if (neighbors_map.is_empty()) { if (neighbors_map.is_empty()) {
return; return main_buffer;
} }
GSpan result_buffer;
attribute_math::convert_to_static_type(main_buffer.type(), [&](auto dummy) { attribute_math::convert_to_static_type(main_buffer.type(), [&](auto dummy) {
using T = decltype(dummy); using T = decltype(dummy);
if constexpr (!std::is_same_v<T, bool>) { if constexpr (!std::is_same_v<T, bool>) {
blur_on_mesh_exec<T>(neighbor_weights, result_buffer = blur_on_mesh_exec<T>(neighbor_weights,
neighbors_map, neighbors_map,
iterations, iterations,
main_buffer.typed<T>(), main_buffer.typed<T>(),
tmp_buffer.typed<T>(), tmp_buffer.typed<T>());
tmp_buffer_is_result);
} }
}); });
return result_buffer;
} }
template<typename T> template<typename T>
static void blur_on_curve_exec(const bke::CurvesGeometry &curves, static Span<T> blur_on_curve_exec(const bke::CurvesGeometry &curves,
const Span<float> neighbor_weights, const Span<float> neighbor_weights,
const int iterations, const int iterations,
const MutableSpan<T> main_buffer, const MutableSpan<T> main_buffer,
const MutableSpan<T> tmp_buffer, const MutableSpan<T> tmp_buffer)
bool &r_tmp_buffer_is_result)
{ {
MutableSpan<T> src = main_buffer; MutableSpan<T> src = tmp_buffer;
MutableSpan<T> dst = tmp_buffer; MutableSpan<T> dst = main_buffer;
const OffsetIndices points_by_curve = curves.points_by_curve(); const OffsetIndices points_by_curve = curves.points_by_curve();
const VArray<bool> cyclic = curves.cyclic(); const VArray<bool> cyclic = curves.cyclic();
for ([[maybe_unused]] const int iteration : IndexRange(iterations)) { for ([[maybe_unused]] const int iteration : IndexRange(iterations)) {
std::swap(src, dst);
attribute_math::DefaultMixer<T> mixer{dst, IndexMask(0)}; attribute_math::DefaultMixer<T> mixer{dst, IndexMask(0)};
threading::parallel_for(curves.curves_range(), 256, [&](const IndexRange range) { threading::parallel_for(curves.curves_range(), 256, [&](const IndexRange range) {
for (const int curve_i : range) { for (const int curve_i : range) {
@ -355,32 +352,26 @@ static void blur_on_curve_exec(const bke::CurvesGeometry &curves,
} }
mixer.finalize(points_by_curve[range]); mixer.finalize(points_by_curve[range]);
}); });
std::swap(src, dst);
} }
if (dst.data() == main_buffer.data()) { return dst;
r_tmp_buffer_is_result = true;
}
} }
static void blur_on_curves(const bke::CurvesGeometry &curves, static GSpan blur_on_curves(const bke::CurvesGeometry &curves,
const int iterations, const int iterations,
const Span<float> neighbor_weights, const Span<float> neighbor_weights,
const GMutableSpan main_buffer, const GMutableSpan main_buffer,
const GMutableSpan tmp_buffer, const GMutableSpan tmp_buffer)
bool &r_tmp_buffer_is_result)
{ {
GSpan result_buffer;
attribute_math::convert_to_static_type(main_buffer.type(), [&](auto dummy) { attribute_math::convert_to_static_type(main_buffer.type(), [&](auto dummy) {
using T = decltype(dummy); using T = decltype(dummy);
if constexpr (!std::is_same_v<T, bool>) { if constexpr (!std::is_same_v<T, bool>) {
blur_on_curve_exec<T>(curves, result_buffer = blur_on_curve_exec<T>(
neighbor_weights, curves, neighbor_weights, iterations, main_buffer.typed<T>(), tmp_buffer.typed<T>());
iterations,
main_buffer.typed<T>(),
tmp_buffer.typed<T>(),
r_tmp_buffer_is_result);
} }
}); });
return result_buffer;
} }
class BlurAttributeFieldInput final : public bke::GeometryFieldInput { class BlurAttributeFieldInput final : public bke::GeometryFieldInput {
@ -423,31 +414,21 @@ class BlurAttributeFieldInput final : public bke::GeometryFieldInput {
VArraySpan<float> neighbor_weights = evaluator.get_evaluated<float>(1); VArraySpan<float> neighbor_weights = evaluator.get_evaluated<float>(1);
GArray<> tmp_buffer(*type_, domain_size); GArray<> tmp_buffer(*type_, domain_size);
bool tmp_buffer_is_result = false; GSpan result_buffer;
switch (context.type()) { switch (context.type()) {
case GEO_COMPONENT_TYPE_MESH: case GEO_COMPONENT_TYPE_MESH:
if (ELEM(context.domain(), ATTR_DOMAIN_POINT, ATTR_DOMAIN_EDGE, ATTR_DOMAIN_FACE)) { if (ELEM(context.domain(), ATTR_DOMAIN_POINT, ATTR_DOMAIN_EDGE, ATTR_DOMAIN_FACE)) {
if (const Mesh *mesh = context.mesh()) { if (const Mesh *mesh = context.mesh()) {
blur_on_mesh(*mesh, result_buffer = blur_on_mesh(
context.domain(), *mesh, context.domain(), iterations_, neighbor_weights, main_buffer, tmp_buffer);
iterations_,
neighbor_weights,
main_buffer,
tmp_buffer,
tmp_buffer_is_result);
} }
} }
break; break;
case GEO_COMPONENT_TYPE_CURVE: case GEO_COMPONENT_TYPE_CURVE:
if (context.domain() == ATTR_DOMAIN_POINT) { if (context.domain() == ATTR_DOMAIN_POINT) {
if (const bke::CurvesGeometry *curves = context.curves()) { if (const bke::CurvesGeometry *curves = context.curves()) {
blur_on_curves(*curves, result_buffer = blur_on_curves(
iterations_, *curves, iterations_, neighbor_weights, main_buffer, tmp_buffer);
neighbor_weights,
main_buffer,
tmp_buffer,
tmp_buffer_is_result);
} }
} }
break; break;
@ -455,10 +436,11 @@ class BlurAttributeFieldInput final : public bke::GeometryFieldInput {
break; break;
} }
if (tmp_buffer_is_result) { BLI_assert(ELEM(result_buffer.data(), tmp_buffer.data(), main_buffer.data()));
return GVArray::ForGArray(std::move(tmp_buffer)); if (result_buffer.data() == main_buffer.data()) {
return GVArray::ForGArray(std::move(main_buffer));
} }
return GVArray::ForGArray(std::move(main_buffer)); return GVArray::ForGArray(std::move(tmp_buffer));
} }
void for_each_field_input_recursive(FunctionRef<void(const FieldInput &)> fn) const override void for_each_field_input_recursive(FunctionRef<void(const FieldInput &)> fn) const override