WIP: Field type inferencing using a constraint solver method #120420

Draft
Lukas Tönne wants to merge 52 commits from LukasTonne/blender:socket-type-inference into main

When changing the target branch, be careful to rebase the branch in your fork to match. See documentation.
1 changed files with 119 additions and 15 deletions
Showing only changes of commit 8c976aa345 - Show all commits

View File

@ -913,8 +913,6 @@ static BitGroupVector<> solve_constraints(const ConstraintSet &constraints,
enum DomainValue {
Single,
Field,
/* Distinct from Field to indicate group inputs that become fields if unconnected. */
ImplicitField,
NumDomainValues
};
@ -1005,6 +1003,86 @@ static void add_node_type_constraints(const bNodeTree &tree,
}
}
/**
* Check what the group output socket depends on. Potentially traverses the node tree
* to figure out if it is always a field or if it depends on any group inputs.
*/
static OutputFieldDependency find_group_output_dependencies(
const bNodeSocket &group_output_socket,
const Span<const FieldInferencingInterface *> interface_by_node,
const BitGroupVector<> &field_state_by_socket_id)
{
if (!is_field_socket_type(group_output_socket)) {
return OutputFieldDependency::ForDataSource();
}
/* Use a Set here instead of an array indexed by socket id, because we my only need to look at
* very few sockets. */
Set<const bNodeSocket *> handled_sockets;
Stack<const bNodeSocket *> sockets_to_check;
handled_sockets.add(&group_output_socket);
sockets_to_check.push(&group_output_socket);
/* Keeps track of group input indices that are (indirectly) connected to the output. */
Vector<int> linked_input_indices;
while (!sockets_to_check.is_empty()) {
const bNodeSocket *input_socket = sockets_to_check.pop();
const BitSpan input_state = field_state_by_socket_id[input_socket->index_in_tree()];
const bool can_be_single = input_state[DomainValue::Single];
const bool can_be_field = input_state[DomainValue::Field];
if (!input_socket->is_directly_linked() && can_be_field && !can_be_single) {
/* This socket uses a field as input by default. */
return OutputFieldDependency::ForFieldSource();
}
return OutputFieldDependency::ForFieldSource();
for (const bNodeSocket *origin_socket : input_socket->directly_linked_sockets()) {
const bNode &origin_node = origin_socket->owner_node();
const BitSpan origin_state = field_state_by_socket_id[origin_socket->index_in_tree()];
const bool origin_can_be_single = origin_state[DomainValue::Single];
const bool origin_can_be_field = origin_state[DomainValue::Field];
if (origin_can_be_field && !origin_can_be_single) {
if (origin_node.type == NODE_GROUP_INPUT) {
/* Found a group input that the group output depends on. */
linked_input_indices.append_non_duplicates(origin_socket->index());
}
else {
/* Found a field source that is not the group input. So the output is always a field. */
return OutputFieldDependency::ForFieldSource();
}
}
else if (!origin_can_be_single) {
const FieldInferencingInterface &inferencing_interface =
*interface_by_node[origin_node.index()];
const OutputFieldDependency &field_dependency =
inferencing_interface.outputs[origin_socket->index()];
/* Propagate search further to the left. */
for (const bNodeSocket *origin_input_socket :
gather_input_socket_dependencies(field_dependency, origin_node))
{
const BitSpan origin_input_state =
field_state_by_socket_id[origin_input_socket->index_in_tree()];
const bool origin_input_can_be_single = origin_input_state[DomainValue::Single];
if (!origin_input_socket->is_available()) {
continue;
}
if (!origin_input_can_be_single) {
if (handled_sockets.add(origin_input_socket)) {
sockets_to_check.push(origin_input_socket);
}
}
}
}
}
}
return OutputFieldDependency::ForPartiallyDependentField(std::move(linked_input_indices));
}
static void test_ac3_field_inferencing(
const bNodeTree &tree,
const Span<const FieldInferencingInterface *> interface_by_node,
@ -1012,6 +1090,11 @@ static void test_ac3_field_inferencing(
{
using Interrupter = ac3::PrintInterupter;
const bNode *group_output_node = tree.group_output_node();
if (!group_output_node) {
return;
}
/* Index ranges of variables for node sockets and interface sockets.
* Group input/output nodes use the tree interface variables directly,
* their socket variables are unused. */
@ -1025,12 +1108,18 @@ static void test_ac3_field_inferencing(
ac3::ConstraintSet constraints;
for (const bNode *node : nodes) {
/* Special case: Group inputs and outputs use the interface variables directly. */
const bool use_interface_vars = node->is_group_input() || node->is_group_output();
const IndexRange interface_inputs = node->is_group_input() ?
tree.interface_inputs().index_range() :
IndexRange();
const IndexRange interface_outputs = node->is_group_output() ?
tree.interface_outputs().index_range() :
IndexRange();
const FieldInferencingInterface &inferencing_interface = *interface_by_node[node->index()];
for (const bNodeSocket *output_socket : node->output_sockets()) {
const int var_index = use_interface_vars ? tree_output_vars[output_socket->index()] :
socket_vars[output_socket->index_in_tree()];
const int var_index = interface_inputs.contains(output_socket->index()) ?
tree_input_vars[output_socket->index()] :
socket_vars[output_socket->index_in_tree()];
const bNodeSocketType *typeinfo = output_socket->typeinfo;
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
SOCK_CUSTOM;
@ -1053,8 +1142,7 @@ static void test_ac3_field_inferencing(
for (const bNodeSocket *target_socket : output_socket->directly_linked_sockets()) {
if (target_socket->is_available()) {
constraints.add(var_index, target_socket->index_in_tree(), [](int value_a, int value_b) {
return value_a == DomainValue::Single ||
ELEM(value_b, DomainValue::Field, DomainValue::ImplicitField);
return value_a == DomainValue::Single || value_b == DomainValue::Field;
});
}
}
@ -1093,8 +1181,9 @@ static void test_ac3_field_inferencing(
/* Some inputs do not require fields independent of what the outputs are connected to. */
for (const bNodeSocket *input_socket : node->input_sockets()) {
const int var_index = use_interface_vars ? tree_input_vars[input_socket->index()] :
socket_vars[input_socket->index_in_tree()];
const int var_index = interface_outputs.contains(input_socket->index()) ?
tree_output_vars[input_socket->index()] :
socket_vars[input_socket->index_in_tree()];
const bNodeSocketType *typeinfo = input_socket->typeinfo;
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
SOCK_CUSTOM;
@ -1119,16 +1208,31 @@ static void test_ac3_field_inferencing(
/* Setup inferencing interface for the tree. */
for (const int i : tree.interface_inputs().index_range()) {
const bNodeTreeInterfaceSocket *group_input = tree.interface_inputs()[i];
const bNodeSocketType *typeinfo = group_input->socket_typeinfo();
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) : SOCK_CUSTOM;
const int var_index = tree_input_vars[i];
const BitSpan state = result[var_index];
if (state[DomainValue::ImplicitField]) {
if (state[DomainValue::Single]) {
if (state[DomainValue::Field]) {
inferencing_interface.inputs[i] = InputSocketFieldType::IsSupported;
}
else {
inferencing_interface.inputs[i] = InputSocketFieldType::None;
}
}
else {
if (state[DomainValue::Field]) {
inferencing_interface.inputs[i] = InputSocketFieldType::Implicit;
}
else {
/* Error: No supported field type. */
BLI_assert_unreachable();
}
}
}
inferencing_interface.inputs[i] = InputSocketFieldType::None;
for (const bNodeSocket *group_output_socket : group_output_node->input_sockets().drop_back(1)) {
OutputFieldDependency field_dependency = find_group_output_dependencies(
*group_output_socket, interface_by_node, result);
inferencing_interface.outputs[group_output_socket->index()] = std::move(field_dependency);
}
}