WIP: Functions: support dynamically added channels in lazy functions #124299

Draft
Jacques Lucke wants to merge 2 commits from JacquesLucke/blender:lazy-function-dynamic-channels into main

When changing the target branch, be careful to rebase the branch in your fork to match. See documentation.
4 changed files with 346 additions and 139 deletions

View File

@ -126,6 +126,43 @@ struct Context {
}
};
struct ChannelID {
private:
int index_;
public:
ChannelID() : index_(-1) {}
explicit ChannelID(const int index) : index_(index)
{
BLI_assert(index >= 0);
}
static ChannelID from_main_or_dynamic_index(const int index)
{
BLI_assert(index >= -1);
ChannelID channel;
channel.index_ = index;
return channel;
}
bool is_main() const
{
return index_ < 0;
}
bool is_dynamic() const
{
return index_ >= 0;
}
int dynamic_index() const
{
BLI_assert(this->is_dynamic());
return index_;
}
};
/**
* Defines the calling convention for a lazy-function. During execution, a lazy-function retrieves
* its inputs and sets the outputs through #Params.
@ -149,13 +186,13 @@ class Params {
*
* The #LazyFunction must leave returned object in an initialized state, but can move from it.
*/
void *try_get_input_data_ptr(int index) const;
void *try_get_input_data_ptr(int index, ChannelID channel = {}) const;
/**
* Same as #try_get_input_data_ptr, but if the data is not yet available, request it. This makes
* sure that the data will be available in a future execution of the #LazyFunction.
*/
void *try_get_input_data_ptr_or_request(int index);
void *try_get_input_data_ptr_or_request(int index, ChannelID channel = {});
/**
* Get a pointer to where the output value should be stored.
@ -163,39 +200,44 @@ class Params {
* The #LazyFunction is responsible for initializing the value.
* After the output has been initialized to its final value, #output_set has to be called.
*/
void *get_output_data_ptr(int index);
void *get_output_data_ptr(int index, ChannelID channel = {});
/**
* Call this after the output value is initialized. After this is called, the value must not be
* touched anymore. It may be moved or destructed immediately.
*/
void output_set(int index);
void output_set(int index, ChannelID channel = {});
/**
* Allows the #LazyFunction to check whether an output was computed already without keeping
* track of it itself.
*/
bool output_was_set(int index) const;
bool output_was_set(int index, ChannelID channel = {}) const;
/**
* Can be used to detect which outputs have to be computed.
*/
ValueUsage get_output_usage(int index) const;
ValueUsage get_output_usage(int index, ChannelID channel = {}) const;
/**
* Tell the caller of the #LazyFunction that a specific input will definitely not be used.
* Only an input that was not #ValueUsage::Used can become unused.
*/
void set_input_unused(int index);
void set_input_unused(int index, ChannelID channel = {});
int get_input_channels_num(int index) const;
int get_output_channels_num(int index) const;
IndexRange add_output_channels(int index, int num);
/**
* Typed utility methods that wrap the methods above.
*/
template<typename T> T extract_input(int index);
template<typename T> T &get_input(int index) const;
template<typename T> T *try_get_input_data_ptr(int index) const;
template<typename T> T *try_get_input_data_ptr_or_request(int index);
template<typename T> T extract_input(int index, ChannelID channel = {});
template<typename T> T &get_input(int index, ChannelID channel = {}) const;
template<typename T> T *try_get_input_data_ptr(int index, ChannelID channel = {}) const;
template<typename T> T *try_get_input_data_ptr_or_request(int index, ChannelID channel = {});
template<typename T> void set_output(int index, T &&value);
template<typename T> void set_output(int index, ChannelID channel, T &&value);
/**
* Returns true when the lazy-function is now allowed to use multi-threading when interacting
@ -211,13 +253,18 @@ class Params {
* methods above to make it easy to insert additional debugging logic on top of the
* implementations.
*/
virtual void *try_get_input_data_ptr_impl(int index) const = 0;
virtual void *try_get_input_data_ptr_or_request_impl(int index) = 0;
virtual void *get_output_data_ptr_impl(int index) = 0;
virtual void output_set_impl(int index) = 0;
virtual bool output_was_set_impl(int index) const = 0;
virtual ValueUsage get_output_usage_impl(int index) const = 0;
virtual void set_input_unused_impl(int index) = 0;
virtual void *try_get_input_data_ptr_impl(int index, ChannelID channel) const = 0;
virtual void *try_get_input_data_ptr_or_request_impl(int index, ChannelID channel) = 0;
virtual void *get_output_data_ptr_impl(int index, ChannelID channel) = 0;
virtual void output_set_impl(int index, ChannelID channel) = 0;
virtual bool output_was_set_impl(int index, ChannelID channel) const = 0;
virtual ValueUsage get_output_usage_impl(int index, ChannelID channel) const = 0;
virtual void set_input_unused_impl(int index, ChannelID channel) = 0;
virtual int get_input_channels_num_impl(int index) const = 0;
virtual int get_output_channels_num_impl(int index) const = 0;
virtual IndexRange add_output_channels_impl(int index, int num) = 0;
virtual bool try_enable_multi_threading_impl();
};
@ -386,85 +433,115 @@ inline Params::Params(const LazyFunction &fn,
{
}
inline void *Params::try_get_input_data_ptr(const int index) const
inline void *Params::try_get_input_data_ptr(const int index, const ChannelID channel) const
{
BLI_assert(index >= 0 && index < fn_.inputs().size());
return this->try_get_input_data_ptr_impl(index);
return this->try_get_input_data_ptr_impl(index, channel);
}
inline void *Params::try_get_input_data_ptr_or_request(const int index)
inline void *Params::try_get_input_data_ptr_or_request(const int index, const ChannelID channel)
{
BLI_assert(index >= 0 && index < fn_.inputs().size());
this->assert_valid_thread();
return this->try_get_input_data_ptr_or_request_impl(index);
return this->try_get_input_data_ptr_or_request_impl(index, channel);
}
inline void *Params::get_output_data_ptr(const int index)
inline void *Params::get_output_data_ptr(const int index, const ChannelID channel)
{
BLI_assert(index >= 0 && index < fn_.outputs().size());
this->assert_valid_thread();
return this->get_output_data_ptr_impl(index);
return this->get_output_data_ptr_impl(index, channel);
}
inline void Params::output_set(const int index)
inline void Params::output_set(const int index, const ChannelID channel)
{
BLI_assert(index >= 0 && index < fn_.outputs().size());
this->assert_valid_thread();
this->output_set_impl(index);
this->output_set_impl(index, channel);
}
inline bool Params::output_was_set(const int index) const
inline bool Params::output_was_set(const int index, const ChannelID channel) const
{
BLI_assert(index >= 0 && index < fn_.outputs().size());
return this->output_was_set_impl(index);
return this->output_was_set_impl(index, channel);
}
inline ValueUsage Params::get_output_usage(const int index) const
inline ValueUsage Params::get_output_usage(const int index, const ChannelID channel) const
{
BLI_assert(index >= 0 && index < fn_.outputs().size());
return this->get_output_usage_impl(index);
return this->get_output_usage_impl(index, channel);
}
inline void Params::set_input_unused(const int index)
inline void Params::set_input_unused(const int index, const ChannelID channel)
{
BLI_assert(index >= 0 && index < fn_.inputs().size());
this->assert_valid_thread();
this->set_input_unused_impl(index);
this->set_input_unused_impl(index, channel);
}
template<typename T> inline T Params::extract_input(const int index)
inline int Params::get_input_channels_num(const int index) const
{
BLI_assert(index >= 0 && index < fn_.inputs().size());
this->assert_valid_thread();
return this->get_input_channels_num_impl(index);
}
inline int Params::get_output_channels_num(const int index) const
{
BLI_assert(index >= 0 && index < fn_.outputs().size());
this->assert_valid_thread();
return this->get_output_channels_num_impl(index);
}
inline IndexRange Params::add_output_channels(const int index, const int num)
{
BLI_assert(num >= 0);
BLI_assert(index >= 0 && index < fn_.inputs().size());
this->assert_valid_thread();
return this->add_output_channels_impl(index, num);
}
template<typename T> inline T Params::extract_input(const int index, const ChannelID channel)
{
this->assert_valid_thread();
void *data = this->try_get_input_data_ptr(index);
void *data = this->try_get_input_data_ptr(index, channel);
BLI_assert(data != nullptr);
T return_value = std::move(*static_cast<T *>(data));
return return_value;
}
template<typename T> inline T &Params::get_input(const int index) const
template<typename T> inline T &Params::get_input(const int index, const ChannelID channel) const
{
void *data = this->try_get_input_data_ptr(index);
void *data = this->try_get_input_data_ptr(index, channel);
BLI_assert(data != nullptr);
return *static_cast<T *>(data);
}
template<typename T> inline T *Params::try_get_input_data_ptr(const int index) const
template<typename T>
inline T *Params::try_get_input_data_ptr(const int index, const ChannelID channel) const
{
this->assert_valid_thread();
return static_cast<T *>(this->try_get_input_data_ptr(index));
return static_cast<T *>(this->try_get_input_data_ptr(index, channel));
}
template<typename T> inline T *Params::try_get_input_data_ptr_or_request(const int index)
template<typename T>
inline T *Params::try_get_input_data_ptr_or_request(const int index, const ChannelID channel)
{
this->assert_valid_thread();
return static_cast<T *>(this->try_get_input_data_ptr_or_request(index));
return static_cast<T *>(this->try_get_input_data_ptr_or_request(index, channel));
}
template<typename T> inline void Params::set_output(const int index, T &&value)
{
this->set_output(index, ChannelID(), std::forward<T>(value));
}
template<typename T>
inline void Params::set_output(const int index, const ChannelID channel, T &&value)
{
using DecayT = std::decay_t<T>;
this->assert_valid_thread();
void *data = this->get_output_data_ptr(index);
void *data = this->get_output_data_ptr(index, channel);
new (data) DecayT(std::forward<T>(value));
this->output_set(index);
}

View File

@ -36,13 +36,18 @@ class BasicParams : public Params {
Span<ValueUsage> output_usages,
MutableSpan<bool> set_outputs);
void *try_get_input_data_ptr_impl(const int index) const override;
void *try_get_input_data_ptr_or_request_impl(const int index) override;
void *get_output_data_ptr_impl(const int index) override;
void output_set_impl(const int index) override;
bool output_was_set_impl(const int index) const override;
ValueUsage get_output_usage_impl(const int index) const override;
void set_input_unused_impl(const int index) override;
void *try_get_input_data_ptr_impl(int index, ChannelID channel) const override;
void *try_get_input_data_ptr_or_request_impl(int index, ChannelID channel) override;
void *get_output_data_ptr_impl(int index, ChannelID channel) override;
void output_set_impl(int index, ChannelID channel) override;
bool output_was_set_impl(int index, ChannelID channel) const override;
ValueUsage get_output_usage_impl(int index, ChannelID channel) const override;
void set_input_unused_impl(int index, ChannelID channel) override;
int get_input_channels_num_impl(int index) const override;
int get_output_channels_num_impl(int index) const override;
IndexRange add_output_channels_impl(int index, int num) override;
bool try_enable_multi_threading_impl() override;
};
@ -64,13 +69,18 @@ class RemappedParams : public Params {
Span<int> output_map,
bool &multi_threading_enabled);
void *try_get_input_data_ptr_impl(const int index) const override;
void *try_get_input_data_ptr_or_request_impl(const int index) override;
void *get_output_data_ptr_impl(const int index) override;
void output_set_impl(const int index) override;
bool output_was_set_impl(const int index) const override;
ValueUsage get_output_usage_impl(const int index) const override;
void set_input_unused_impl(const int index) override;
void *try_get_input_data_ptr_impl(int index, ChannelID channel) const override;
void *try_get_input_data_ptr_or_request_impl(int index, ChannelID channel) override;
void *get_output_data_ptr_impl(int index, ChannelID channel) override;
void output_set_impl(int index, ChannelID channel) override;
bool output_was_set_impl(int index, ChannelID channel) const override;
ValueUsage get_output_usage_impl(int index, ChannelID channel) const override;
void set_input_unused_impl(int index, ChannelID channel) override;
int get_input_channels_num_impl(int index) const override;
int get_output_channels_num_impl(int index) const override;
IndexRange add_output_channels_impl(int index, int num) override;
bool try_enable_multi_threading_impl() override;
};

