WIP: Functions: support dynamically added channels in lazy functions #124299
@ -126,6 +126,43 @@ struct Context {
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelID {
|
||||
private:
|
||||
int index_;
|
||||
|
||||
public:
|
||||
ChannelID() : index_(-1) {}
|
||||
|
||||
explicit ChannelID(const int index) : index_(index)
|
||||
{
|
||||
BLI_assert(index >= 0);
|
||||
}
|
||||
|
||||
static ChannelID from_main_or_dynamic_index(const int index)
|
||||
{
|
||||
BLI_assert(index >= -1);
|
||||
ChannelID channel;
|
||||
channel.index_ = index;
|
||||
return channel;
|
||||
}
|
||||
|
||||
bool is_main() const
|
||||
{
|
||||
return index_ < 0;
|
||||
}
|
||||
|
||||
bool is_dynamic() const
|
||||
{
|
||||
return index_ >= 0;
|
||||
}
|
||||
|
||||
int dynamic_index() const
|
||||
{
|
||||
BLI_assert(this->is_dynamic());
|
||||
return index_;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Defines the calling convention for a lazy-function. During execution, a lazy-function retrieves
|
||||
* its inputs and sets the outputs through #Params.
|
||||
@ -149,13 +186,13 @@ class Params {
|
||||
*
|
||||
* The #LazyFunction must leave returned object in an initialized state, but can move from it.
|
||||
*/
|
||||
void *try_get_input_data_ptr(int index) const;
|
||||
void *try_get_input_data_ptr(int index, ChannelID channel = {}) const;
|
||||
|
||||
/**
|
||||
* Same as #try_get_input_data_ptr, but if the data is not yet available, request it. This makes
|
||||
* sure that the data will be available in a future execution of the #LazyFunction.
|
||||
*/
|
||||
void *try_get_input_data_ptr_or_request(int index);
|
||||
void *try_get_input_data_ptr_or_request(int index, ChannelID channel = {});
|
||||
|
||||
/**
|
||||
* Get a pointer to where the output value should be stored.
|
||||
@ -163,39 +200,44 @@ class Params {
|
||||
* The #LazyFunction is responsible for initializing the value.
|
||||
* After the output has been initialized to its final value, #output_set has to be called.
|
||||
*/
|
||||
void *get_output_data_ptr(int index);
|
||||
void *get_output_data_ptr(int index, ChannelID channel = {});
|
||||
|
||||
/**
|
||||
* Call this after the output value is initialized. After this is called, the value must not be
|
||||
* touched anymore. It may be moved or destructed immediately.
|
||||
*/
|
||||
void output_set(int index);
|
||||
void output_set(int index, ChannelID channel = {});
|
||||
|
||||
/**
|
||||
* Allows the #LazyFunction to check whether an output was computed already without keeping
|
||||
* track of it itself.
|
||||
*/
|
||||
bool output_was_set(int index) const;
|
||||
bool output_was_set(int index, ChannelID channel = {}) const;
|
||||
|
||||
/**
|
||||
* Can be used to detect which outputs have to be computed.
|
||||
*/
|
||||
ValueUsage get_output_usage(int index) const;
|
||||
ValueUsage get_output_usage(int index, ChannelID channel = {}) const;
|
||||
|
||||
/**
|
||||
* Tell the caller of the #LazyFunction that a specific input will definitely not be used.
|
||||
* Only an input that was not #ValueUsage::Used can become unused.
|
||||
*/
|
||||
void set_input_unused(int index);
|
||||
void set_input_unused(int index, ChannelID channel = {});
|
||||
|
||||
int get_input_channels_num(int index) const;
|
||||
int get_output_channels_num(int index) const;
|
||||
IndexRange add_output_channels(int index, int num);
|
||||
|
||||
/**
|
||||
* Typed utility methods that wrap the methods above.
|
||||
*/
|
||||
template<typename T> T extract_input(int index);
|
||||
template<typename T> T &get_input(int index) const;
|
||||
template<typename T> T *try_get_input_data_ptr(int index) const;
|
||||
template<typename T> T *try_get_input_data_ptr_or_request(int index);
|
||||
template<typename T> T extract_input(int index, ChannelID channel = {});
|
||||
template<typename T> T &get_input(int index, ChannelID channel = {}) const;
|
||||
template<typename T> T *try_get_input_data_ptr(int index, ChannelID channel = {}) const;
|
||||
template<typename T> T *try_get_input_data_ptr_or_request(int index, ChannelID channel = {});
|
||||
template<typename T> void set_output(int index, T &&value);
|
||||
template<typename T> void set_output(int index, ChannelID channel, T &&value);
|
||||
|
||||
/**
|
||||
* Returns true when the lazy-function is now allowed to use multi-threading when interacting
|
||||
@ -211,13 +253,18 @@ class Params {
|
||||
* methods above to make it easy to insert additional debugging logic on top of the
|
||||
* implementations.
|
||||
*/
|
||||
virtual void *try_get_input_data_ptr_impl(int index) const = 0;
|
||||
virtual void *try_get_input_data_ptr_or_request_impl(int index) = 0;
|
||||
virtual void *get_output_data_ptr_impl(int index) = 0;
|
||||
virtual void output_set_impl(int index) = 0;
|
||||
virtual bool output_was_set_impl(int index) const = 0;
|
||||
virtual ValueUsage get_output_usage_impl(int index) const = 0;
|
||||
virtual void set_input_unused_impl(int index) = 0;
|
||||
virtual void *try_get_input_data_ptr_impl(int index, ChannelID channel) const = 0;
|
||||
virtual void *try_get_input_data_ptr_or_request_impl(int index, ChannelID channel) = 0;
|
||||
virtual void *get_output_data_ptr_impl(int index, ChannelID channel) = 0;
|
||||
virtual void output_set_impl(int index, ChannelID channel) = 0;
|
||||
virtual bool output_was_set_impl(int index, ChannelID channel) const = 0;
|
||||
virtual ValueUsage get_output_usage_impl(int index, ChannelID channel) const = 0;
|
||||
virtual void set_input_unused_impl(int index, ChannelID channel) = 0;
|
||||
|
||||
virtual int get_input_channels_num_impl(int index) const = 0;
|
||||
virtual int get_output_channels_num_impl(int index) const = 0;
|
||||
virtual IndexRange add_output_channels_impl(int index, int num) = 0;
|
||||
|
||||
virtual bool try_enable_multi_threading_impl();
|
||||
};
|
||||
|
||||
@ -386,85 +433,115 @@ inline Params::Params(const LazyFunction &fn,
|
||||
{
|
||||
}
|
||||
|
||||
inline void *Params::try_get_input_data_ptr(const int index) const
|
||||
inline void *Params::try_get_input_data_ptr(const int index, const ChannelID channel) const
|
||||
{
|
||||
BLI_assert(index >= 0 && index < fn_.inputs().size());
|
||||
return this->try_get_input_data_ptr_impl(index);
|
||||
return this->try_get_input_data_ptr_impl(index, channel);
|
||||
}
|
||||
|
||||
inline void *Params::try_get_input_data_ptr_or_request(const int index)
|
||||
inline void *Params::try_get_input_data_ptr_or_request(const int index, const ChannelID channel)
|
||||
{
|
||||
BLI_assert(index >= 0 && index < fn_.inputs().size());
|
||||
this->assert_valid_thread();
|
||||
return this->try_get_input_data_ptr_or_request_impl(index);
|
||||
return this->try_get_input_data_ptr_or_request_impl(index, channel);
|
||||
}
|
||||
|
||||
inline void *Params::get_output_data_ptr(const int index)
|
||||
inline void *Params::get_output_data_ptr(const int index, const ChannelID channel)
|
||||
{
|
||||
BLI_assert(index >= 0 && index < fn_.outputs().size());
|
||||
this->assert_valid_thread();
|
||||
return this->get_output_data_ptr_impl(index);
|
||||
return this->get_output_data_ptr_impl(index, channel);
|
||||
}
|
||||
|
||||
inline void Params::output_set(const int index)
|
||||
inline void Params::output_set(const int index, const ChannelID channel)
|
||||
{
|
||||
BLI_assert(index >= 0 && index < fn_.outputs().size());
|
||||
this->assert_valid_thread();
|
||||
this->output_set_impl(index);
|
||||
this->output_set_impl(index, channel);
|
||||
}
|
||||
|
||||
inline bool Params::output_was_set(const int index) const
|
||||
inline bool Params::output_was_set(const int index, const ChannelID channel) const
|
||||
{
|
||||
BLI_assert(index >= 0 && index < fn_.outputs().size());
|
||||
return this->output_was_set_impl(index);
|
||||
return this->output_was_set_impl(index, channel);
|
||||
}
|
||||
|
||||
inline ValueUsage Params::get_output_usage(const int index) const
|
||||
inline ValueUsage Params::get_output_usage(const int index, const ChannelID channel) const
|
||||
{
|
||||
BLI_assert(index >= 0 && index < fn_.outputs().size());
|
||||
return this->get_output_usage_impl(index);
|
||||
return this->get_output_usage_impl(index, channel);
|
||||
}
|
||||
|
||||
inline void Params::set_input_unused(const int index)
|
||||
inline void Params::set_input_unused(const int index, const ChannelID channel)
|
||||
{
|
||||
BLI_assert(index >= 0 && index < fn_.inputs().size());
|
||||
this->assert_valid_thread();
|
||||
this->set_input_unused_impl(index);
|
||||
this->set_input_unused_impl(index, channel);
|
||||
}
|
||||
|
||||
template<typename T> inline T Params::extract_input(const int index)
|
||||
inline int Params::get_input_channels_num(const int index) const
|
||||
{
|
||||
BLI_assert(index >= 0 && index < fn_.inputs().size());
|
||||
this->assert_valid_thread();
|
||||
return this->get_input_channels_num_impl(index);
|
||||
}
|
||||
|
||||
inline int Params::get_output_channels_num(const int index) const
|
||||
{
|
||||
BLI_assert(index >= 0 && index < fn_.outputs().size());
|
||||
this->assert_valid_thread();
|
||||
return this->get_output_channels_num_impl(index);
|
||||
}
|
||||
|
||||
inline IndexRange Params::add_output_channels(const int index, const int num)
|
||||
{
|
||||
BLI_assert(num >= 0);
|
||||
BLI_assert(index >= 0 && index < fn_.inputs().size());
|
||||
this->assert_valid_thread();
|
||||
return this->add_output_channels_impl(index, num);
|
||||
}
|
||||
|
||||
template<typename T> inline T Params::extract_input(const int index, const ChannelID channel)
|
||||
{
|
||||
this->assert_valid_thread();
|
||||
void *data = this->try_get_input_data_ptr(index);
|
||||
void *data = this->try_get_input_data_ptr(index, channel);
|
||||
BLI_assert(data != nullptr);
|
||||
T return_value = std::move(*static_cast<T *>(data));
|
||||
return return_value;
|
||||
}
|
||||
|
||||
template<typename T> inline T &Params::get_input(const int index) const
|
||||
template<typename T> inline T &Params::get_input(const int index, const ChannelID channel) const
|
||||
{
|
||||
void *data = this->try_get_input_data_ptr(index);
|
||||
void *data = this->try_get_input_data_ptr(index, channel);
|
||||
BLI_assert(data != nullptr);
|
||||
return *static_cast<T *>(data);
|
||||
}
|
||||
|
||||
template<typename T> inline T *Params::try_get_input_data_ptr(const int index) const
|
||||
template<typename T>
|
||||
inline T *Params::try_get_input_data_ptr(const int index, const ChannelID channel) const
|
||||
{
|
||||
this->assert_valid_thread();
|
||||
return static_cast<T *>(this->try_get_input_data_ptr(index));
|
||||
return static_cast<T *>(this->try_get_input_data_ptr(index, channel));
|
||||
}
|
||||
|
||||
template<typename T> inline T *Params::try_get_input_data_ptr_or_request(const int index)
|
||||
template<typename T>
|
||||
inline T *Params::try_get_input_data_ptr_or_request(const int index, const ChannelID channel)
|
||||
{
|
||||
this->assert_valid_thread();
|
||||
return static_cast<T *>(this->try_get_input_data_ptr_or_request(index));
|
||||
return static_cast<T *>(this->try_get_input_data_ptr_or_request(index, channel));
|
||||
}
|
||||
|
||||
template<typename T> inline void Params::set_output(const int index, T &&value)
|
||||
{
|
||||
this->set_output(index, ChannelID(), std::forward<T>(value));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline void Params::set_output(const int index, const ChannelID channel, T &&value)
|
||||
{
|
||||
using DecayT = std::decay_t<T>;
|
||||
this->assert_valid_thread();
|
||||
void *data = this->get_output_data_ptr(index);
|
||||
void *data = this->get_output_data_ptr(index, channel);
|
||||
new (data) DecayT(std::forward<T>(value));
|
||||
this->output_set(index);
|
||||
}
|
||||
|
@ -36,13 +36,18 @@ class BasicParams : public Params {
|
||||
Span<ValueUsage> output_usages,
|
||||
MutableSpan<bool> set_outputs);
|
||||
|
||||
void *try_get_input_data_ptr_impl(const int index) const override;
|
||||
void *try_get_input_data_ptr_or_request_impl(const int index) override;
|
||||
void *get_output_data_ptr_impl(const int index) override;
|
||||
void output_set_impl(const int index) override;
|
||||
bool output_was_set_impl(const int index) const override;
|
||||
ValueUsage get_output_usage_impl(const int index) const override;
|
||||
void set_input_unused_impl(const int index) override;
|
||||
void *try_get_input_data_ptr_impl(int index, ChannelID channel) const override;
|
||||
void *try_get_input_data_ptr_or_request_impl(int index, ChannelID channel) override;
|
||||
void *get_output_data_ptr_impl(int index, ChannelID channel) override;
|
||||
void output_set_impl(int index, ChannelID channel) override;
|
||||
bool output_was_set_impl(int index, ChannelID channel) const override;
|
||||
ValueUsage get_output_usage_impl(int index, ChannelID channel) const override;
|
||||
void set_input_unused_impl(int index, ChannelID channel) override;
|
||||
|
||||
int get_input_channels_num_impl(int index) const override;
|
||||
int get_output_channels_num_impl(int index) const override;
|
||||
IndexRange add_output_channels_impl(int index, int num) override;
|
||||
|
||||
bool try_enable_multi_threading_impl() override;
|
||||
};
|
||||
|
||||
@ -64,13 +69,18 @@ class RemappedParams : public Params {
|
||||
Span<int> output_map,
|
||||
bool &multi_threading_enabled);
|
||||
|
||||
void *try_get_input_data_ptr_impl(const int index) const override;
|
||||
void *try_get_input_data_ptr_or_request_impl(const int index) override;
|
||||
void *get_output_data_ptr_impl(const int index) override;
|
||||
void output_set_impl(const int index) override;
|
||||
bool output_was_set_impl(const int index) const override;
|
||||
ValueUsage get_output_usage_impl(const int index) const override;
|
||||
void set_input_unused_impl(const int index) override;
|
||||
void *try_get_input_data_ptr_impl(int index, ChannelID channel) const override;
|
||||
void *try_get_input_data_ptr_or_request_impl(int index, ChannelID channel) override;
|
||||
void *get_output_data_ptr_impl(int index, ChannelID channel) override;
|
||||
void output_set_impl(int index, ChannelID channel) override;
|
||||
bool output_was_set_impl(int index, ChannelID channel) const override;
|
||||
ValueUsage get_output_usage_impl(int index, ChannelID channel) const override;
|
||||
void set_input_unused_impl(int index, ChannelID channel) override;
|
||||
|
||||
int get_input_channels_num_impl(int index) const override;
|
||||
int get_output_channels_num_impl(int index) const override;
|
||||
IndexRange add_output_channels_impl(int index, int num) override;
|
||||
|
||||
bool try_enable_multi_threading_impl() override;
|
||||
};
|
||||
|
||||
|
@ -29,12 +29,13 @@ BasicParams::BasicParams(const LazyFunction &fn,
|
||||
{
|
||||
}
|
||||
|
||||
void *BasicParams::try_get_input_data_ptr_impl(const int index) const
|
||||
void *BasicParams::try_get_input_data_ptr_impl(const int index, const ChannelID /*channel*/) const
|
||||
{
|
||||
return inputs_[index].get();
|
||||
}
|
||||
|
||||
void *BasicParams::try_get_input_data_ptr_or_request_impl(const int index)
|
||||
void *BasicParams::try_get_input_data_ptr_or_request_impl(const int index,
|
||||
const ChannelID /*channel*/)
|
||||
{
|
||||
void *value = inputs_[index].get();
|
||||
if (value == nullptr) {
|
||||
@ -43,31 +44,46 @@ void *BasicParams::try_get_input_data_ptr_or_request_impl(const int index)
|
||||
return value;
|
||||
}
|
||||
|
||||
void *BasicParams::get_output_data_ptr_impl(const int index)
|
||||
void *BasicParams::get_output_data_ptr_impl(const int index, const ChannelID /*channel*/)
|
||||
{
|
||||
return outputs_[index].get();
|
||||
}
|
||||
|
||||
void BasicParams::output_set_impl(const int index)
|
||||
void BasicParams::output_set_impl(const int index, const ChannelID /*channel*/)
|
||||
{
|
||||
set_outputs_[index] = true;
|
||||
}
|
||||
|
||||
bool BasicParams::output_was_set_impl(const int index) const
|
||||
bool BasicParams::output_was_set_impl(const int index, const ChannelID /*channel*/) const
|
||||
{
|
||||
return set_outputs_[index];
|
||||
}
|
||||
|
||||
ValueUsage BasicParams::get_output_usage_impl(const int index) const
|
||||
ValueUsage BasicParams::get_output_usage_impl(const int index, const ChannelID /*channel*/) const
|
||||
{
|
||||
return output_usages_[index];
|
||||
}
|
||||
|
||||
void BasicParams::set_input_unused_impl(const int index)
|
||||
void BasicParams::set_input_unused_impl(const int index, const ChannelID /*channel*/)
|
||||
{
|
||||
input_usages_[index] = ValueUsage::Unused;
|
||||
}
|
||||
|
||||
int BasicParams::get_input_channels_num_impl(const int /*index*/) const
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
int BasicParams::get_output_channels_num_impl(const int /*index*/) const
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
IndexRange BasicParams::add_output_channels_impl(const int /*index*/, const int /*num*/)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
bool BasicParams::try_enable_multi_threading_impl()
|
||||
{
|
||||
return true;
|
||||
@ -92,41 +108,59 @@ RemappedParams::RemappedParams(const LazyFunction &fn,
|
||||
{
|
||||
}
|
||||
|
||||
void *RemappedParams::try_get_input_data_ptr_impl(const int index) const
|
||||
void *RemappedParams::try_get_input_data_ptr_impl(const int index,
|
||||
const ChannelID /*channel*/) const
|
||||
{
|
||||
return base_params_.try_get_input_data_ptr(input_map_[index]);
|
||||
}
|
||||
|
||||
void *RemappedParams::try_get_input_data_ptr_or_request_impl(const int index)
|
||||
void *RemappedParams::try_get_input_data_ptr_or_request_impl(const int index,
|
||||
const ChannelID /*channel*/)
|
||||
{
|
||||
return base_params_.try_get_input_data_ptr_or_request(input_map_[index]);
|
||||
}
|
||||
|
||||
void *RemappedParams::get_output_data_ptr_impl(const int index)
|
||||
void *RemappedParams::get_output_data_ptr_impl(const int index, const ChannelID /*channel*/)
|
||||
{
|
||||
return base_params_.get_output_data_ptr(output_map_[index]);
|
||||
}
|
||||
|
||||
void RemappedParams::output_set_impl(const int index)
|
||||
void RemappedParams::output_set_impl(const int index, const ChannelID /*channel*/)
|
||||
{
|
||||
return base_params_.output_set(output_map_[index]);
|
||||
}
|
||||
|
||||
bool RemappedParams::output_was_set_impl(const int index) const
|
||||
bool RemappedParams::output_was_set_impl(const int index, const ChannelID /*channel*/) const
|
||||
{
|
||||
return base_params_.output_was_set(output_map_[index]);
|
||||
}
|
||||
|
||||
lf::ValueUsage RemappedParams::get_output_usage_impl(const int index) const
|
||||
lf::ValueUsage RemappedParams::get_output_usage_impl(const int index,
|
||||
const ChannelID /*channel*/) const
|
||||
{
|
||||
return base_params_.get_output_usage(output_map_[index]);
|
||||
}
|
||||
|
||||
void RemappedParams::set_input_unused_impl(const int index)
|
||||
void RemappedParams::set_input_unused_impl(const int index, const ChannelID /*channel*/)
|
||||
{
|
||||
return base_params_.set_input_unused(input_map_[index]);
|
||||
}
|
||||
|
||||
int RemappedParams::get_input_channels_num_impl(const int index) const
|
||||
{
|
||||
return base_params_.get_input_channels_num(input_map_[index]);
|
||||
}
|
||||
|
||||
int RemappedParams::get_output_channels_num_impl(const int index) const
|
||||
{
|
||||
return base_params_.get_output_channels_num(output_map_[index]);
|
||||
}
|
||||
|
||||
IndexRange RemappedParams::add_output_channels_impl(const int index, const int num)
|
||||
{
|
||||
return base_params_.add_output_channels(output_map_[index], num);
|
||||
}
|
||||
|
||||
bool RemappedParams::try_enable_multi_threading_impl()
|
||||
{
|
||||
if (multi_threading_enabled_) {
|
||||
|
@ -47,6 +47,7 @@
|
||||
#include "BLI_compute_context.hh"
|
||||
#include "BLI_enumerable_thread_specific.hh"
|
||||
#include "BLI_function_ref.hh"
|
||||
#include "BLI_linear_allocator_chunked_list.hh"
|
||||
#include "BLI_task.h"
|
||||
#include "BLI_task.hh"
|
||||
#include "BLI_timeit.hh"
|
||||
@ -130,6 +131,14 @@ struct OutputState {
|
||||
void *value = nullptr;
|
||||
};
|
||||
|
||||
struct DynamicChannelStates {
|
||||
Vector<Vector<InputState *>> input_states;
|
||||
Vector<Vector<OutputState *>> output_states;
|
||||
|
||||
Vector<Vector<InputState *>> input_states_for_execution;
|
||||
Vector<Vector<OutputState *>> output_states_for_execution;
|
||||
};
|
||||
|
||||
struct NodeState {
|
||||
/**
|
||||
* Needs to be locked when any data in this state is accessed that is not explicitly marked as
|
||||
@ -144,6 +153,7 @@ struct NodeState {
|
||||
*/
|
||||
InputState *inputs;
|
||||
OutputState *outputs;
|
||||
|
||||
/**
|
||||
* Counts the number of inputs that still have to be provided to this node, until it should run
|
||||
* again. This is used as an optimization so that nodes are not scheduled unnecessarily in many
|
||||
@ -184,6 +194,8 @@ struct NodeState {
|
||||
* Custom storage of the node.
|
||||
*/
|
||||
void *storage = nullptr;
|
||||
|
||||
DynamicChannelStates *dynamic_channels = nullptr;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -205,8 +217,8 @@ struct LockedNode {
|
||||
*
|
||||
* The notifications will be send right after the node is not locked anymore.
|
||||
*/
|
||||
Vector<const OutputSocket *> delayed_required_outputs;
|
||||
Vector<const OutputSocket *> delayed_unused_outputs;
|
||||
Vector<std::pair<const OutputSocket *, ChannelID>> delayed_required_outputs;
|
||||
Vector<std::pair<const OutputSocket *, ChannelID>> delayed_unused_outputs;
|
||||
|
||||
LockedNode(const Node &node, NodeState &node_state) : node(node), node_state(node_state) {}
|
||||
};
|
||||
@ -486,22 +498,28 @@ class Executor {
|
||||
if (params_->output_was_set(graph_output_index)) {
|
||||
continue;
|
||||
}
|
||||
const ValueUsage output_usage = params_->get_output_usage(graph_output_index);
|
||||
if (output_usage == ValueUsage::Maybe) {
|
||||
continue;
|
||||
const int channels_num = params_->get_output_channels_num(graph_output_index);
|
||||
for (int main_or_dynamic_index = -1; main_or_dynamic_index < channels_num;
|
||||
main_or_dynamic_index++)
|
||||
{
|
||||
const ChannelID channel = ChannelID::from_main_or_dynamic_index(main_or_dynamic_index);
|
||||
const ValueUsage output_usage = params_->get_output_usage(graph_output_index, channel);
|
||||
if (output_usage == ValueUsage::Maybe) {
|
||||
continue;
|
||||
}
|
||||
const InputSocket &socket = *self_.graph_outputs_[graph_output_index];
|
||||
const Node &node = socket.node();
|
||||
NodeState &node_state = *node_states_[node.index_in_graph()];
|
||||
this->with_locked_node(
|
||||
node, node_state, current_task, local_data, [&](LockedNode &locked_node) {
|
||||
if (output_usage == ValueUsage::Used) {
|
||||
this->set_input_required(locked_node, socket, channel);
|
||||
}
|
||||
else {
|
||||
this->set_input_unused(locked_node, socket, channel);
|
||||
}
|
||||
});
|
||||
}
|
||||
const InputSocket &socket = *self_.graph_outputs_[graph_output_index];
|
||||
const Node &node = socket.node();
|
||||
NodeState &node_state = *node_states_[node.index_in_graph()];
|
||||
this->with_locked_node(
|
||||
node, node_state, current_task, local_data, [&](LockedNode &locked_node) {
|
||||
if (output_usage == ValueUsage::Used) {
|
||||
this->set_input_required(locked_node, socket);
|
||||
}
|
||||
else {
|
||||
this->set_input_unused(locked_node, socket);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -636,25 +654,38 @@ class Executor {
|
||||
void forward_newly_provided_inputs(CurrentTask ¤t_task, const LocalData &local_data)
|
||||
{
|
||||
for (const int graph_input_index : self_.graph_inputs_.index_range()) {
|
||||
std::atomic<uint8_t> &was_loaded = loaded_inputs_[graph_input_index];
|
||||
if (was_loaded.load()) {
|
||||
continue;
|
||||
const int channels_num = params_->get_input_channels_num(graph_input_index);
|
||||
for (int main_or_dynamic_index = -1; main_or_dynamic_index < channels_num;
|
||||
main_or_dynamic_index++)
|
||||
{
|
||||
const ChannelID channel = ChannelID::from_main_or_dynamic_index(main_or_dynamic_index);
|
||||
std::atomic<uint8_t> &was_loaded = loaded_inputs_[graph_input_index];
|
||||
if (channel.is_main()) {
|
||||
if (was_loaded.load()) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
void *input_data = params_->try_get_input_data_ptr(graph_input_index);
|
||||
if (input_data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (channel.is_main()) {
|
||||
/* TODO: Don't forward dynamic channel value again. */
|
||||
if (was_loaded.fetch_or(1)) {
|
||||
/* The value was forwarded before. */
|
||||
continue;
|
||||
}
|
||||
}
|
||||
this->forward_newly_provided_input(
|
||||
current_task, local_data, graph_input_index, channel, input_data);
|
||||
}
|
||||
void *input_data = params_->try_get_input_data_ptr(graph_input_index);
|
||||
if (input_data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (was_loaded.fetch_or(1)) {
|
||||
/* The value was forwarded before. */
|
||||
continue;
|
||||
}
|
||||
this->forward_newly_provided_input(current_task, local_data, graph_input_index, input_data);
|
||||
}
|
||||
}
|
||||
|
||||
void forward_newly_provided_input(CurrentTask ¤t_task,
|
||||
const LocalData &local_data,
|
||||
const int graph_input_index,
|
||||
const ChannelID channel,
|
||||
void *input_data)
|
||||
{
|
||||
const OutputSocket &socket = *self_.graph_inputs_[graph_input_index];
|
||||
@ -665,31 +696,40 @@ class Executor {
|
||||
}
|
||||
|
||||
void notify_output_required(const OutputSocket &socket,
|
||||
const ChannelID channel,
|
||||
CurrentTask ¤t_task,
|
||||
const LocalData &local_data)
|
||||
{
|
||||
const Node &node = socket.node();
|
||||
const int index_in_node = socket.index();
|
||||
NodeState &node_state = *node_states_[node.index_in_graph()];
|
||||
OutputState &output_state = node_state.outputs[index_in_node];
|
||||
OutputState &output_state =
|
||||
channel.is_main() ?
|
||||
node_state.outputs[index_in_node] :
|
||||
*node_state.dynamic_channels->output_states[index_in_node][channel.dynamic_index()];
|
||||
|
||||
/* The notified output socket might be an input of the entire graph. In this case, notify the
|
||||
* caller that the input is required. */
|
||||
if (node.is_interface()) {
|
||||
const int graph_input_index = self_.graph_input_index_by_socket_index_[socket.index()];
|
||||
std::atomic<uint8_t> &was_loaded = loaded_inputs_[graph_input_index];
|
||||
if (was_loaded.load()) {
|
||||
return;
|
||||
if (channel.is_main()) {
|
||||
if (was_loaded.load()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
void *input_data = params_->try_get_input_data_ptr_or_request(graph_input_index);
|
||||
void *input_data = params_->try_get_input_data_ptr_or_request(graph_input_index, channel);
|
||||
if (input_data == nullptr) {
|
||||
return;
|
||||
}
|
||||
if (was_loaded.fetch_or(1)) {
|
||||
/* The value was forwarded already. */
|
||||
return;
|
||||
if (channel.is_main()) {
|
||||
if (was_loaded.fetch_or(1)) {
|
||||
/* The value was forwarded already. */
|
||||
return;
|
||||
}
|
||||
}
|
||||
this->forward_newly_provided_input(current_task, local_data, graph_input_index, input_data);
|
||||
this->forward_newly_provided_input(
|
||||
current_task, local_data, graph_input_index, channel, input_data);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -705,6 +745,7 @@ class Executor {
|
||||
}
|
||||
|
||||
void notify_output_unused(const OutputSocket &socket,
|
||||
const ChannelID channel,
|
||||
CurrentTask ¤t_task,
|
||||
const LocalData &local_data)
|
||||
{
|
||||
@ -788,21 +829,23 @@ class Executor {
|
||||
locked_node.delayed_unused_outputs, current_task, local_data);
|
||||
}
|
||||
|
||||
void send_output_required_notifications(const Span<const OutputSocket *> sockets,
|
||||
CurrentTask ¤t_task,
|
||||
const LocalData &local_data)
|
||||
void send_output_required_notifications(
|
||||
const Span<std::pair<const OutputSocket *, ChannelID>> sockets,
|
||||
CurrentTask ¤t_task,
|
||||
const LocalData &local_data)
|
||||
{
|
||||
for (const OutputSocket *socket : sockets) {
|
||||
this->notify_output_required(*socket, current_task, local_data);
|
||||
for (auto &&[socket, channel] : sockets) {
|
||||
this->notify_output_required(*socket, channel, current_task, local_data);
|
||||
}
|
||||
}
|
||||
|
||||
void send_output_unused_notifications(const Span<const OutputSocket *> sockets,
|
||||
CurrentTask ¤t_task,
|
||||
const LocalData &local_data)
|
||||
void send_output_unused_notifications(
|
||||
const Span<std::pair<const OutputSocket *, ChannelID>> sockets,
|
||||
CurrentTask ¤t_task,
|
||||
const LocalData &local_data)
|
||||
{
|
||||
for (const OutputSocket *socket : sockets) {
|
||||
this->notify_output_unused(*socket, current_task, local_data);
|
||||
for (auto &&[socket, channel] : sockets) {
|
||||
this->notify_output_unused(*socket, channel, current_task, local_data);
|
||||
}
|
||||
}
|
||||
|
||||
@ -865,7 +908,7 @@ class Executor {
|
||||
if (fn_input.usage == ValueUsage::Used) {
|
||||
const InputSocket &input_socket = node.input(input_index);
|
||||
if (input_socket.origin() != nullptr) {
|
||||
this->set_input_required(locked_node, input_socket);
|
||||
this->set_input_required(locked_node, input_socket, ChannelID());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -873,6 +916,7 @@ class Executor {
|
||||
node_state.always_used_inputs_requested = true;
|
||||
}
|
||||
|
||||
/* TODO: Handle dynamic channels. */
|
||||
for (const int input_index : node.inputs().index_range()) {
|
||||
InputState &input_state = node_state.inputs[input_index];
|
||||
if (input_state.was_ready_for_execution) {
|
||||
@ -1044,11 +1088,16 @@ class Executor {
|
||||
});
|
||||
}
|
||||
|
||||
void set_input_unused(LockedNode &locked_node, const InputSocket &input_socket)
|
||||
void set_input_unused(LockedNode &locked_node,
|
||||
const InputSocket &input_socket,
|
||||
const ChannelID channel)
|
||||
{
|
||||
NodeState &node_state = locked_node.node_state;
|
||||
const int input_index = input_socket.index();
|
||||
InputState &input_state = node_state.inputs[input_index];
|
||||
InputState &input_state =
|
||||
channel.is_main() ?
|
||||
node_state.inputs[input_index] :
|
||||
*node_state.dynamic_channels->input_states[input_index][channel.dynamic_index()];
|
||||
|
||||
BLI_assert(input_state.usage != ValueUsage::Used);
|
||||
if (input_state.usage == ValueUsage::Unused) {
|
||||
@ -1062,13 +1111,14 @@ class Executor {
|
||||
}
|
||||
const OutputSocket *origin = input_socket.origin();
|
||||
if (origin != nullptr) {
|
||||
locked_node.delayed_unused_outputs.append(origin);
|
||||
locked_node.delayed_unused_outputs.append({origin, channel});
|
||||
}
|
||||
}
|
||||
|
||||
void *set_input_required_during_execution(const Node &node,
|
||||
NodeState &node_state,
|
||||
const int input_index,
|
||||
const ChannelID channel,
|
||||
CurrentTask ¤t_task,
|
||||
const LocalData &local_data)
|
||||
{
|
||||
@ -1076,17 +1126,22 @@ class Executor {
|
||||
void *result;
|
||||
this->with_locked_node(
|
||||
node, node_state, current_task, local_data, [&](LockedNode &locked_node) {
|
||||
result = this->set_input_required(locked_node, input_socket);
|
||||
result = this->set_input_required(locked_node, input_socket, channel);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
void *set_input_required(LockedNode &locked_node, const InputSocket &input_socket)
|
||||
void *set_input_required(LockedNode &locked_node,
|
||||
const InputSocket &input_socket,
|
||||
const ChannelID channel)
|
||||
{
|
||||
BLI_assert(&locked_node.node == &input_socket.node());
|
||||
NodeState &node_state = locked_node.node_state;
|
||||
const int input_index = input_socket.index();
|
||||
InputState &input_state = node_state.inputs[input_index];
|
||||
InputState &input_state =
|
||||
channel.is_main() ?
|
||||
node_state.inputs[input_index] :
|
||||
*node_state.dynamic_channels->input_states[input_index][channel.dynamic_index()];
|
||||
|
||||
BLI_assert(input_state.usage != ValueUsage::Unused);
|
||||
|
||||
@ -1103,7 +1158,7 @@ class Executor {
|
||||
const OutputSocket *origin_socket = input_socket.origin();
|
||||
/* Unlinked inputs are always loaded in advance. */
|
||||
BLI_assert(origin_socket != nullptr);
|
||||
locked_node.delayed_required_outputs.append(origin_socket);
|
||||
locked_node.delayed_required_outputs.append({origin_socket, channel});
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -1224,8 +1279,8 @@ class Executor {
|
||||
return true;
|
||||
}
|
||||
#ifdef FN_LAZY_FUNCTION_DEBUG_THREADS
|
||||
/* Only the current main thread is allowed to enabled multi-threading, because the executor is
|
||||
* still in single-threaded mode. */
|
||||
/* Only the current main thread is allowed to enabled multi-threading, because the executor
|
||||
* is still in single-threaded mode. */
|
||||
if (current_main_thread_ != std::this_thread::get_id()) {
|
||||
BLI_assert_unreachable();
|
||||
}
|
||||
@ -1342,26 +1397,36 @@ class GraphExecutorLFParams final : public Params {
|
||||
return executor_.get_local_data();
|
||||
}
|
||||
|
||||
void *try_get_input_data_ptr_impl(const int index) const override
|
||||
const InputState &get_channel_input_state(const int index, const ChannelID channel) const
|
||||
{
|
||||
const InputState &input_state = node_state_.inputs[index];
|
||||
if (channel.is_main()) {
|
||||
return node_state_.inputs[index];
|
||||
}
|
||||
BLI_assert(node_state_.dynamic_channels);
|
||||
return *node_state_.dynamic_channels
|
||||
->input_states_for_execution[index][channel.dynamic_index()];
|
||||
}
|
||||
|
||||
void *try_get_input_data_ptr_impl(const int index, const ChannelID channel) const override
|
||||
{
|
||||
const InputState &input_state = this->get_channel_input_state(index, channel);
|
||||
if (input_state.was_ready_for_execution) {
|
||||
return input_state.value;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void *try_get_input_data_ptr_or_request_impl(const int index) override
|
||||
void *try_get_input_data_ptr_or_request_impl(const int index, const ChannelID channel) override
|
||||
{
|
||||
const InputState &input_state = node_state_.inputs[index];
|
||||
const InputState &input_state = this->get_channel_input_state(index, channel);
|
||||
if (input_state.was_ready_for_execution) {
|
||||
return input_state.value;
|
||||
}
|
||||
return executor_.set_input_required_during_execution(
|
||||
node_, node_state_, index, current_task_, this->get_local_data());
|
||||
node_, node_state_, index, channel, current_task_, this->get_local_data());
|
||||
}
|
||||
|
||||
void *get_output_data_ptr_impl(const int index) override
|
||||
void *get_output_data_ptr_impl(const int index, const ChannelID /*channel*/) override
|
||||
{
|
||||
OutputState &output_state = node_state_.outputs[index];
|
||||
BLI_assert(!output_state.has_been_computed);
|
||||
@ -1373,7 +1438,7 @@ class GraphExecutorLFParams final : public Params {
|
||||
return output_state.value;
|
||||
}
|
||||
|
||||
void output_set_impl(const int index) override
|
||||
void output_set_impl(const int index, const ChannelID /*channel*/) override
|
||||
{
|
||||
OutputState &output_state = node_state_.outputs[index];
|
||||
BLI_assert(!output_state.has_been_computed);
|
||||
@ -1387,24 +1452,45 @@ class GraphExecutorLFParams final : public Params {
|
||||
output_state.has_been_computed = true;
|
||||
}
|
||||
|
||||
bool output_was_set_impl(const int index) const override
|
||||
bool output_was_set_impl(const int index, const ChannelID /*channel*/) const override
|
||||
{
|
||||
const OutputState &output_state = node_state_.outputs[index];
|
||||
return output_state.has_been_computed;
|
||||
}
|
||||
|
||||
ValueUsage get_output_usage_impl(const int index) const override
|
||||
ValueUsage get_output_usage_impl(const int index, const ChannelID /*channel*/) const override
|
||||
{
|
||||
const OutputState &output_state = node_state_.outputs[index];
|
||||
return output_state.usage_for_execution;
|
||||
}
|
||||
|
||||
void set_input_unused_impl(const int index) override
|
||||
void set_input_unused_impl(const int index, const ChannelID /*channel*/) override
|
||||
{
|
||||
executor_.set_input_unused_during_execution(
|
||||
node_, node_state_, index, current_task_, this->get_local_data());
|
||||
}
|
||||
|
||||
int get_input_channels_num_impl(const int index) const override
|
||||
{
|
||||
if (!node_state_.dynamic_channels) {
|
||||
return 0;
|
||||
}
|
||||
return node_state_.dynamic_channels->input_states_for_execution[index].size();
|
||||
}
|
||||
|
||||
int get_output_channels_num_impl(const int index) const override
|
||||
{
|
||||
if (!node_state_.dynamic_channels) {
|
||||
return 0;
|
||||
}
|
||||
return node_state_.dynamic_channels->output_states_for_execution[index].size();
|
||||
}
|
||||
|
||||
IndexRange add_output_channels_impl(const int /*index*/, const int /*num*/)
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
bool try_enable_multi_threading_impl() override
|
||||
{
|
||||
const bool success = executor_.try_enable_multi_threading();
|
||||
@ -1434,9 +1520,9 @@ inline void Executor::execute_node(const FunctionNode &node,
|
||||
self_.logger_->log_before_node_execute(node, node_params, fn_context);
|
||||
}
|
||||
|
||||
/* This is run when the execution of the node calls `lazy_threading::send_hint` to indicate that
|
||||
* the execution will take a while. In this case, other tasks waiting on this thread should be
|
||||
* allowed to be picked up by another thread. */
|
||||
/* This is run when the execution of the node calls `lazy_threading::send_hint` to indicate
|
||||
* that the execution will take a while. In this case, other tasks waiting on this thread
|
||||
* should be allowed to be picked up by another thread. */
|
||||
auto blocking_hint_fn = [&]() {
|
||||
if (!current_task.has_scheduled_nodes.load()) {
|
||||
return;
|
||||
|
Loading…
Reference in New Issue
Block a user