WIP: Field type inferencing using a constraint solver method #120420
|
@ -709,15 +709,217 @@ static void prepare_inferencing_interfaces(
|
|||
}
|
||||
}
|
||||
|
||||
enum DomainValue { Single, Field, NumDomainValues };
|
||||
enum DomainValue {
|
||||
/* Socket is a single value. */
|
||||
Single,
|
||||
/* Socket is a field. */
|
||||
Field,
|
||||
|
||||
/* Keep last. */
|
||||
NumDomainValues
|
||||
};
|
||||
|
||||
namespace csp = constraint_satisfaction;
|
||||
|
||||
static void add_node_type_constraints(const bNodeTree &tree,
|
||||
const bNode &node,
|
||||
csp::ConstraintSet &constraints)
|
||||
struct NodeTreeVariables {
|
||||
/* Index ranges of variables for node sockets and interface sockets.
|
||||
* Group input/output nodes use the tree interface variables directly,
|
||||
* their socket variables are unused. */
|
||||
IndexRange socket_vars;
|
||||
IndexRange tree_input_vars;
|
||||
IndexRange tree_output_vars;
|
||||
|
||||
NodeTreeVariables(const bNodeTree &tree)
|
||||
{
|
||||
this->socket_vars = tree.all_sockets().index_range();
|
||||
this->tree_input_vars = socket_vars.after(tree.interface_inputs().size());
|
||||
this->tree_output_vars = tree_input_vars.after(tree.interface_outputs().size());
|
||||
}
|
||||
|
||||
int num_vars() const
|
||||
{
|
||||
return tree_output_vars.one_after_last();
|
||||
}
|
||||
|
||||
csp::VariableIndex get_socket_variable(const bNodeSocket &socket) const
|
||||
{
|
||||
if (socket.is_output()) {
|
||||
/* Use tree variables directly for group input nodes (except the extension socket). */
|
||||
if (socket.owner_node().is_group_input() &&
|
||||
tree_input_vars.index_range().contains(socket.index()))
|
||||
{
|
||||
return tree_input_vars[socket.index()];
|
||||
}
|
||||
return socket_vars[socket.index_in_tree()];
|
||||
}
|
||||
else {
|
||||
/* Use tree variables directly for group output nodes (except the extension socket). */
|
||||
if (socket.owner_node().is_group_output() &&
|
||||
tree_output_vars.index_range().contains(socket.index()))
|
||||
{
|
||||
return tree_output_vars[socket.index()];
|
||||
}
|
||||
return socket_vars[socket.index_in_tree()];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static std::string input_field_type_name(const InputSocketFieldType type)
|
||||
{
|
||||
tree.ensure_topology_cache();
|
||||
switch (type) {
|
||||
case InputSocketFieldType::None:
|
||||
return "None";
|
||||
case InputSocketFieldType::IsSupported:
|
||||
return "IsSupported";
|
||||
case InputSocketFieldType::Implicit:
|
||||
return "Implicit";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
static std::string output_field_type_name(const OutputSocketFieldType type)
|
||||
{
|
||||
switch (type) {
|
||||
case OutputSocketFieldType::None:
|
||||
return "None";
|
||||
case OutputSocketFieldType::FieldSource:
|
||||
return "FieldSource";
|
||||
case OutputSocketFieldType::DependentField:
|
||||
return "DependentField";
|
||||
case OutputSocketFieldType::PartiallyDependent:
|
||||
return "PartiallyDependent";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
static bool verify_field_inferencing_csp_result(
|
||||
const bNodeTree &tree,
|
||||
const Span<const FieldInferencingInterface *> interface_by_node,
|
||||
const BitGroupVector<> csp_result,
|
||||
const FieldInferencingInterface &inferencing_interface)
|
||||
{
|
||||
/* Use the old propagation method to provide "ground truth" to compare against. */
|
||||
const bool use_propagation_result = true;
|
||||
|
||||
NodeTreeVariables variables(tree);
|
||||
|
||||
Array<SocketFieldState> field_state_by_socket_id;
|
||||
std::unique_ptr<FieldInferencingInterface> tmp_inferencing_interface;
|
||||
if (use_propagation_result) {
|
||||
/* Keep track of the state of all sockets. The index into this array is #SocketRef::id(). */
|
||||
field_state_by_socket_id.reinitialize(tree.all_sockets().size());
|
||||
|
||||
/* Temp local inferencing interface to avoid overwriting the actual interface.
|
||||
* The propagation method directly writes to the interface. */
|
||||
tmp_inferencing_interface = std::make_unique<FieldInferencingInterface>();
|
||||
tmp_inferencing_interface->inputs.resize(tree.interface_inputs().size(),
|
||||
InputSocketFieldType::IsSupported);
|
||||
tmp_inferencing_interface->outputs.resize(tree.interface_outputs().size(),
|
||||
OutputFieldDependency::ForDataSource());
|
||||
|
||||
propagate_data_requirements_from_right_to_left(
|
||||
tree, interface_by_node, field_state_by_socket_id);
|
||||
determine_group_input_states(tree, *tmp_inferencing_interface, field_state_by_socket_id);
|
||||
propagate_field_status_from_left_to_right(tree, interface_by_node, field_state_by_socket_id);
|
||||
determine_group_output_states(
|
||||
tree, *tmp_inferencing_interface, interface_by_node, field_state_by_socket_id);
|
||||
}
|
||||
|
||||
std::cout << "Verify field type inferencing for tree " << tree.id.name << std::endl;
|
||||
bool error = false;
|
||||
for (const bNodeSocket *socket : tree.all_sockets()) {
|
||||
if (!socket->is_available()) {
|
||||
continue;
|
||||
}
|
||||
auto log_error = [&](StringRef message) {
|
||||
const std::string socket_address = std::string(socket->owner_node().name) +
|
||||
(socket->is_output() ? "|>" : "<|") + socket->identifier;
|
||||
std::cout << " [Error] " << socket_address << ": " << message << std::endl;
|
||||
error = true;
|
||||
};
|
||||
const int var = variables.get_socket_variable(*socket);
|
||||
const BitSpan state = csp_result[var];
|
||||
const int num_values = int(state[DomainValue::Single]) + int(state[DomainValue::Field]);
|
||||
if (num_values == 0)
|
||||
{
|
||||
log_error("No valid result");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (use_propagation_result) {
|
||||
const SocketFieldState &old_state = field_state_by_socket_id[socket->index_in_tree()];
|
||||
if (old_state.is_always_single) {
|
||||
if (!state[DomainValue::Single] || state[DomainValue::Field])
|
||||
{
|
||||
log_error("Should only be single value");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (!old_state.is_single) {
|
||||
if (state[DomainValue::Single] || !state[DomainValue::Field])
|
||||
{
|
||||
log_error("Should only be field");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (old_state.requires_single) {
|
||||
if (!state[DomainValue::Single] || state[DomainValue::Field])
|
||||
{
|
||||
log_error("Should only be single value");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (!state[DomainValue::Single] || !state[DomainValue::Field])
|
||||
{
|
||||
log_error("Should be single value or field");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (use_propagation_result) {
|
||||
for (const int i : tree.interface_inputs().index_range()) {
|
||||
auto log_error = [&](StringRef message) {
|
||||
std::cout << " [Error] " << tree.interface_inputs()[i]->identifier << ": " << message
|
||||
<< std::endl;
|
||||
error = true;
|
||||
};
|
||||
const InputSocketFieldType old_field_type = tmp_inferencing_interface->inputs[i];
|
||||
const InputSocketFieldType new_field_type = inferencing_interface.inputs[i];
|
||||
if (old_field_type != new_field_type) {
|
||||
log_error("Input field type is " + input_field_type_name(new_field_type) + ", expected " +
|
||||
input_field_type_name(old_field_type));
|
||||
}
|
||||
}
|
||||
for (const int i : tree.interface_outputs().index_range()) {
|
||||
auto log_error = [&](StringRef message) {
|
||||
std::cout << " [Error] " << tree.interface_outputs()[i]->identifier << ": " << message
|
||||
<< std::endl;
|
||||
error = true;
|
||||
};
|
||||
const OutputFieldDependency &old_field_dep = tmp_inferencing_interface->outputs[i];
|
||||
const OutputFieldDependency &new_field_dep = inferencing_interface.outputs[i];
|
||||
if (old_field_dep.field_type() != new_field_dep.field_type()) {
|
||||
log_error("Output field type is " + output_field_type_name(new_field_dep.field_type()) +
|
||||
", expected " + output_field_type_name(old_field_dep.field_type()));
|
||||
}
|
||||
if (old_field_dep.linked_input_indices() != new_field_dep.linked_input_indices()) {
|
||||
log_error("Output field dependencies don't match");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!error) {
|
||||
std::cout << " OK!" << std::endl;
|
||||
}
|
||||
return error;
|
||||
}
|
||||
|
||||
static void add_node_constraints(const bNodeTree &tree,
|
||||
const bNode &node,
|
||||
const FieldInferencingInterface &inferencing_interface,
|
||||
csp::ConstraintSet &constraints)
|
||||
{
|
||||
NodeTreeVariables variables(tree);
|
||||
|
||||
/* Constraint is satisfied if both inputs or outputs of a zone node pair are the same type. */
|
||||
auto shared_field_type_constraint = [](const int value_a, const int value_b) {
|
||||
|
@ -748,6 +950,121 @@ static void add_node_type_constraints(const bNodeTree &tree,
|
|||
}
|
||||
};
|
||||
|
||||
for (const bNodeSocket *output_socket : node.output_sockets()) {
|
||||
if (!output_socket->is_available()) {
|
||||
continue;
|
||||
}
|
||||
const int var = variables.get_socket_variable(*output_socket);
|
||||
const bNodeSocketType *typeinfo = output_socket->typeinfo;
|
||||
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) : SOCK_CUSTOM;
|
||||
|
||||
if (!nodes::socket_type_supports_fields(type)) {
|
||||
constraints.add_unary(var, [](const int value) { return value == DomainValue::Single; });
|
||||
}
|
||||
|
||||
const OutputFieldDependency &field_dependency =
|
||||
inferencing_interface.outputs[output_socket->index()];
|
||||
switch (field_dependency.field_type()) {
|
||||
/* Fixed single value output. */
|
||||
case OutputSocketFieldType::None:
|
||||
constraints.add_unary(var, [](const int value) { return value == DomainValue::Single; });
|
||||
break;
|
||||
/* Fixed field source output. */
|
||||
case OutputSocketFieldType::FieldSource:
|
||||
constraints.add_unary(var, [](const int value) { return value == DomainValue::Field; });
|
||||
break;
|
||||
/* Internal dependency on one or more inputs. */
|
||||
case OutputSocketFieldType::DependentField:
|
||||
case OutputSocketFieldType::PartiallyDependent:
|
||||
for (const bNodeSocket *input_socket :
|
||||
gather_input_socket_dependencies(field_dependency, node))
|
||||
{
|
||||
if (!input_socket->is_available()) {
|
||||
continue;
|
||||
}
|
||||
const int input_var = variables.get_socket_variable(*input_socket);
|
||||
/* The output must be a field if the input it depends on is a field. */
|
||||
constraints.add_binary(var, input_var, [](const int value_dst, const int value_src) {
|
||||
return value_dst == DomainValue::Field || value_src == DomainValue::Single;
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
/* The output must be a single value when it is connected to any input that does
|
||||
* not support fields. */
|
||||
for (const bNodeSocket *src_socket : output_socket->directly_linked_sockets()) {
|
||||
if (src_socket->is_available()) {
|
||||
const int src_var = variables.get_socket_variable(*src_socket);
|
||||
constraints.add_binary(var, src_var, [](int value_dst, int value_src) {
|
||||
return value_dst == DomainValue::Single || value_src == DomainValue::Field;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// if (state.requires_single) {
|
||||
// bool any_input_is_field_implicitly = false;
|
||||
// const Vector<const bNodeSocket *> connected_inputs = gather_input_socket_dependencies(
|
||||
// field_dependency, *node);
|
||||
// for (const bNodeSocket *input_socket : connected_inputs) {
|
||||
// if (!input_socket->is_available()) {
|
||||
// continue;
|
||||
// }
|
||||
// if (inferencing_interface.inputs[input_socket->index()] ==
|
||||
// InputSocketFieldType::Implicit)
|
||||
// {
|
||||
// if (!input_socket->is_logically_linked()) {
|
||||
// any_input_is_field_implicitly = true;
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if (any_input_is_field_implicitly) {
|
||||
// /* This output isn't a single value actually. */
|
||||
// state.requires_single = false;
|
||||
// }
|
||||
// else {
|
||||
// /* If the output is required to be a single value, the connected inputs in the same
|
||||
// * node must not be fields as well. */
|
||||
// for (const bNodeSocket *input_socket : connected_inputs) {
|
||||
// field_state_by_socket_id[input_socket->index_in_tree()].requires_single = true;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
/* Some inputs do not require fields independent of what the outputs are connected to. */
|
||||
for (const bNodeSocket *input_socket : node.input_sockets()) {
|
||||
if (!input_socket->is_available()) {
|
||||
continue;
|
||||
}
|
||||
const int var = variables.get_socket_variable(*input_socket);
|
||||
const bNodeSocketType *typeinfo = input_socket->typeinfo;
|
||||
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) : SOCK_CUSTOM;
|
||||
|
||||
if (!nodes::socket_type_supports_fields(type)) {
|
||||
constraints.add_unary(var, [](const int value) { return value == DomainValue::Single; });
|
||||
}
|
||||
|
||||
const InputSocketFieldType field_type = inferencing_interface.inputs[input_socket->index()];
|
||||
if (field_type == InputSocketFieldType::None) {
|
||||
constraints.add_unary(input_socket->index_in_tree(),
|
||||
[](int value) { return value == DomainValue::Single; });
|
||||
}
|
||||
|
||||
/* The input must be a field value when it is connected to any output that can't be a single
|
||||
* value. */
|
||||
for (const bNodeSocket *src_socket : input_socket->directly_linked_sockets()) {
|
||||
if (src_socket->is_available()) {
|
||||
const int src_var = variables.get_socket_variable(*src_socket);
|
||||
constraints.add_binary(var, src_var, [](int value_dst, int value_src) {
|
||||
return value_dst == DomainValue::Field || value_src == DomainValue::Single;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Special constraints for certain node types. */
|
||||
switch (node.type) {
|
||||
case GEO_NODE_SIMULATION_INPUT: {
|
||||
const NodeGeometrySimulationInput &data = *static_cast<const NodeGeometrySimulationInput *>(
|
||||
|
@ -763,17 +1080,6 @@ static void add_node_type_constraints(const bNodeTree &tree,
|
|||
}
|
||||
case GEO_NODE_SIMULATION_OUTPUT: {
|
||||
/* Already handled in the input node case. */
|
||||
// for (const bNode *input_node : tree.nodes_by_type("GeometryNodeSimulationInput")) {
|
||||
// const NodeGeometrySimulationInput &data =
|
||||
// *static_cast<const NodeGeometrySimulationInput *>(input_node->storage);
|
||||
// if (node.identifier == data.output_node_id) {
|
||||
// /* First input node output is Delta Time which does not appear in the output node. */
|
||||
// add_zone_constraints(input_node->input_sockets(),
|
||||
// input_node->output_sockets().drop_front(1),
|
||||
// node.input_sockets(),
|
||||
// node.output_sockets());
|
||||
// }
|
||||
// }
|
||||
break;
|
||||
}
|
||||
case GEO_NODE_REPEAT_INPUT: {
|
||||
|
@ -789,21 +1095,34 @@ static void add_node_type_constraints(const bNodeTree &tree,
|
|||
}
|
||||
case GEO_NODE_REPEAT_OUTPUT: {
|
||||
/* Already handled in the input node case. */
|
||||
// for (const bNode *input_node : tree.nodes_by_type("GeometryNodeRepeatInput")) {
|
||||
// const NodeGeometryRepeatInput &data = *static_cast<const NodeGeometryRepeatInput *>(
|
||||
// input_node->storage);
|
||||
// if (node.identifier == data.output_node_id) {
|
||||
// add_zone_constraints(input_node->input_sockets(),
|
||||
// input_node->output_sockets(),
|
||||
// node.input_sockets(),
|
||||
// node.output_sockets());
|
||||
// }
|
||||
// }
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void update_socket_shapes(const bNodeTree &tree,
|
||||
const NodeTreeVariables &variables,
|
||||
const BitGroupVector<> csp_result)
|
||||
{
|
||||
auto get_shape_for_state = [](const BitSpan state) {
|
||||
return state[DomainValue::Field] ?
|
||||
(state[DomainValue::Single] ? SOCK_DISPLAY_SHAPE_DIAMOND_DOT :
|
||||
SOCK_DISPLAY_SHAPE_DIAMOND) :
|
||||
SOCK_DISPLAY_SHAPE_CIRCLE;
|
||||
};
|
||||
|
||||
for (const bNodeSocket *socket : tree.all_input_sockets()) {
|
||||
const int var = variables.get_socket_variable(*socket);
|
||||
const BitSpan state = csp_result[var];
|
||||
const_cast<bNodeSocket *>(socket)->display_shape = get_shape_for_state(state);
|
||||
}
|
||||
for (const bNodeSocket *socket : tree.all_sockets()) {
|
||||
const int var = variables.get_socket_variable(*socket);
|
||||
const BitSpan state = csp_result[var];
|
||||
const_cast<bNodeSocket *>(socket)->display_shape = get_shape_for_state(state);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check what the group output socket depends on. Potentially traverses the node tree
|
||||
* to figure out if it is always a field or if it depends on any group inputs.
|
||||
|
@ -838,7 +1157,6 @@ static OutputFieldDependency find_group_output_dependencies(
|
|||
/* This socket uses a field as input by default. */
|
||||
return OutputFieldDependency::ForFieldSource();
|
||||
}
|
||||
return OutputFieldDependency::ForFieldSource();
|
||||
|
||||
for (const bNodeSocket *origin_socket : input_socket->directly_linked_sockets()) {
|
||||
const bNode &origin_node = origin_socket->owner_node();
|
||||
|
@ -884,205 +1202,6 @@ static OutputFieldDependency find_group_output_dependencies(
|
|||
return OutputFieldDependency::ForPartiallyDependentField(std::move(linked_input_indices));
|
||||
}
|
||||
|
||||
struct NodeTreeVariables {
|
||||
/* Index ranges of variables for node sockets and interface sockets.
|
||||
* Group input/output nodes use the tree interface variables directly,
|
||||
* their socket variables are unused. */
|
||||
IndexRange socket_vars;
|
||||
IndexRange tree_input_vars;
|
||||
IndexRange tree_output_vars;
|
||||
|
||||
NodeTreeVariables(const bNodeTree &tree)
|
||||
{
|
||||
this->socket_vars = tree.all_sockets().index_range();
|
||||
this->tree_input_vars = socket_vars.after(tree.interface_inputs().size());
|
||||
this->tree_output_vars = tree_input_vars.after(tree.interface_outputs().size());
|
||||
}
|
||||
|
||||
int num_vars() const
|
||||
{
|
||||
return tree_output_vars.one_after_last();
|
||||
}
|
||||
|
||||
csp::VariableIndex get_socket_variable(const bNodeSocket &socket) const
|
||||
{
|
||||
if (socket.is_output()) {
|
||||
/* Use tree variables directly for group input nodes (except the extension socket). */
|
||||
if (socket.owner_node().is_group_input() &&
|
||||
tree_input_vars.index_range().contains(socket.index()))
|
||||
{
|
||||
return tree_input_vars[socket.index()];
|
||||
}
|
||||
return socket_vars[socket.index_in_tree()];
|
||||
}
|
||||
else {
|
||||
/* Use tree variables directly for group output nodes (except the extension socket). */
|
||||
if (socket.owner_node().is_group_output() &&
|
||||
tree_output_vars.index_range().contains(socket.index()))
|
||||
{
|
||||
return tree_output_vars[socket.index()];
|
||||
}
|
||||
return socket_vars[socket.index_in_tree()];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/* Verify inferencing result by comparing to the old propagation method. */
|
||||
static bool verify_field_inferencing_csp_result(
|
||||
const bNodeTree &tree,
|
||||
const Span<const FieldInferencingInterface *> interface_by_node,
|
||||
const BitGroupVector<> csp_result)
|
||||
{
|
||||
NodeTreeVariables variables(tree);
|
||||
/* Keep track of the state of all sockets. The index into this array is #SocketRef::id(). */
|
||||
Array<SocketFieldState> field_state_by_socket_id(tree.all_sockets().size());
|
||||
|
||||
/* Temp local inferencing interface to avoid overwriting the actual interface.
|
||||
* The propagation method directly writes to the interface. */
|
||||
std::unique_ptr<FieldInferencingInterface> tmp_inferencing_interface =
|
||||
std::make_unique<FieldInferencingInterface>();
|
||||
tmp_inferencing_interface->inputs.resize(tree.interface_inputs().size(),
|
||||
InputSocketFieldType::IsSupported);
|
||||
tmp_inferencing_interface->outputs.resize(tree.interface_outputs().size(),
|
||||
OutputFieldDependency::ForDataSource());
|
||||
|
||||
propagate_data_requirements_from_right_to_left(
|
||||
tree, interface_by_node, field_state_by_socket_id);
|
||||
determine_group_input_states(tree, *tmp_inferencing_interface, field_state_by_socket_id);
|
||||
propagate_field_status_from_left_to_right(tree, interface_by_node, field_state_by_socket_id);
|
||||
determine_group_output_states(
|
||||
tree, *tmp_inferencing_interface, interface_by_node, field_state_by_socket_id);
|
||||
|
||||
std::cout << "Verify field type inferencing for tree " << tree.id.name << std::endl;
|
||||
bool error = false;
|
||||
for (const bNodeSocket *socket : tree.all_sockets()) {
|
||||
if (!socket->is_available()) {
|
||||
continue;
|
||||
}
|
||||
auto log_error = [&](StringRef message) {
|
||||
const std::string socket_address = std::string(socket->owner_node().name) +
|
||||
(socket->is_output() ? "|>" : "<|") + socket->identifier;
|
||||
std::cout << " [Error] " << socket_address << ": " << message << std::endl;
|
||||
error = true;
|
||||
};
|
||||
const int var = variables.get_socket_variable(*socket);
|
||||
const BitSpan state = csp_result[var];
|
||||
if (!state[DomainValue::Single] && !state[DomainValue::Field]) {
|
||||
log_error("Neither single value nor field");
|
||||
continue;
|
||||
}
|
||||
const SocketFieldState &old_state = field_state_by_socket_id[socket->index_in_tree()];
|
||||
if (old_state.is_always_single) {
|
||||
if (!state[DomainValue::Single] || state[DomainValue::Field]) {
|
||||
log_error("Should only be single value");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (!old_state.is_single) {
|
||||
if (state[DomainValue::Single] || !state[DomainValue::Field]) {
|
||||
log_error("Should only be field");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (old_state.requires_single) {
|
||||
if (!state[DomainValue::Single] || state[DomainValue::Field]) {
|
||||
log_error("Should only be single value");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (!state[DomainValue::Single] || !state[DomainValue::Field]) {
|
||||
log_error("Should be both single value and field");
|
||||
}
|
||||
}
|
||||
|
||||
for (const int i : tree.interface_inputs().index_range()) {
|
||||
auto log_error = [&](StringRef message) {
|
||||
std::cout << " [Error] " << tree.interface_inputs()[i]->identifier << ": " << message << std::endl;
|
||||
error = true;
|
||||
};
|
||||
const int var = variables.tree_input_vars[i];
|
||||
const BitSpan state = csp_result[var];
|
||||
if (!state[DomainValue::Single] && !state[DomainValue::Field]) {
|
||||
log_error("Neither single value nor field");
|
||||
continue;
|
||||
}
|
||||
const InputSocketFieldType &old_state = tmp_inferencing_interface->inputs[i];
|
||||
switch (old_state) {
|
||||
case InputSocketFieldType::None:
|
||||
if (!state[DomainValue::Single] || state[DomainValue::Field]) {
|
||||
log_error("Should only be single value");
|
||||
}
|
||||
break;
|
||||
case InputSocketFieldType::IsSupported:
|
||||
case InputSocketFieldType::Implicit:
|
||||
if (!state[DomainValue::Single] || !state[DomainValue::Field]) {
|
||||
log_error("Should be both single value and field");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (const int i : tree.interface_outputs().index_range()) {
|
||||
auto log_error = [&](StringRef message) {
|
||||
std::cout << " [Error] " << tree.interface_outputs()[i]->identifier << ": " << message
|
||||
<< std::endl;
|
||||
error = true;
|
||||
};
|
||||
const int var = variables.tree_output_vars[i];
|
||||
const BitSpan state = csp_result[var];
|
||||
if (!state[DomainValue::Single] && !state[DomainValue::Field]) {
|
||||
log_error("Neither single value nor field");
|
||||
continue;
|
||||
}
|
||||
const OutputFieldDependency &old_state = tmp_inferencing_interface->outputs[i];
|
||||
switch (old_state.field_type()) {
|
||||
case OutputSocketFieldType::None:
|
||||
if (!state[DomainValue::Single] || state[DomainValue::Field]) {
|
||||
log_error("Should only be single value");
|
||||
}
|
||||
break;
|
||||
case OutputSocketFieldType::FieldSource:
|
||||
if (state[DomainValue::Single] || !state[DomainValue::Field]) {
|
||||
log_error("Should only be field");
|
||||
}
|
||||
break;
|
||||
case OutputSocketFieldType::DependentField:
|
||||
case OutputSocketFieldType::PartiallyDependent:
|
||||
if (!state[DomainValue::Single] || !state[DomainValue::Field]) {
|
||||
log_error("Should be both single value and field");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!error) {
|
||||
std::cout << " OK!" << std::endl;
|
||||
}
|
||||
return error;
|
||||
}
|
||||
|
||||
static void update_socket_shapes(const bNodeTree &tree,
|
||||
const NodeTreeVariables &variables,
|
||||
const BitGroupVector<> csp_result)
|
||||
{
|
||||
auto get_shape_for_state = [](const BitSpan state) {
|
||||
return state[DomainValue::Field] ?
|
||||
(state[DomainValue::Single] ? SOCK_DISPLAY_SHAPE_DIAMOND_DOT :
|
||||
SOCK_DISPLAY_SHAPE_DIAMOND) :
|
||||
SOCK_DISPLAY_SHAPE_CIRCLE;
|
||||
};
|
||||
|
||||
for (const bNodeSocket *socket : tree.all_input_sockets()) {
|
||||
const int var = variables.get_socket_variable(*socket);
|
||||
const BitSpan state = csp_result[var];
|
||||
const_cast<bNodeSocket *>(socket)->display_shape = get_shape_for_state(state);
|
||||
}
|
||||
for (const bNodeSocket *socket : tree.all_sockets()) {
|
||||
const int var = variables.get_socket_variable(*socket);
|
||||
const BitSpan state = csp_result[var];
|
||||
const_cast<bNodeSocket *>(socket)->display_shape = get_shape_for_state(state);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Logger>
|
||||
static void solve_field_inferencing_constraints(
|
||||
const bNodeTree &tree,
|
||||
|
@ -1097,7 +1216,8 @@ static void solve_field_inferencing_constraints(
|
|||
|
||||
tree.ensure_topology_cache();
|
||||
NodeTreeVariables variables(tree);
|
||||
auto variable_name = [&](const csp::VariableIndex var) -> std::string {
|
||||
|
||||
logger.declare_variables(variables.num_vars(), [&](const csp::VariableIndex var) -> std::string {
|
||||
if (variables.socket_vars.contains(var)) {
|
||||
const bNodeSocket &socket = *tree.all_sockets()[var - variables.socket_vars.start()];
|
||||
const bNode &node = socket.owner_node();
|
||||
|
@ -1115,150 +1235,19 @@ static void solve_field_inferencing_constraints(
|
|||
return std::string("O:") + iosocket.identifier;
|
||||
}
|
||||
return "";
|
||||
};
|
||||
logger.declare_variables(variables.num_vars(), variable_name);
|
||||
});
|
||||
|
||||
const Span<const bNode *> nodes = tree.toposort_right_to_left();
|
||||
|
||||
csp::ConstraintSet constraints;
|
||||
for (const bNode *node : nodes) {
|
||||
/* Special case: Group inputs and outputs use the interface variables directly. */
|
||||
const IndexRange interface_inputs = node->is_group_input() ?
|
||||
tree.interface_inputs().index_range() :
|
||||
IndexRange();
|
||||
const IndexRange interface_outputs = node->is_group_output() ?
|
||||
tree.interface_outputs().index_range() :
|
||||
IndexRange();
|
||||
|
||||
const FieldInferencingInterface &inferencing_interface = *interface_by_node[node->index()];
|
||||
for (const bNodeSocket *output_socket : node->output_sockets()) {
|
||||
if (!output_socket->is_available()) {
|
||||
continue;
|
||||
}
|
||||
const int var = variables.get_socket_variable(*output_socket);
|
||||
const bNodeSocketType *typeinfo = output_socket->typeinfo;
|
||||
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
|
||||
SOCK_CUSTOM;
|
||||
|
||||
if (!nodes::socket_type_supports_fields(type)) {
|
||||
constraints.add_unary(var, [](const int value) { return value == DomainValue::Single; });
|
||||
}
|
||||
|
||||
const OutputFieldDependency &field_dependency =
|
||||
inferencing_interface.outputs[output_socket->index()];
|
||||
switch (field_dependency.field_type()) {
|
||||
/* Fixed single value output. */
|
||||
case OutputSocketFieldType::None:
|
||||
constraints.add_unary(var, [](const int value) { return value == DomainValue::Single; });
|
||||
break;
|
||||
/* Fixed field source output. */
|
||||
case OutputSocketFieldType::FieldSource:
|
||||
constraints.add_unary(var, [](const int value) { return value == DomainValue::Field; });
|
||||
break;
|
||||
/* Internal dependency on one or more inputs. */
|
||||
case OutputSocketFieldType::DependentField:
|
||||
case OutputSocketFieldType::PartiallyDependent:
|
||||
for (const bNodeSocket *input_socket :
|
||||
gather_input_socket_dependencies(field_dependency, *node))
|
||||
{
|
||||
if (!input_socket->is_available()) {
|
||||
continue;
|
||||
}
|
||||
const int input_var = variables.get_socket_variable(*input_socket);
|
||||
/* The output must be a field if the input it depends on is a field. */
|
||||
constraints.add_binary(var, input_var, [](const int value_dst, const int value_src) {
|
||||
return value_dst == DomainValue::Field || value_src == DomainValue::Single;
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
/* The output must be a single value when it is connected to any input that does
|
||||
* not support fields. */
|
||||
for (const bNodeSocket *src_socket : output_socket->directly_linked_sockets()) {
|
||||
if (src_socket->is_available()) {
|
||||
constraints.add_binary(
|
||||
var, src_socket->index_in_tree(), [](int value_dst, int value_src) {
|
||||
return value_dst == DomainValue::Single || value_src == DomainValue::Field;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// if (state.requires_single) {
|
||||
// bool any_input_is_field_implicitly = false;
|
||||
// const Vector<const bNodeSocket *> connected_inputs = gather_input_socket_dependencies(
|
||||
// field_dependency, *node);
|
||||
// for (const bNodeSocket *input_socket : connected_inputs) {
|
||||
// if (!input_socket->is_available()) {
|
||||
// continue;
|
||||
// }
|
||||
// if (inferencing_interface.inputs[input_socket->index()] ==
|
||||
// InputSocketFieldType::Implicit)
|
||||
// {
|
||||
// if (!input_socket->is_logically_linked()) {
|
||||
// any_input_is_field_implicitly = true;
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if (any_input_is_field_implicitly) {
|
||||
// /* This output isn't a single value actually. */
|
||||
// state.requires_single = false;
|
||||
// }
|
||||
// else {
|
||||
// /* If the output is required to be a single value, the connected inputs in the same
|
||||
// * node must not be fields as well. */
|
||||
// for (const bNodeSocket *input_socket : connected_inputs) {
|
||||
// field_state_by_socket_id[input_socket->index_in_tree()].requires_single = true;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
/* Some inputs do not require fields independent of what the outputs are connected to. */
|
||||
for (const bNodeSocket *input_socket : node->input_sockets()) {
|
||||
if (!input_socket->is_available()) {
|
||||
continue;
|
||||
}
|
||||
const int var = variables.get_socket_variable(*input_socket);
|
||||
const bNodeSocketType *typeinfo = input_socket->typeinfo;
|
||||
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
|
||||
SOCK_CUSTOM;
|
||||
|
||||
if (!nodes::socket_type_supports_fields(type)) {
|
||||
constraints.add_unary(var, [](const int value) { return value == DomainValue::Single; });
|
||||
}
|
||||
|
||||
const InputSocketFieldType field_type = inferencing_interface.inputs[input_socket->index()];
|
||||
if (field_type == InputSocketFieldType::None) {
|
||||
constraints.add_unary(input_socket->index_in_tree(),
|
||||
[](int value) { return value == DomainValue::Single; });
|
||||
}
|
||||
|
||||
/* The input must be a field value when it is connected to any output that can't be a single value. */
|
||||
for (const bNodeSocket *src_socket : input_socket->directly_linked_sockets()) {
|
||||
if (src_socket->is_available()) {
|
||||
constraints.add_binary(
|
||||
var, src_socket->index_in_tree(), [](int value_dst, int value_src) {
|
||||
return value_dst == DomainValue::Field || value_src == DomainValue::Single;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Constraints for consistent field type across zones. */
|
||||
add_node_type_constraints(tree, *node, constraints);
|
||||
add_node_constraints(tree, *node, *interface_by_node[node->index()], constraints);
|
||||
}
|
||||
logger.declare_constraints(constraints);
|
||||
|
||||
BitGroupVector<> result = csp::solve_constraints_with_logger(
|
||||
constraints, variables.num_vars(), NumDomainValues, logger);
|
||||
|
||||
/* Perform old propagation method as well to verify the result. */
|
||||
if (true) {
|
||||
verify_field_inferencing_csp_result(tree, interface_by_node, result);
|
||||
}
|
||||
|
||||
/* Setup inferencing interface for the tree. */
|
||||
for (const int i : tree.interface_inputs().index_range()) {
|
||||
const int var = variables.tree_input_vars[i];
|
||||
|
@ -1288,6 +1277,9 @@ static void solve_field_inferencing_constraints(
|
|||
inferencing_interface.outputs[group_output_socket->index()] = std::move(field_dependency);
|
||||
}
|
||||
|
||||
/* Verify the result. */
|
||||
verify_field_inferencing_csp_result(tree, interface_by_node, result, inferencing_interface);
|
||||
|
||||
update_socket_shapes(tree, variables, result);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue