WIP: Field type inferencing using a constraint solver method #120420

Draft
Lukas Tönne wants to merge 52 commits from LukasTonne/blender:socket-type-inference into main

When changing the target branch, be careful to rebase the branch in your fork to match. See documentation.
1 changed files with 63 additions and 48 deletions
Showing only changes of commit 67a45c00f7 - Show all commits

View File

@ -886,16 +886,56 @@ static OutputFieldDependency find_group_output_dependencies(
return OutputFieldDependency::ForPartiallyDependentField(std::move(linked_input_indices));
}
struct NodeTreeVariables {
/* Index ranges of variables for node sockets and interface sockets.
* Group input/output nodes use the tree interface variables directly,
* their socket variables are unused. */
IndexRange socket_vars;
IndexRange tree_input_vars;
IndexRange tree_output_vars;
NodeTreeVariables(const bNodeTree &tree)
{
this->socket_vars = tree.all_sockets().index_range();
this->tree_input_vars = socket_vars.after(tree.interface_inputs().size());
this->tree_output_vars = tree_input_vars.after(tree.interface_outputs().size());
}
int num_vars() const
{
return tree_output_vars.one_after_last();
}
csp::VariableIndex get_socket_variable(const bNodeSocket &socket) const
{
if (socket.is_output()) {
/* Use tree variables directly for group input nodes (except the extension socket). */
if (socket.owner_node().is_group_input() &&
tree_input_vars.index_range().contains(socket.index()))
{
return tree_input_vars[socket.index()];
}
return socket_vars[socket.index_in_tree()];
}
else {
/* Use tree variables directly for group output nodes (except the extension socket). */
if (socket.owner_node().is_group_output() &&
tree_output_vars.index_range().contains(socket.index()))
{
return tree_output_vars[socket.index()];
}
return socket_vars[socket.index_in_tree()];
}
}
};
/* Verify inferencing result by comparing to the old propagation method. */
static bool verify_field_inferencing_csp_result(
const bNodeTree &tree,
const Span<const FieldInferencingInterface *> interface_by_node,
const BitGroupVector<> csp_result)
{
const IndexRange socket_vars = tree.all_sockets().index_range();
const IndexRange tree_input_vars = socket_vars.after(tree.interface_inputs().size());
const IndexRange tree_output_vars = tree_input_vars.after(tree.interface_outputs().size());
NodeTreeVariables variables(tree);
/* Keep track of the state of all sockets. The index into this array is #SocketRef::id(). */
Array<SocketFieldState> field_state_by_socket_id(tree.all_sockets().size());
@ -927,8 +967,8 @@ static bool verify_field_inferencing_csp_result(
std::cout << socket_address << ": " << message << std::endl;
error = true;
};
const int var_index = socket_vars[socket->index_in_tree()];
const BitSpan state = csp_result[var_index];
const int var = variables.get_socket_variable(*socket);
const BitSpan state = csp_result[var];
if (!state[DomainValue::Single] && !state[DomainValue::Field]) {
log_error("ERROR: neither single value nor field");
continue;
@ -962,8 +1002,8 @@ static bool verify_field_inferencing_csp_result(
std::cout << tree.interface_inputs()[i]->identifier << ": " << message << std::endl;
error = true;
};
const int var_index = tree_input_vars[i];
const BitSpan state = csp_result[var_index];
const int var = variables.tree_input_vars[i];
const BitSpan state = csp_result[var];
if (!state[DomainValue::Single] && !state[DomainValue::Field]) {
log_error("ERROR: neither single value nor field");
continue;
@ -988,8 +1028,8 @@ static bool verify_field_inferencing_csp_result(
std::cout << tree.interface_outputs()[i]->identifier << ": " << message << std::endl;
error = true;
};
const int var_index = tree_output_vars[i];
const BitSpan state = csp_result[var_index];
const int var = variables.tree_output_vars[i];
const BitSpan state = csp_result[var];
if (!state[DomainValue::Single] && !state[DomainValue::Field]) {
log_error("ERROR: neither single value nor field");
continue;
@ -1030,35 +1070,28 @@ static void solve_field_inferencing_constraints(
return;
}
/* Index ranges of variables for node sockets and interface sockets.
* Group input/output nodes use the tree interface variables directly,
* their socket variables are unused. */
const IndexRange socket_vars = tree.all_sockets().index_range();
const IndexRange tree_input_vars = socket_vars.after(tree.interface_inputs().size());
const IndexRange tree_output_vars = tree_input_vars.after(tree.interface_outputs().size());
const int num_vars = tree_output_vars.one_after_last();
tree.ensure_topology_cache();
NodeTreeVariables variables(tree);
auto variable_name = [&](const csp::VariableIndex var) -> std::string {
if (socket_vars.contains(var)) {
const bNodeSocket &socket = *tree.all_sockets()[var];
if (variables.socket_vars.contains(var)) {
const bNodeSocket &socket = *tree.all_sockets()[var - variables.socket_vars.start()];
const bNode &node = socket.owner_node();
return socket.is_output() ? std::string(node.name) + ":O:" + socket.identifier :
std::string(node.name) + ":I:" + socket.identifier;
}
if (tree_input_vars.contains(var)) {
if (variables.tree_input_vars.contains(var)) {
const bNodeTreeInterfaceSocket &iosocket =
*tree.interface_inputs()[var - tree_input_vars.start()];
*tree.interface_inputs()[var - variables.tree_input_vars.start()];
return std::string("I:") + iosocket.identifier;
}
if (tree_output_vars.contains(var)) {
if (variables.tree_output_vars.contains(var)) {
const bNodeTreeInterfaceSocket &iosocket =
*tree.interface_outputs()[var - tree_output_vars.start()];
*tree.interface_outputs()[var - variables.tree_output_vars.start()];
return std::string("O:") + iosocket.identifier;
}
return "";
};
logger.declare_variables(num_vars, variable_name);
logger.declare_variables(variables.num_vars(), variable_name);
const Span<const bNode *> nodes = tree.toposort_right_to_left();
@ -1071,31 +1104,13 @@ static void solve_field_inferencing_constraints(
const IndexRange interface_outputs = node->is_group_output() ?
tree.interface_outputs().index_range() :
IndexRange();
auto get_socket_variable = [&](const bNodeSocket &socket) -> csp::VariableIndex {
if (socket.is_output()) {
if (interface_inputs.contains(socket.index())) {
return tree_input_vars[socket.index()];
}
else {
return socket_vars[socket.index_in_tree()];
}
}
else {
if (interface_outputs.contains(socket.index())) {
return tree_output_vars[socket.index()];
}
else {
return socket_vars[socket.index_in_tree()];
}
}
};
const FieldInferencingInterface &inferencing_interface = *interface_by_node[node->index()];
for (const bNodeSocket *output_socket : node->output_sockets()) {
if (!output_socket->is_available()) {
continue;
}
const int var = get_socket_variable(*output_socket);
const int var = variables.get_socket_variable(*output_socket);
const bNodeSocketType *typeinfo = output_socket->typeinfo;
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
SOCK_CUSTOM;
@ -1161,7 +1176,7 @@ static void solve_field_inferencing_constraints(
if (!input_socket->is_available()) {
continue;
}
const int var = get_socket_variable(*input_socket);
const int var = variables.get_socket_variable(*input_socket);
const bNodeSocketType *typeinfo = input_socket->typeinfo;
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
SOCK_CUSTOM;
@ -1183,12 +1198,12 @@ static void solve_field_inferencing_constraints(
logger.declare_constraints(constraints);
BitGroupVector<> result = csp::solve_constraints_with_logger(
constraints, num_vars, NumDomainValues, logger);
constraints, variables.num_vars(), NumDomainValues, logger);
/* Setup inferencing interface for the tree. */
for (const int i : tree.interface_inputs().index_range()) {
const int var_index = tree_input_vars[i];
const BitSpan state = result[var_index];
const int var = variables.tree_input_vars[i];
const BitSpan state = result[var];
if (state[DomainValue::Single]) {
if (state[DomainValue::Field]) {
inferencing_interface.inputs[i] = InputSocketFieldType::IsSupported;