WIP: Field type inferencing using a constraint solver method #120420
|
@ -886,16 +886,56 @@ static OutputFieldDependency find_group_output_dependencies(
|
|||
return OutputFieldDependency::ForPartiallyDependentField(std::move(linked_input_indices));
|
||||
}
|
||||
|
||||
struct NodeTreeVariables {
|
||||
/* Index ranges of variables for node sockets and interface sockets.
|
||||
* Group input/output nodes use the tree interface variables directly,
|
||||
* their socket variables are unused. */
|
||||
IndexRange socket_vars;
|
||||
IndexRange tree_input_vars;
|
||||
IndexRange tree_output_vars;
|
||||
|
||||
NodeTreeVariables(const bNodeTree &tree)
|
||||
{
|
||||
this->socket_vars = tree.all_sockets().index_range();
|
||||
this->tree_input_vars = socket_vars.after(tree.interface_inputs().size());
|
||||
this->tree_output_vars = tree_input_vars.after(tree.interface_outputs().size());
|
||||
}
|
||||
|
||||
int num_vars() const
|
||||
{
|
||||
return tree_output_vars.one_after_last();
|
||||
}
|
||||
|
||||
csp::VariableIndex get_socket_variable(const bNodeSocket &socket) const
|
||||
{
|
||||
if (socket.is_output()) {
|
||||
/* Use tree variables directly for group input nodes (except the extension socket). */
|
||||
if (socket.owner_node().is_group_input() &&
|
||||
tree_input_vars.index_range().contains(socket.index()))
|
||||
{
|
||||
return tree_input_vars[socket.index()];
|
||||
}
|
||||
return socket_vars[socket.index_in_tree()];
|
||||
}
|
||||
else {
|
||||
/* Use tree variables directly for group output nodes (except the extension socket). */
|
||||
if (socket.owner_node().is_group_output() &&
|
||||
tree_output_vars.index_range().contains(socket.index()))
|
||||
{
|
||||
return tree_output_vars[socket.index()];
|
||||
}
|
||||
return socket_vars[socket.index_in_tree()];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/* Verify inferencing result by comparing to the old propagation method. */
|
||||
static bool verify_field_inferencing_csp_result(
|
||||
const bNodeTree &tree,
|
||||
const Span<const FieldInferencingInterface *> interface_by_node,
|
||||
const BitGroupVector<> csp_result)
|
||||
{
|
||||
const IndexRange socket_vars = tree.all_sockets().index_range();
|
||||
const IndexRange tree_input_vars = socket_vars.after(tree.interface_inputs().size());
|
||||
const IndexRange tree_output_vars = tree_input_vars.after(tree.interface_outputs().size());
|
||||
|
||||
NodeTreeVariables variables(tree);
|
||||
/* Keep track of the state of all sockets. The index into this array is #SocketRef::id(). */
|
||||
Array<SocketFieldState> field_state_by_socket_id(tree.all_sockets().size());
|
||||
|
||||
|
@ -927,8 +967,8 @@ static bool verify_field_inferencing_csp_result(
|
|||
std::cout << socket_address << ": " << message << std::endl;
|
||||
error = true;
|
||||
};
|
||||
const int var_index = socket_vars[socket->index_in_tree()];
|
||||
const BitSpan state = csp_result[var_index];
|
||||
const int var = variables.get_socket_variable(*socket);
|
||||
const BitSpan state = csp_result[var];
|
||||
if (!state[DomainValue::Single] && !state[DomainValue::Field]) {
|
||||
log_error("ERROR: neither single value nor field");
|
||||
continue;
|
||||
|
@ -962,8 +1002,8 @@ static bool verify_field_inferencing_csp_result(
|
|||
std::cout << tree.interface_inputs()[i]->identifier << ": " << message << std::endl;
|
||||
error = true;
|
||||
};
|
||||
const int var_index = tree_input_vars[i];
|
||||
const BitSpan state = csp_result[var_index];
|
||||
const int var = variables.tree_input_vars[i];
|
||||
const BitSpan state = csp_result[var];
|
||||
if (!state[DomainValue::Single] && !state[DomainValue::Field]) {
|
||||
log_error("ERROR: neither single value nor field");
|
||||
continue;
|
||||
|
@ -988,8 +1028,8 @@ static bool verify_field_inferencing_csp_result(
|
|||
std::cout << tree.interface_outputs()[i]->identifier << ": " << message << std::endl;
|
||||
error = true;
|
||||
};
|
||||
const int var_index = tree_output_vars[i];
|
||||
const BitSpan state = csp_result[var_index];
|
||||
const int var = variables.tree_output_vars[i];
|
||||
const BitSpan state = csp_result[var];
|
||||
if (!state[DomainValue::Single] && !state[DomainValue::Field]) {
|
||||
log_error("ERROR: neither single value nor field");
|
||||
continue;
|
||||
|
@ -1030,35 +1070,28 @@ static void solve_field_inferencing_constraints(
|
|||
return;
|
||||
}
|
||||
|
||||
/* Index ranges of variables for node sockets and interface sockets.
|
||||
* Group input/output nodes use the tree interface variables directly,
|
||||
* their socket variables are unused. */
|
||||
const IndexRange socket_vars = tree.all_sockets().index_range();
|
||||
const IndexRange tree_input_vars = socket_vars.after(tree.interface_inputs().size());
|
||||
const IndexRange tree_output_vars = tree_input_vars.after(tree.interface_outputs().size());
|
||||
const int num_vars = tree_output_vars.one_after_last();
|
||||
|
||||
tree.ensure_topology_cache();
|
||||
NodeTreeVariables variables(tree);
|
||||
auto variable_name = [&](const csp::VariableIndex var) -> std::string {
|
||||
if (socket_vars.contains(var)) {
|
||||
const bNodeSocket &socket = *tree.all_sockets()[var];
|
||||
if (variables.socket_vars.contains(var)) {
|
||||
const bNodeSocket &socket = *tree.all_sockets()[var - variables.socket_vars.start()];
|
||||
const bNode &node = socket.owner_node();
|
||||
return socket.is_output() ? std::string(node.name) + ":O:" + socket.identifier :
|
||||
std::string(node.name) + ":I:" + socket.identifier;
|
||||
}
|
||||
if (tree_input_vars.contains(var)) {
|
||||
if (variables.tree_input_vars.contains(var)) {
|
||||
const bNodeTreeInterfaceSocket &iosocket =
|
||||
*tree.interface_inputs()[var - tree_input_vars.start()];
|
||||
*tree.interface_inputs()[var - variables.tree_input_vars.start()];
|
||||
return std::string("I:") + iosocket.identifier;
|
||||
}
|
||||
if (tree_output_vars.contains(var)) {
|
||||
if (variables.tree_output_vars.contains(var)) {
|
||||
const bNodeTreeInterfaceSocket &iosocket =
|
||||
*tree.interface_outputs()[var - tree_output_vars.start()];
|
||||
*tree.interface_outputs()[var - variables.tree_output_vars.start()];
|
||||
return std::string("O:") + iosocket.identifier;
|
||||
}
|
||||
return "";
|
||||
};
|
||||
logger.declare_variables(num_vars, variable_name);
|
||||
logger.declare_variables(variables.num_vars(), variable_name);
|
||||
|
||||
const Span<const bNode *> nodes = tree.toposort_right_to_left();
|
||||
|
||||
|
@ -1071,31 +1104,13 @@ static void solve_field_inferencing_constraints(
|
|||
const IndexRange interface_outputs = node->is_group_output() ?
|
||||
tree.interface_outputs().index_range() :
|
||||
IndexRange();
|
||||
auto get_socket_variable = [&](const bNodeSocket &socket) -> csp::VariableIndex {
|
||||
if (socket.is_output()) {
|
||||
if (interface_inputs.contains(socket.index())) {
|
||||
return tree_input_vars[socket.index()];
|
||||
}
|
||||
else {
|
||||
return socket_vars[socket.index_in_tree()];
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (interface_outputs.contains(socket.index())) {
|
||||
return tree_output_vars[socket.index()];
|
||||
}
|
||||
else {
|
||||
return socket_vars[socket.index_in_tree()];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const FieldInferencingInterface &inferencing_interface = *interface_by_node[node->index()];
|
||||
for (const bNodeSocket *output_socket : node->output_sockets()) {
|
||||
if (!output_socket->is_available()) {
|
||||
continue;
|
||||
}
|
||||
const int var = get_socket_variable(*output_socket);
|
||||
const int var = variables.get_socket_variable(*output_socket);
|
||||
const bNodeSocketType *typeinfo = output_socket->typeinfo;
|
||||
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
|
||||
SOCK_CUSTOM;
|
||||
|
@ -1161,7 +1176,7 @@ static void solve_field_inferencing_constraints(
|
|||
if (!input_socket->is_available()) {
|
||||
continue;
|
||||
}
|
||||
const int var = get_socket_variable(*input_socket);
|
||||
const int var = variables.get_socket_variable(*input_socket);
|
||||
const bNodeSocketType *typeinfo = input_socket->typeinfo;
|
||||
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
|
||||
SOCK_CUSTOM;
|
||||
|
@ -1183,12 +1198,12 @@ static void solve_field_inferencing_constraints(
|
|||
logger.declare_constraints(constraints);
|
||||
|
||||
BitGroupVector<> result = csp::solve_constraints_with_logger(
|
||||
constraints, num_vars, NumDomainValues, logger);
|
||||
constraints, variables.num_vars(), NumDomainValues, logger);
|
||||
|
||||
/* Setup inferencing interface for the tree. */
|
||||
for (const int i : tree.interface_inputs().index_range()) {
|
||||
const int var_index = tree_input_vars[i];
|
||||
const BitSpan state = result[var_index];
|
||||
const int var = variables.tree_input_vars[i];
|
||||
const BitSpan state = result[var];
|
||||
if (state[DomainValue::Single]) {
|
||||
if (state[DomainValue::Field]) {
|
||||
inferencing_interface.inputs[i] = InputSocketFieldType::IsSupported;
|
||||
|
|
Loading…
Reference in New Issue