WIP: Field type inferencing using a constraint solver method #120420

Draft
Lukas Tönne wants to merge 52 commits from LukasTonne/blender:socket-type-inference into main

When changing the target branch, be careful to rebase the branch in your fork to match. See documentation.
1 changed files with 352 additions and 360 deletions
Showing only changes of commit b3128a83e7 - Show all commits

View File

@ -709,15 +709,217 @@ static void prepare_inferencing_interfaces(
}
}
enum DomainValue { Single, Field, NumDomainValues };
enum DomainValue {
/* Socket is a single value. */
Single,
/* Socket is a field. */
Field,
/* Keep last. */
NumDomainValues
};
namespace csp = constraint_satisfaction;
static void add_node_type_constraints(const bNodeTree &tree,
const bNode &node,
csp::ConstraintSet &constraints)
struct NodeTreeVariables {
/* Index ranges of variables for node sockets and interface sockets.
* Group input/output nodes use the tree interface variables directly,
* their socket variables are unused. */
IndexRange socket_vars;
IndexRange tree_input_vars;
IndexRange tree_output_vars;
NodeTreeVariables(const bNodeTree &tree)
{
this->socket_vars = tree.all_sockets().index_range();
this->tree_input_vars = socket_vars.after(tree.interface_inputs().size());
this->tree_output_vars = tree_input_vars.after(tree.interface_outputs().size());
}
int num_vars() const
{
return tree_output_vars.one_after_last();
}
csp::VariableIndex get_socket_variable(const bNodeSocket &socket) const
{
if (socket.is_output()) {
/* Use tree variables directly for group input nodes (except the extension socket). */
if (socket.owner_node().is_group_input() &&
tree_input_vars.index_range().contains(socket.index()))
{
return tree_input_vars[socket.index()];
}
return socket_vars[socket.index_in_tree()];
}
else {
/* Use tree variables directly for group output nodes (except the extension socket). */
if (socket.owner_node().is_group_output() &&
tree_output_vars.index_range().contains(socket.index()))
{
return tree_output_vars[socket.index()];
}
return socket_vars[socket.index_in_tree()];
}
}
};
static std::string input_field_type_name(const InputSocketFieldType type)
{
tree.ensure_topology_cache();
switch (type) {
case InputSocketFieldType::None:
return "None";
case InputSocketFieldType::IsSupported:
return "IsSupported";
case InputSocketFieldType::Implicit:
return "Implicit";
}
return "";
}
static std::string output_field_type_name(const OutputSocketFieldType type)
{
switch (type) {
case OutputSocketFieldType::None:
return "None";
case OutputSocketFieldType::FieldSource:
return "FieldSource";
case OutputSocketFieldType::DependentField:
return "DependentField";
case OutputSocketFieldType::PartiallyDependent:
return "PartiallyDependent";
}
return "";
}
static bool verify_field_inferencing_csp_result(
const bNodeTree &tree,
const Span<const FieldInferencingInterface *> interface_by_node,
const BitGroupVector<> csp_result,
const FieldInferencingInterface &inferencing_interface)
{
/* Use the old propagation method to provide "ground truth" to compare against. */
const bool use_propagation_result = true;
NodeTreeVariables variables(tree);
Array<SocketFieldState> field_state_by_socket_id;
std::unique_ptr<FieldInferencingInterface> tmp_inferencing_interface;
if (use_propagation_result) {
/* Keep track of the state of all sockets. The index into this array is #SocketRef::id(). */
field_state_by_socket_id.reinitialize(tree.all_sockets().size());
/* Temp local inferencing interface to avoid overwriting the actual interface.
* The propagation method directly writes to the interface. */
tmp_inferencing_interface = std::make_unique<FieldInferencingInterface>();
tmp_inferencing_interface->inputs.resize(tree.interface_inputs().size(),
InputSocketFieldType::IsSupported);
tmp_inferencing_interface->outputs.resize(tree.interface_outputs().size(),
OutputFieldDependency::ForDataSource());
propagate_data_requirements_from_right_to_left(
tree, interface_by_node, field_state_by_socket_id);
determine_group_input_states(tree, *tmp_inferencing_interface, field_state_by_socket_id);
propagate_field_status_from_left_to_right(tree, interface_by_node, field_state_by_socket_id);
determine_group_output_states(
tree, *tmp_inferencing_interface, interface_by_node, field_state_by_socket_id);
}
std::cout << "Verify field type inferencing for tree " << tree.id.name << std::endl;
bool error = false;
for (const bNodeSocket *socket : tree.all_sockets()) {
if (!socket->is_available()) {
continue;
}
auto log_error = [&](StringRef message) {
const std::string socket_address = std::string(socket->owner_node().name) +
(socket->is_output() ? "|>" : "<|") + socket->identifier;
std::cout << " [Error] " << socket_address << ": " << message << std::endl;
error = true;
};
const int var = variables.get_socket_variable(*socket);
const BitSpan state = csp_result[var];
const int num_values = int(state[DomainValue::Single]) + int(state[DomainValue::Field]);
if (num_values == 0)
{
log_error("No valid result");
continue;
}
if (use_propagation_result) {
const SocketFieldState &old_state = field_state_by_socket_id[socket->index_in_tree()];
if (old_state.is_always_single) {
if (!state[DomainValue::Single] || state[DomainValue::Field])
{
log_error("Should only be single value");
}
continue;
}
if (!old_state.is_single) {
if (state[DomainValue::Single] || !state[DomainValue::Field])
{
log_error("Should only be field");
}
continue;
}
if (old_state.requires_single) {
if (!state[DomainValue::Single] || state[DomainValue::Field])
{
log_error("Should only be single value");
}
continue;
}
if (!state[DomainValue::Single] || !state[DomainValue::Field])
{
log_error("Should be single value or field");
}
}
}
if (use_propagation_result) {
for (const int i : tree.interface_inputs().index_range()) {
auto log_error = [&](StringRef message) {
std::cout << " [Error] " << tree.interface_inputs()[i]->identifier << ": " << message
<< std::endl;
error = true;
};
const InputSocketFieldType old_field_type = tmp_inferencing_interface->inputs[i];
const InputSocketFieldType new_field_type = inferencing_interface.inputs[i];
if (old_field_type != new_field_type) {
log_error("Input field type is " + input_field_type_name(new_field_type) + ", expected " +
input_field_type_name(old_field_type));
}
}
for (const int i : tree.interface_outputs().index_range()) {
auto log_error = [&](StringRef message) {
std::cout << " [Error] " << tree.interface_outputs()[i]->identifier << ": " << message
<< std::endl;
error = true;
};
const OutputFieldDependency &old_field_dep = tmp_inferencing_interface->outputs[i];
const OutputFieldDependency &new_field_dep = inferencing_interface.outputs[i];
if (old_field_dep.field_type() != new_field_dep.field_type()) {
log_error("Output field type is " + output_field_type_name(new_field_dep.field_type()) +
", expected " + output_field_type_name(old_field_dep.field_type()));
}
if (old_field_dep.linked_input_indices() != new_field_dep.linked_input_indices()) {
log_error("Output field dependencies don't match");
}
}
}
if (!error) {
std::cout << " OK!" << std::endl;
}
return error;
}
static void add_node_constraints(const bNodeTree &tree,
const bNode &node,
const FieldInferencingInterface &inferencing_interface,
csp::ConstraintSet &constraints)
{
NodeTreeVariables variables(tree);
/* Constraint is satisfied if both inputs or outputs of a zone node pair are the same type. */
auto shared_field_type_constraint = [](const int value_a, const int value_b) {
@ -748,6 +950,121 @@ static void add_node_type_constraints(const bNodeTree &tree,
}
};
for (const bNodeSocket *output_socket : node.output_sockets()) {
if (!output_socket->is_available()) {
continue;
}
const int var = variables.get_socket_variable(*output_socket);
const bNodeSocketType *typeinfo = output_socket->typeinfo;
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) : SOCK_CUSTOM;
if (!nodes::socket_type_supports_fields(type)) {
constraints.add_unary(var, [](const int value) { return value == DomainValue::Single; });
}
const OutputFieldDependency &field_dependency =
inferencing_interface.outputs[output_socket->index()];
switch (field_dependency.field_type()) {
/* Fixed single value output. */
case OutputSocketFieldType::None:
constraints.add_unary(var, [](const int value) { return value == DomainValue::Single; });
break;
/* Fixed field source output. */
case OutputSocketFieldType::FieldSource:
constraints.add_unary(var, [](const int value) { return value == DomainValue::Field; });
break;
/* Internal dependency on one or more inputs. */
case OutputSocketFieldType::DependentField:
case OutputSocketFieldType::PartiallyDependent:
for (const bNodeSocket *input_socket :
gather_input_socket_dependencies(field_dependency, node))
{
if (!input_socket->is_available()) {
continue;
}
const int input_var = variables.get_socket_variable(*input_socket);
/* The output must be a field if the input it depends on is a field. */
constraints.add_binary(var, input_var, [](const int value_dst, const int value_src) {
return value_dst == DomainValue::Field || value_src == DomainValue::Single;
});
}
break;
}
/* The output must be a single value when it is connected to any input that does
* not support fields. */
for (const bNodeSocket *src_socket : output_socket->directly_linked_sockets()) {
if (src_socket->is_available()) {
const int src_var = variables.get_socket_variable(*src_socket);
constraints.add_binary(var, src_var, [](int value_dst, int value_src) {
return value_dst == DomainValue::Single || value_src == DomainValue::Field;
});
}
}
// if (state.requires_single) {
// bool any_input_is_field_implicitly = false;
// const Vector<const bNodeSocket *> connected_inputs = gather_input_socket_dependencies(
// field_dependency, *node);
// for (const bNodeSocket *input_socket : connected_inputs) {
// if (!input_socket->is_available()) {
// continue;
// }
// if (inferencing_interface.inputs[input_socket->index()] ==
// InputSocketFieldType::Implicit)
// {
// if (!input_socket->is_logically_linked()) {
// any_input_is_field_implicitly = true;
// break;
// }
// }
// }
// if (any_input_is_field_implicitly) {
// /* This output isn't a single value actually. */
// state.requires_single = false;
// }
// else {
// /* If the output is required to be a single value, the connected inputs in the same
// * node must not be fields as well. */
// for (const bNodeSocket *input_socket : connected_inputs) {
// field_state_by_socket_id[input_socket->index_in_tree()].requires_single = true;
// }
// }
// }
}
/* Some inputs do not require fields independent of what the outputs are connected to. */
for (const bNodeSocket *input_socket : node.input_sockets()) {
if (!input_socket->is_available()) {
continue;
}
const int var = variables.get_socket_variable(*input_socket);
const bNodeSocketType *typeinfo = input_socket->typeinfo;
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) : SOCK_CUSTOM;
if (!nodes::socket_type_supports_fields(type)) {
constraints.add_unary(var, [](const int value) { return value == DomainValue::Single; });
}
const InputSocketFieldType field_type = inferencing_interface.inputs[input_socket->index()];
if (field_type == InputSocketFieldType::None) {
constraints.add_unary(input_socket->index_in_tree(),
[](int value) { return value == DomainValue::Single; });
}
/* The input must be a field value when it is connected to any output that can't be a single
* value. */
for (const bNodeSocket *src_socket : input_socket->directly_linked_sockets()) {
if (src_socket->is_available()) {
const int src_var = variables.get_socket_variable(*src_socket);
constraints.add_binary(var, src_var, [](int value_dst, int value_src) {
return value_dst == DomainValue::Field || value_src == DomainValue::Single;
});
}
}
}
/* Special constraints for certain node types. */
switch (node.type) {
case GEO_NODE_SIMULATION_INPUT: {
const NodeGeometrySimulationInput &data = *static_cast<const NodeGeometrySimulationInput *>(
@ -763,17 +1080,6 @@ static void add_node_type_constraints(const bNodeTree &tree,
}
case GEO_NODE_SIMULATION_OUTPUT: {
/* Already handled in the input node case. */
// for (const bNode *input_node : tree.nodes_by_type("GeometryNodeSimulationInput")) {
// const NodeGeometrySimulationInput &data =
// *static_cast<const NodeGeometrySimulationInput *>(input_node->storage);
// if (node.identifier == data.output_node_id) {
// /* First input node output is Delta Time which does not appear in the output node. */
// add_zone_constraints(input_node->input_sockets(),
// input_node->output_sockets().drop_front(1),
// node.input_sockets(),
// node.output_sockets());
// }
// }
break;
}
case GEO_NODE_REPEAT_INPUT: {
@ -789,21 +1095,34 @@ static void add_node_type_constraints(const bNodeTree &tree,
}
case GEO_NODE_REPEAT_OUTPUT: {
/* Already handled in the input node case. */
// for (const bNode *input_node : tree.nodes_by_type("GeometryNodeRepeatInput")) {
// const NodeGeometryRepeatInput &data = *static_cast<const NodeGeometryRepeatInput *>(
// input_node->storage);
// if (node.identifier == data.output_node_id) {
// add_zone_constraints(input_node->input_sockets(),
// input_node->output_sockets(),
// node.input_sockets(),
// node.output_sockets());
// }
// }
break;
}
}
}
static void update_socket_shapes(const bNodeTree &tree,
const NodeTreeVariables &variables,
const BitGroupVector<> csp_result)
{
auto get_shape_for_state = [](const BitSpan state) {
return state[DomainValue::Field] ?
(state[DomainValue::Single] ? SOCK_DISPLAY_SHAPE_DIAMOND_DOT :
SOCK_DISPLAY_SHAPE_DIAMOND) :
SOCK_DISPLAY_SHAPE_CIRCLE;
};
for (const bNodeSocket *socket : tree.all_input_sockets()) {
const int var = variables.get_socket_variable(*socket);
const BitSpan state = csp_result[var];
const_cast<bNodeSocket *>(socket)->display_shape = get_shape_for_state(state);
}
for (const bNodeSocket *socket : tree.all_sockets()) {
const int var = variables.get_socket_variable(*socket);
const BitSpan state = csp_result[var];
const_cast<bNodeSocket *>(socket)->display_shape = get_shape_for_state(state);
}
}
/**
* Check what the group output socket depends on. Potentially traverses the node tree
* to figure out if it is always a field or if it depends on any group inputs.
@ -838,7 +1157,6 @@ static OutputFieldDependency find_group_output_dependencies(
/* This socket uses a field as input by default. */
return OutputFieldDependency::ForFieldSource();
}
return OutputFieldDependency::ForFieldSource();
for (const bNodeSocket *origin_socket : input_socket->directly_linked_sockets()) {
const bNode &origin_node = origin_socket->owner_node();
@ -884,205 +1202,6 @@ static OutputFieldDependency find_group_output_dependencies(
return OutputFieldDependency::ForPartiallyDependentField(std::move(linked_input_indices));
}
struct NodeTreeVariables {
/* Index ranges of variables for node sockets and interface sockets.
* Group input/output nodes use the tree interface variables directly,
* their socket variables are unused. */
IndexRange socket_vars;
IndexRange tree_input_vars;
IndexRange tree_output_vars;
NodeTreeVariables(const bNodeTree &tree)
{
this->socket_vars = tree.all_sockets().index_range();
this->tree_input_vars = socket_vars.after(tree.interface_inputs().size());
this->tree_output_vars = tree_input_vars.after(tree.interface_outputs().size());
}
int num_vars() const
{
return tree_output_vars.one_after_last();
}
csp::VariableIndex get_socket_variable(const bNodeSocket &socket) const
{
if (socket.is_output()) {
/* Use tree variables directly for group input nodes (except the extension socket). */
if (socket.owner_node().is_group_input() &&
tree_input_vars.index_range().contains(socket.index()))
{
return tree_input_vars[socket.index()];
}
return socket_vars[socket.index_in_tree()];
}
else {
/* Use tree variables directly for group output nodes (except the extension socket). */
if (socket.owner_node().is_group_output() &&
tree_output_vars.index_range().contains(socket.index()))
{
return tree_output_vars[socket.index()];
}
return socket_vars[socket.index_in_tree()];
}
}
};
/* Verify inferencing result by comparing to the old propagation method. */
static bool verify_field_inferencing_csp_result(
const bNodeTree &tree,
const Span<const FieldInferencingInterface *> interface_by_node,
const BitGroupVector<> csp_result)
{
NodeTreeVariables variables(tree);
/* Keep track of the state of all sockets. The index into this array is #SocketRef::id(). */
Array<SocketFieldState> field_state_by_socket_id(tree.all_sockets().size());
/* Temp local inferencing interface to avoid overwriting the actual interface.
* The propagation method directly writes to the interface. */
std::unique_ptr<FieldInferencingInterface> tmp_inferencing_interface =
std::make_unique<FieldInferencingInterface>();
tmp_inferencing_interface->inputs.resize(tree.interface_inputs().size(),
InputSocketFieldType::IsSupported);
tmp_inferencing_interface->outputs.resize(tree.interface_outputs().size(),
OutputFieldDependency::ForDataSource());
propagate_data_requirements_from_right_to_left(
tree, interface_by_node, field_state_by_socket_id);
determine_group_input_states(tree, *tmp_inferencing_interface, field_state_by_socket_id);
propagate_field_status_from_left_to_right(tree, interface_by_node, field_state_by_socket_id);
determine_group_output_states(
tree, *tmp_inferencing_interface, interface_by_node, field_state_by_socket_id);
std::cout << "Verify field type inferencing for tree " << tree.id.name << std::endl;
bool error = false;
for (const bNodeSocket *socket : tree.all_sockets()) {
if (!socket->is_available()) {
continue;
}
auto log_error = [&](StringRef message) {
const std::string socket_address = std::string(socket->owner_node().name) +
(socket->is_output() ? "|>" : "<|") + socket->identifier;
std::cout << " [Error] " << socket_address << ": " << message << std::endl;
error = true;
};
const int var = variables.get_socket_variable(*socket);
const BitSpan state = csp_result[var];
if (!state[DomainValue::Single] && !state[DomainValue::Field]) {
log_error("Neither single value nor field");
continue;
}
const SocketFieldState &old_state = field_state_by_socket_id[socket->index_in_tree()];
if (old_state.is_always_single) {
if (!state[DomainValue::Single] || state[DomainValue::Field]) {
log_error("Should only be single value");
}
continue;
}
if (!old_state.is_single) {
if (state[DomainValue::Single] || !state[DomainValue::Field]) {
log_error("Should only be field");
}
continue;
}
if (old_state.requires_single) {
if (!state[DomainValue::Single] || state[DomainValue::Field]) {
log_error("Should only be single value");
}
continue;
}
if (!state[DomainValue::Single] || !state[DomainValue::Field]) {
log_error("Should be both single value and field");
}
}
for (const int i : tree.interface_inputs().index_range()) {
auto log_error = [&](StringRef message) {
std::cout << " [Error] " << tree.interface_inputs()[i]->identifier << ": " << message << std::endl;
error = true;
};
const int var = variables.tree_input_vars[i];
const BitSpan state = csp_result[var];
if (!state[DomainValue::Single] && !state[DomainValue::Field]) {
log_error("Neither single value nor field");
continue;
}
const InputSocketFieldType &old_state = tmp_inferencing_interface->inputs[i];
switch (old_state) {
case InputSocketFieldType::None:
if (!state[DomainValue::Single] || state[DomainValue::Field]) {
log_error("Should only be single value");
}
break;
case InputSocketFieldType::IsSupported:
case InputSocketFieldType::Implicit:
if (!state[DomainValue::Single] || !state[DomainValue::Field]) {
log_error("Should be both single value and field");
}
break;
}
}
for (const int i : tree.interface_outputs().index_range()) {
auto log_error = [&](StringRef message) {
std::cout << " [Error] " << tree.interface_outputs()[i]->identifier << ": " << message
<< std::endl;
error = true;
};
const int var = variables.tree_output_vars[i];
const BitSpan state = csp_result[var];
if (!state[DomainValue::Single] && !state[DomainValue::Field]) {
log_error("Neither single value nor field");
continue;
}
const OutputFieldDependency &old_state = tmp_inferencing_interface->outputs[i];
switch (old_state.field_type()) {
case OutputSocketFieldType::None:
if (!state[DomainValue::Single] || state[DomainValue::Field]) {
log_error("Should only be single value");
}
break;
case OutputSocketFieldType::FieldSource:
if (state[DomainValue::Single] || !state[DomainValue::Field]) {
log_error("Should only be field");
}
break;
case OutputSocketFieldType::DependentField:
case OutputSocketFieldType::PartiallyDependent:
if (!state[DomainValue::Single] || !state[DomainValue::Field]) {
log_error("Should be both single value and field");
}
break;
}
}
if (!error) {
std::cout << " OK!" << std::endl;
}
return error;
}
static void update_socket_shapes(const bNodeTree &tree,
const NodeTreeVariables &variables,
const BitGroupVector<> csp_result)
{
auto get_shape_for_state = [](const BitSpan state) {
return state[DomainValue::Field] ?
(state[DomainValue::Single] ? SOCK_DISPLAY_SHAPE_DIAMOND_DOT :
SOCK_DISPLAY_SHAPE_DIAMOND) :
SOCK_DISPLAY_SHAPE_CIRCLE;
};
for (const bNodeSocket *socket : tree.all_input_sockets()) {
const int var = variables.get_socket_variable(*socket);
const BitSpan state = csp_result[var];
const_cast<bNodeSocket *>(socket)->display_shape = get_shape_for_state(state);
}
for (const bNodeSocket *socket : tree.all_sockets()) {
const int var = variables.get_socket_variable(*socket);
const BitSpan state = csp_result[var];
const_cast<bNodeSocket *>(socket)->display_shape = get_shape_for_state(state);
}
}
template<typename Logger>
static void solve_field_inferencing_constraints(
const bNodeTree &tree,
@ -1097,7 +1216,8 @@ static void solve_field_inferencing_constraints(
tree.ensure_topology_cache();
NodeTreeVariables variables(tree);
auto variable_name = [&](const csp::VariableIndex var) -> std::string {
logger.declare_variables(variables.num_vars(), [&](const csp::VariableIndex var) -> std::string {
if (variables.socket_vars.contains(var)) {
const bNodeSocket &socket = *tree.all_sockets()[var - variables.socket_vars.start()];
const bNode &node = socket.owner_node();
@ -1115,150 +1235,19 @@ static void solve_field_inferencing_constraints(
return std::string("O:") + iosocket.identifier;
}
return "";
};
logger.declare_variables(variables.num_vars(), variable_name);
});
const Span<const bNode *> nodes = tree.toposort_right_to_left();
csp::ConstraintSet constraints;
for (const bNode *node : nodes) {
/* Special case: Group inputs and outputs use the interface variables directly. */
const IndexRange interface_inputs = node->is_group_input() ?
tree.interface_inputs().index_range() :
IndexRange();
const IndexRange interface_outputs = node->is_group_output() ?
tree.interface_outputs().index_range() :
IndexRange();
const FieldInferencingInterface &inferencing_interface = *interface_by_node[node->index()];
for (const bNodeSocket *output_socket : node->output_sockets()) {
if (!output_socket->is_available()) {
continue;
}
const int var = variables.get_socket_variable(*output_socket);
const bNodeSocketType *typeinfo = output_socket->typeinfo;
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
SOCK_CUSTOM;
if (!nodes::socket_type_supports_fields(type)) {
constraints.add_unary(var, [](const int value) { return value == DomainValue::Single; });
}
const OutputFieldDependency &field_dependency =
inferencing_interface.outputs[output_socket->index()];
switch (field_dependency.field_type()) {
/* Fixed single value output. */
case OutputSocketFieldType::None:
constraints.add_unary(var, [](const int value) { return value == DomainValue::Single; });
break;
/* Fixed field source output. */
case OutputSocketFieldType::FieldSource:
constraints.add_unary(var, [](const int value) { return value == DomainValue::Field; });
break;
/* Internal dependency on one or more inputs. */
case OutputSocketFieldType::DependentField:
case OutputSocketFieldType::PartiallyDependent:
for (const bNodeSocket *input_socket :
gather_input_socket_dependencies(field_dependency, *node))
{
if (!input_socket->is_available()) {
continue;
}
const int input_var = variables.get_socket_variable(*input_socket);
/* The output must be a field if the input it depends on is a field. */
constraints.add_binary(var, input_var, [](const int value_dst, const int value_src) {
return value_dst == DomainValue::Field || value_src == DomainValue::Single;
});
}
break;
}
/* The output must be a single value when it is connected to any input that does
* not support fields. */
for (const bNodeSocket *src_socket : output_socket->directly_linked_sockets()) {
if (src_socket->is_available()) {
constraints.add_binary(
var, src_socket->index_in_tree(), [](int value_dst, int value_src) {
return value_dst == DomainValue::Single || value_src == DomainValue::Field;
});
}
}
// if (state.requires_single) {
// bool any_input_is_field_implicitly = false;
// const Vector<const bNodeSocket *> connected_inputs = gather_input_socket_dependencies(
// field_dependency, *node);
// for (const bNodeSocket *input_socket : connected_inputs) {
// if (!input_socket->is_available()) {
// continue;
// }
// if (inferencing_interface.inputs[input_socket->index()] ==
// InputSocketFieldType::Implicit)
// {
// if (!input_socket->is_logically_linked()) {
// any_input_is_field_implicitly = true;
// break;
// }
// }
// }
// if (any_input_is_field_implicitly) {
// /* This output isn't a single value actually. */
// state.requires_single = false;
// }
// else {
// /* If the output is required to be a single value, the connected inputs in the same
// * node must not be fields as well. */
// for (const bNodeSocket *input_socket : connected_inputs) {
// field_state_by_socket_id[input_socket->index_in_tree()].requires_single = true;
// }
// }
// }
}
/* Some inputs do not require fields independent of what the outputs are connected to. */
for (const bNodeSocket *input_socket : node->input_sockets()) {
if (!input_socket->is_available()) {
continue;
}
const int var = variables.get_socket_variable(*input_socket);
const bNodeSocketType *typeinfo = input_socket->typeinfo;
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
SOCK_CUSTOM;
if (!nodes::socket_type_supports_fields(type)) {
constraints.add_unary(var, [](const int value) { return value == DomainValue::Single; });
}
const InputSocketFieldType field_type = inferencing_interface.inputs[input_socket->index()];
if (field_type == InputSocketFieldType::None) {
constraints.add_unary(input_socket->index_in_tree(),
[](int value) { return value == DomainValue::Single; });
}
/* The input must be a field value when it is connected to any output that can't be a single value. */
for (const bNodeSocket *src_socket : input_socket->directly_linked_sockets()) {
if (src_socket->is_available()) {
constraints.add_binary(
var, src_socket->index_in_tree(), [](int value_dst, int value_src) {
return value_dst == DomainValue::Field || value_src == DomainValue::Single;
});
}
}
}
/* Constraints for consistent field type across zones. */
add_node_type_constraints(tree, *node, constraints);
add_node_constraints(tree, *node, *interface_by_node[node->index()], constraints);
}
logger.declare_constraints(constraints);
BitGroupVector<> result = csp::solve_constraints_with_logger(
constraints, variables.num_vars(), NumDomainValues, logger);
/* Perform old propagation method as well to verify the result. */
if (true) {
verify_field_inferencing_csp_result(tree, interface_by_node, result);
}
/* Setup inferencing interface for the tree. */
for (const int i : tree.interface_inputs().index_range()) {
const int var = variables.tree_input_vars[i];
@ -1288,6 +1277,9 @@ static void solve_field_inferencing_constraints(
inferencing_interface.outputs[group_output_socket->index()] = std::move(field_dependency);
}
/* Verify the result. */
verify_field_inferencing_csp_result(tree, interface_by_node, result, inferencing_interface);
update_socket_shapes(tree, variables, result);
}