View File

@ -29,12 +29,13 @@ BasicParams::BasicParams(const LazyFunction &fn,
{
}
void *BasicParams::try_get_input_data_ptr_impl(const int index) const
void *BasicParams::try_get_input_data_ptr_impl(const int index, const ChannelID /*channel*/) const
{
return inputs_[index].get();
}
void *BasicParams::try_get_input_data_ptr_or_request_impl(const int index)
void *BasicParams::try_get_input_data_ptr_or_request_impl(const int index,
const ChannelID /*channel*/)
{
void *value = inputs_[index].get();
if (value == nullptr) {
@ -43,31 +44,46 @@ void *BasicParams::try_get_input_data_ptr_or_request_impl(const int index)
return value;
}
void *BasicParams::get_output_data_ptr_impl(const int index)
void *BasicParams::get_output_data_ptr_impl(const int index, const ChannelID /*channel*/)
{
return outputs_[index].get();
}
void BasicParams::output_set_impl(const int index)
void BasicParams::output_set_impl(const int index, const ChannelID /*channel*/)
{
set_outputs_[index] = true;
}
bool BasicParams::output_was_set_impl(const int index) const
bool BasicParams::output_was_set_impl(const int index, const ChannelID /*channel*/) const
{
return set_outputs_[index];
}
ValueUsage BasicParams::get_output_usage_impl(const int index) const
ValueUsage BasicParams::get_output_usage_impl(const int index, const ChannelID /*channel*/) const
{
return output_usages_[index];
}
void BasicParams::set_input_unused_impl(const int index)
void BasicParams::set_input_unused_impl(const int index, const ChannelID /*channel*/)
{
input_usages_[index] = ValueUsage::Unused;
}
int BasicParams::get_input_channels_num_impl(const int /*index*/) const
{
return 0;
}
int BasicParams::get_output_channels_num_impl(const int /*index*/) const
{
return 0;
}
IndexRange BasicParams::add_output_channels_impl(const int /*index*/, const int /*num*/)
{
return {};
}
bool BasicParams::try_enable_multi_threading_impl()
{
return true;
@ -92,41 +108,59 @@ RemappedParams::RemappedParams(const LazyFunction &fn,
{
}
void *RemappedParams::try_get_input_data_ptr_impl(const int index) const
void *RemappedParams::try_get_input_data_ptr_impl(const int index,
const ChannelID /*channel*/) const
{
return base_params_.try_get_input_data_ptr(input_map_[index]);
}
void *RemappedParams::try_get_input_data_ptr_or_request_impl(const int index)
void *RemappedParams::try_get_input_data_ptr_or_request_impl(const int index,
const ChannelID /*channel*/)
{
return base_params_.try_get_input_data_ptr_or_request(input_map_[index]);
}
void *RemappedParams::get_output_data_ptr_impl(const int index)
void *RemappedParams::get_output_data_ptr_impl(const int index, const ChannelID /*channel*/)
{
return base_params_.get_output_data_ptr(output_map_[index]);
}
void RemappedParams::output_set_impl(const int index)
void RemappedParams::output_set_impl(const int index, const ChannelID /*channel*/)
{
return base_params_.output_set(output_map_[index]);
}
bool RemappedParams::output_was_set_impl(const int index) const
bool RemappedParams::output_was_set_impl(const int index, const ChannelID /*channel*/) const
{
return base_params_.output_was_set(output_map_[index]);
}
lf::ValueUsage RemappedParams::get_output_usage_impl(const int index) const
lf::ValueUsage RemappedParams::get_output_usage_impl(const int index,
const ChannelID /*channel*/) const
{
return base_params_.get_output_usage(output_map_[index]);
}
void RemappedParams::set_input_unused_impl(const int index)
void RemappedParams::set_input_unused_impl(const int index, const ChannelID /*channel*/)
{
return base_params_.set_input_unused(input_map_[index]);
}
int RemappedParams::get_input_channels_num_impl(const int index) const
{
return base_params_.get_input_channels_num(input_map_[index]);
}
int RemappedParams::get_output_channels_num_impl(const int index) const
{
return base_params_.get_output_channels_num(output_map_[index]);
}
IndexRange RemappedParams::add_output_channels_impl(const int index, const int num)
{
return base_params_.add_output_channels(output_map_[index], num);
}
bool RemappedParams::try_enable_multi_threading_impl()
{
if (multi_threading_enabled_) {

View File

@ -47,6 +47,7 @@
#include "BLI_compute_context.hh"
#include "BLI_enumerable_thread_specific.hh"
#include "BLI_function_ref.hh"
#include "BLI_linear_allocator_chunked_list.hh"
#include "BLI_task.h"
#include "BLI_task.hh"
#include "BLI_timeit.hh"
@ -130,6 +131,14 @@ struct OutputState {
void *value = nullptr;
};
struct DynamicChannelStates {
Vector<Vector<InputState *>> input_states;
Vector<Vector<OutputState *>> output_states;
Vector<Vector<InputState *>> input_states_for_execution;
Vector<Vector<OutputState *>> output_states_for_execution;
};
struct NodeState {
/**
* Needs to be locked when any data in this state is accessed that is not explicitly marked as
@ -144,6 +153,7 @@ struct NodeState {
*/
InputState *inputs;
OutputState *outputs;
/**
* Counts the number of inputs that still have to be provided to this node, until it should run
* again. This is used as an optimization so that nodes are not scheduled unnecessarily in many
@ -184,6 +194,8 @@ struct NodeState {
* Custom storage of the node.
*/
void *storage = nullptr;
DynamicChannelStates *dynamic_channels = nullptr;
};
/**
@ -205,8 +217,8 @@ struct LockedNode {
*
* The notifications will be send right after the node is not locked anymore.
*/
Vector<const OutputSocket *> delayed_required_outputs;
Vector<const OutputSocket *> delayed_unused_outputs;
Vector<std::pair<const OutputSocket *, ChannelID>> delayed_required_outputs;
Vector<std::pair<const OutputSocket *, ChannelID>> delayed_unused_outputs;
LockedNode(const Node &node, NodeState &node_state) : node(node), node_state(node_state) {}
};
@ -486,22 +498,28 @@ class Executor {
if (params_->output_was_set(graph_output_index)) {
continue;
}
const ValueUsage output_usage = params_->get_output_usage(graph_output_index);
if (output_usage == ValueUsage::Maybe) {
continue;
const int channels_num = params_->get_output_channels_num(graph_output_index);
for (int main_or_dynamic_index = -1; main_or_dynamic_index < channels_num;
main_or_dynamic_index++)
{
const ChannelID channel = ChannelID::from_main_or_dynamic_index(main_or_dynamic_index);
const ValueUsage output_usage = params_->get_output_usage(graph_output_index, channel);
if (output_usage == ValueUsage::Maybe) {
continue;
}
const InputSocket &socket = *self_.graph_outputs_[graph_output_index];
const Node &node = socket.node();
NodeState &node_state = *node_states_[node.index_in_graph()];
this->with_locked_node(
node, node_state, current_task, local_data, [&](LockedNode &locked_node) {
if (output_usage == ValueUsage::Used) {
this->set_input_required(locked_node, socket, channel);
}
else {
this->set_input_unused(locked_node, socket, channel);
}
});
}
const InputSocket &socket = *self_.graph_outputs_[graph_output_index];
const Node &node = socket.node();
NodeState &node_state = *node_states_[node.index_in_graph()];
this->with_locked_node(
node, node_state, current_task, local_data, [&](LockedNode &locked_node) {
if (output_usage == ValueUsage::Used) {
this->set_input_required(locked_node, socket);
}
else {
this->set_input_unused(locked_node, socket);
}
});
}
}
@ -636,25 +654,38 @@ class Executor {
void forward_newly_provided_inputs(CurrentTask &current_task, const LocalData &local_data)
{
for (const int graph_input_index : self_.graph_inputs_.index_range()) {
std::atomic<uint8_t> &was_loaded = loaded_inputs_[graph_input_index];
if (was_loaded.load()) {
continue;
const int channels_num = params_->get_input_channels_num(graph_input_index);
for (int main_or_dynamic_index = -1; main_or_dynamic_index < channels_num;
main_or_dynamic_index++)
{
const ChannelID channel = ChannelID::from_main_or_dynamic_index(main_or_dynamic_index);
std::atomic<uint8_t> &was_loaded = loaded_inputs_[graph_input_index];
if (channel.is_main()) {
if (was_loaded.load()) {
continue;
}
}
void *input_data = params_->try_get_input_data_ptr(graph_input_index);
if (input_data == nullptr) {
continue;
}
if (channel.is_main()) {
/* TODO: Don't forward dynamic channel value again. */
if (was_loaded.fetch_or(1)) {
/* The value was forwarded before. */
continue;
}
}
this->forward_newly_provided_input(
current_task, local_data, graph_input_index, channel, input_data);
}
void *input_data = params_->try_get_input_data_ptr(graph_input_index);
if (input_data == nullptr) {
continue;
}
if (was_loaded.fetch_or(1)) {
/* The value was forwarded before. */
continue;
}
this->forward_newly_provided_input(current_task, local_data, graph_input_index, input_data);
}
}
void forward_newly_provided_input(CurrentTask &current_task,
const LocalData &local_data,
const int graph_input_index,
const ChannelID channel,
void *input_data)
{
const OutputSocket &socket = *self_.graph_inputs_[graph_input_index];
@ -665,31 +696,40 @@ class Executor {
}
void notify_output_required(const OutputSocket &socket,
const ChannelID channel,
CurrentTask &current_task,
const LocalData &local_data)
{
const Node &node = socket.node();
const int index_in_node = socket.index();
NodeState &node_state = *node_states_[node.index_in_graph()];
OutputState &output_state = node_state.outputs[index_in_node];
OutputState &output_state =
channel.is_main() ?
node_state.outputs[index_in_node] :
*node_state.dynamic_channels->output_states[index_in_node][channel.dynamic_index()];
/* The notified output socket might be an input of the entire graph. In this case, notify the
* caller that the input is required. */
if (node.is_interface()) {
const int graph_input_index = self_.graph_input_index_by_socket_index_[socket.index()];
std::atomic<uint8_t> &was_loaded = loaded_inputs_[graph_input_index];
if (was_loaded.load()) {
return;
if (channel.is_main()) {
if (was_loaded.load()) {
return;
}
}
void *input_data = params_->try_get_input_data_ptr_or_request(graph_input_index);
void *input_data = params_->try_get_input_data_ptr_or_request(graph_input_index, channel);
if (input_data == nullptr) {
return;
}
if (was_loaded.fetch_or(1)) {
/* The value was forwarded already. */
return;
if (channel.is_main()) {
if (was_loaded.fetch_or(1)) {
/* The value was forwarded already. */
return;
}
}
this->forward_newly_provided_input(current_task, local_data, graph_input_index, input_data);
this->forward_newly_provided_input(
current_task, local_data, graph_input_index, channel, input_data);
return;
}
@ -705,6 +745,7 @@ class Executor {
}
void notify_output_unused(const OutputSocket &socket,
const ChannelID channel,
CurrentTask &current_task,
const LocalData &local_data)
{
@ -788,21 +829,23 @@ class Executor {
locked_node.delayed_unused_outputs, current_task, local_data);
}
void send_output_required_notifications(const Span<const OutputSocket *> sockets,
CurrentTask &current_task,
const LocalData &local_data)
void send_output_required_notifications(
const Span<std::pair<const OutputSocket *, ChannelID>> sockets,
CurrentTask &current_task,
const LocalData &local_data)
{
for (const OutputSocket *socket : sockets) {
this->notify_output_required(*socket, current_task, local_data);
for (auto &&[socket, channel] : sockets) {
this->notify_output_required(*socket, channel, current_task, local_data);
}
}
void send_output_unused_notifications(const Span<const OutputSocket *> sockets,
CurrentTask &current_task,
const LocalData &local_data)
void send_output_unused_notifications(
const Span<std::pair<const OutputSocket *, ChannelID>> sockets,
CurrentTask &current_task,
const LocalData &local_data)
{
for (const OutputSocket *socket : sockets) {
this->notify_output_unused(*socket, current_task, local_data);
for (auto &&[socket, channel] : sockets) {
this->notify_output_unused(*socket, channel, current_task, local_data);
}
}
@ -865,7 +908,7 @@ class Executor {
if (fn_input.usage == ValueUsage::Used) {
const InputSocket &input_socket = node.input(input_index);
if (input_socket.origin() != nullptr) {
this->set_input_required(locked_node, input_socket);
this->set_input_required(locked_node, input_socket, ChannelID());
}
}
}
@ -873,6 +916,7 @@ class Executor {
node_state.always_used_inputs_requested = true;
}
/* TODO: Handle dynamic channels. */
for (const int input_index : node.inputs().index_range()) {
InputState &input_state = node_state.inputs[input_index];
if (input_state.was_ready_for_execution) {
@ -1044,11 +1088,16 @@ class Executor {
});
}
void set_input_unused(LockedNode &locked_node, const InputSocket &input_socket)
void set_input_unused(LockedNode &locked_node,
const InputSocket &input_socket,
const ChannelID channel)
{
NodeState &node_state = locked_node.node_state;
const int input_index = input_socket.index();
InputState &input_state = node_state.inputs[input_index];
InputState &input_state =
channel.is_main() ?
node_state.inputs[input_index] :
*node_state.dynamic_channels->input_states[input_index][channel.dynamic_index()];
BLI_assert(input_state.usage != ValueUsage::Used);
if (input_state.usage == ValueUsage::Unused) {
@ -1062,13 +1111,14 @@ class Executor {
}
const OutputSocket *origin = input_socket.origin();
if (origin != nullptr) {
locked_node.delayed_unused_outputs.append(origin);
locked_node.delayed_unused_outputs.append({origin, channel});
}
}
void *set_input_required_during_execution(const Node &node,
NodeState &node_state,
const int input_index,
const ChannelID channel,
CurrentTask &current_task,
const LocalData &local_data)
{
@ -1076,17 +1126,22 @@ class Executor {
void *result;
this->with_locked_node(
node, node_state, current_task, local_data, [&](LockedNode &locked_node) {
result = this->set_input_required(locked_node, input_socket);
result = this->set_input_required(locked_node, input_socket, channel);
});
return result;
}
void *set_input_required(LockedNode &locked_node, const InputSocket &input_socket)
void *set_input_required(LockedNode &locked_node,
const InputSocket &input_socket,
const ChannelID channel)
{
BLI_assert(&locked_node.node == &input_socket.node());
NodeState &node_state = locked_node.node_state;
const int input_index = input_socket.index();
InputState &input_state = node_state.inputs[input_index];
InputState &input_state =
channel.is_main() ?
node_state.inputs[input_index] :
*node_state.dynamic_channels->input_states[input_index][channel.dynamic_index()];
BLI_assert(input_state.usage != ValueUsage::Unused);
@ -1103,7 +1158,7 @@ class Executor {
const OutputSocket *origin_socket = input_socket.origin();
/* Unlinked inputs are always loaded in advance. */
BLI_assert(origin_socket != nullptr);
locked_node.delayed_required_outputs.append(origin_socket);
locked_node.delayed_required_outputs.append({origin_socket, channel});
return nullptr;
}
@ -1224,8 +1279,8 @@ class Executor {
return true;
}
#ifdef FN_LAZY_FUNCTION_DEBUG_THREADS
/* Only the current main thread is allowed to enabled multi-threading, because the executor is
* still in single-threaded mode. */
/* Only the current main thread is allowed to enabled multi-threading, because the executor
* is still in single-threaded mode. */
if (current_main_thread_ != std::this_thread::get_id()) {
BLI_assert_unreachable();
}
@ -1342,26 +1397,36 @@ class GraphExecutorLFParams final : public Params {
return executor_.get_local_data();
}
void *try_get_input_data_ptr_impl(const int index) const override
const InputState &get_channel_input_state(const int index, const ChannelID channel) const
{
const InputState &input_state = node_state_.inputs[index];
if (channel.is_main()) {
return node_state_.inputs[index];
}
BLI_assert(node_state_.dynamic_channels);
return *node_state_.dynamic_channels
->input_states_for_execution[index][channel.dynamic_index()];
}
void *try_get_input_data_ptr_impl(const int index, const ChannelID channel) const override
{
const InputState &input_state = this->get_channel_input_state(index, channel);
if (input_state.was_ready_for_execution) {
return input_state.value;
}
return nullptr;
}
void *try_get_input_data_ptr_or_request_impl(const int index) override
void *try_get_input_data_ptr_or_request_impl(const int index, const ChannelID channel) override
{
const InputState &input_state = node_state_.inputs[index];
const InputState &input_state = this->get_channel_input_state(index, channel);
if (input_state.was_ready_for_execution) {
return input_state.value;
}
return executor_.set_input_required_during_execution(
node_, node_state_, index, current_task_, this->get_local_data());
node_, node_state_, index, channel, current_task_, this->get_local_data());
}
void *get_output_data_ptr_impl(const int index) override
void *get_output_data_ptr_impl(const int index, const ChannelID /*channel*/) override
{
OutputState &output_state = node_state_.outputs[index];
BLI_assert(!output_state.has_been_computed);
@ -1373,7 +1438,7 @@ class GraphExecutorLFParams final : public Params {
return output_state.value;
}
void output_set_impl(const int index) override
void output_set_impl(const int index, const ChannelID /*channel*/) override
{
OutputState &output_state = node_state_.outputs[index];
BLI_assert(!output_state.has_been_computed);
@ -1387,24 +1452,45 @@ class GraphExecutorLFParams final : public Params {
output_state.has_been_computed = true;
}
bool output_was_set_impl(const int index) const override
bool output_was_set_impl(const int index, const ChannelID /*channel*/) const override
{
const OutputState &output_state = node_state_.outputs[index];
return output_state.has_been_computed;
}
ValueUsage get_output_usage_impl(const int index) const override
ValueUsage get_output_usage_impl(const int index, const ChannelID /*channel*/) const override
{
const OutputState &output_state = node_state_.outputs[index];
return output_state.usage_for_execution;
}
void set_input_unused_impl(const int index) override
void set_input_unused_impl(const int index, const ChannelID /*channel*/) override
{
executor_.set_input_unused_during_execution(
node_, node_state_, index, current_task_, this->get_local_data());
}
int get_input_channels_num_impl(const int index) const override
{
if (!node_state_.dynamic_channels) {
return 0;
}
return node_state_.dynamic_channels->input_states_for_execution[index].size();
}
int get_output_channels_num_impl(const int index) const override
{
if (!node_state_.dynamic_channels) {
return 0;
}
return node_state_.dynamic_channels->output_states_for_execution[index].size();
}
IndexRange add_output_channels_impl(const int /*index*/, const int /*num*/)
{
return {};
}
bool try_enable_multi_threading_impl() override
{
const bool success = executor_.try_enable_multi_threading();
@ -1434,9 +1520,9 @@ inline void Executor::execute_node(const FunctionNode &node,
self_.logger_->log_before_node_execute(node, node_params, fn_context);
}
/* This is run when the execution of the node calls `lazy_threading::send_hint` to indicate that
* the execution will take a while. In this case, other tasks waiting on this thread should be
* allowed to be picked up by another thread. */
/* This is run when the execution of the node calls `lazy_threading::send_hint` to indicate
* that the execution will take a while. In this case, other tasks waiting on this thread
* should be allowed to be picked up by another thread. */
auto blocking_hint_fn = [&]() {
if (!current_task.has_scheduled_nodes.load()) {
return;