WIP: Field type inferencing using a constraint solver method #120420

Draft
Lukas Tönne wants to merge 52 commits from LukasTonne/blender:socket-type-inference into main

When changing the target branch, be careful to rebase the branch in your fork to match. See documentation.
1 changed files with 195 additions and 108 deletions
Showing only changes of commit fbcfd48d2a - Show all commits

View File

@ -721,51 +721,111 @@ using UnaryConstraintFn = std::function<bool(int value)>;
/** Binary constraint function, returns true if both values are compatible. */
using BinaryConstraintFn = std::function<bool(int value_a, int value_b)>;
class VariableRef {
int index_;
public:
explicit VariableRef(const int index) : index_(index) {}
int index() const
{
return index_;
}
operator int() const
{
return index_;
}
bool operator==(const VariableRef other) const
{
return index_ == other.index_;
}
uint64_t hash() const
{
return get_default_hash(index_);
}
};
struct MutableVariableRef {
int index_;
public:
explicit MutableVariableRef(const int index) : index_(index) {}
int index() const
{
return index_;
}
operator int() const
{
return index_;
}
operator VariableRef() const
{
return VariableRef(index_);
}
bool operator==(const MutableVariableRef other) const
{
return index_ == other.index_;
}
uint64_t hash() const
{
return get_default_hash(index_);
}
};
class ConstraintSet {
public:
struct Target {
int variable;
MutableVariableRef variable;
BinaryConstraintFn constraint;
};
struct Source {
int variable;
VariableRef variable;
BinaryConstraintFn constraint;
};
private:
MultiValueMap<int, UnaryConstraintFn> unary_;
MultiValueMap<int, Target> forward_;
MultiValueMap<int, Source> reverse_;
MultiValueMap<MutableVariableRef, UnaryConstraintFn> unary_;
MultiValueMap<VariableRef, Target> binary_by_source_;
MultiValueMap<MutableVariableRef, Source> binary_by_target_;
public:
const MultiValueMap<int, UnaryConstraintFn> &all_unary_constraints() const
const MultiValueMap<MutableVariableRef, UnaryConstraintFn> &all_unary_constraints() const
{
return unary_;
}
const MultiValueMap<int, Target> &binary_constraints_by_source() const
const MultiValueMap<VariableRef, Target> &binary_constraints_by_source() const
{
return forward_;
return binary_by_source_;
}
const MultiValueMap<int, Source> &binary_constraints_by_target() const
const MultiValueMap<MutableVariableRef, Source> &binary_constraints_by_target() const
{
return reverse_;
return binary_by_target_;
}
Span<UnaryConstraintFn> get_unary_constraints(const int source_key) const
Span<UnaryConstraintFn> get_unary_constraints(const MutableVariableRef source) const
{
return unary_.lookup(source_key);
return unary_.lookup(source);
}
Span<Target> get_target_constraints(const int source_key) const
Span<Target> get_target_constraints(const VariableRef source) const
{
return forward_.lookup(source_key);
return binary_by_source_.lookup(source);
}
Span<Source> get_source_constraints(const int target_key) const
Span<Source> get_source_constraints(const MutableVariableRef target) const
{
return reverse_.lookup(target_key);
return binary_by_target_.lookup(target);
}
BinaryConstraintFn get_binary_constraint(const int source_key, const int target_key) const
BinaryConstraintFn get_binary_constraint(const VariableRef source_key,
const MutableVariableRef target_key) const
{
for (const Target &target : forward_.lookup(source_key)) {
for (const Target &target : binary_by_source_.lookup(source_key)) {
if (target.variable == target_key) {
return target.constraint;
}
@ -773,14 +833,16 @@ class ConstraintSet {
return nullptr;
}
void add(const int variable, UnaryConstraintFn constraint)
void add(const MutableVariableRef variable, UnaryConstraintFn constraint)
{
unary_.add(variable, constraint);
}
void add(const int source, const int target, BinaryConstraintFn constraint)
void add(const MutableVariableRef target,
const VariableRef source,
BinaryConstraintFn constraint)
{
forward_.add(source, {target, constraint});
reverse_.add(target, {source, constraint});
binary_by_source_.add(source, {target, constraint});
binary_by_target_.add(target, {source, constraint});
}
};
@ -799,22 +861,22 @@ static void reduce_unary(const UnaryConstraintFn &constraint, MutableBitSpan dom
/* Remove all domain values from A that can not be paired with any value in B. */
static bool reduce_binary(const BinaryConstraintFn &constraint,
MutableBitSpan domain_a,
BitSpan domain_b)
BitSpan domain_src,
MutableBitSpan domain_dst)
{
bool changed = false;
for (const int i : domain_a.index_range()) {
if (!domain_a[i]) {
for (const int i : domain_dst.index_range()) {
if (!domain_dst[i]) {
continue;
}
bool valid = false;
for (const int j : domain_b.index_range()) {
if (domain_b[j] && constraint(i, j)) {
for (const int j : domain_src.index_range()) {
if (domain_src[j] && constraint(i, j)) {
valid = true;
}
}
if (!valid) {
domain_a[i].reset();
domain_dst[i].reset();
changed = true;
}
}
@ -825,18 +887,18 @@ struct NullLogger {
void on_start(StringRef /*message*/) {}
void on_end() {}
void declare_variables(const int /*num_vars*/, FunctionRef<std::string(int)> /*names_fn*/) {}
void declare_variables(const int /*num_vars*/, FunctionRef<std::string(VariableRef)> /*names_fn*/) {}
void declare_constraints(const ConstraintSet &/*constraints*/) {}
void notify(StringRef /*message*/) {}
void on_solve_start() {}
void on_worklist_extended(const int /*var_src*/, const int /*var_dst*/) {}
void on_binary_constraint_applied(const int /*src*/, const int /*dst*/) {}
void on_worklist_extended(VariableRef /*src*/, VariableRef /*dst*/) {}
void on_binary_constraint_applied(VariableRef /*src*/, VariableRef /*dst*/) {}
void on_domain_init(const int /*var*/, const BitSpan /*domain*/) {}
void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/) {}
void on_domain_empty(const int /*var*/) {}
void on_domain_init(VariableRef /*var*/, const BitSpan /*domain*/) {}
void on_domain_reduced(VariableRef /*var*/, const BitSpan /*domain*/) {}
void on_domain_empty(VariableRef /*var*/) {}
void on_solve_end() {}
};
@ -848,7 +910,10 @@ struct PrintLogger {
}
void on_end() {}
void declare_variables(const int /*num_vars*/, FunctionRef<std::string(int)> /*names_fn*/) {}
void declare_variables(const int /*num_vars*/,
FunctionRef<std::string(VariableRef)> /*names_fn*/)
{
}
void declare_constraints(const ConstraintSet & /*constraints*/) {}
void notify(StringRef message)
@ -858,29 +923,29 @@ struct PrintLogger {
void on_solve_start() {}
void on_worklist_extended(const int src, const int dst)
void on_worklist_extended(VariableRef src, VariableRef dst)
{
std::cout << " Worklist extended: " << src << ", " << dst << std::endl;
std::cout << " Worklist extended: " << src.index() << ", " << dst.index() << std::endl;
}
void on_binary_constraint_applied(const int src, const int dst)
void on_binary_constraint_applied(VariableRef src, VariableRef dst)
{
std::cout << " Applying " << src << ", " << dst << std::endl;
std::cout << " Applying " << src.index() << ", " << dst.index() << std::endl;
}
void on_domain_init(const int /*var*/, const BitSpan /*domain*/)
void on_domain_init(VariableRef /*var*/, const BitSpan /*domain*/)
{
std::cout << " Initialized domain" << std::endl;
}
void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/)
void on_domain_reduced(VariableRef /*var*/, const BitSpan /*domain*/)
{
std::cout << " Reduced domain!" << std::endl;
}
void on_domain_empty(const int var)
void on_domain_empty(VariableRef var)
{
std::cout << " FAILED! No possible values for " << var << std::endl;
std::cout << " FAILED! No possible values for " << var.index() << std::endl;
}
static void on_solve_end() {}
@ -927,12 +992,12 @@ struct JSONLogger {
return d;
}
void declare_variables(const int num_vars, FunctionRef<std::string(int)> names_fn)
void declare_variables(const int num_vars, FunctionRef<std::string(VariableRef)> names_fn)
{
variable_names.reinitialize(num_vars);
this->stream << "\"variables\": [";
for (const int i : IndexRange(num_vars)) {
variable_names[i] = names_fn(i);
variable_names[i] = names_fn(VariableRef(i));
if (i > 0) {
this->stream << ", ";
}
@ -946,14 +1011,14 @@ struct JSONLogger {
bool has_constraints = false;
this->stream << ", \"constraints\": [";
for (const auto &item : constraints.binary_constraints_by_source().items()) {
const int src = item.key;
const VariableRef src = item.key;
for (const ConstraintSet::Target &target : item.value) {
const int dst = target.variable;
const VariableRef dst = target.variable;
if (has_constraints) {
this->stream << ", ";
}
const std::string src_name = variable_names[src];
const std::string dst_name = variable_names[dst];
const std::string src_name = variable_names[src.index()];
const std::string dst_name = variable_names[dst.index()];
this->stream << "{\"name\": " << stringify(src_name + "--" + dst_name)
<< ", \"source\": " << stringify(src_name)
<< ", \"target\": " << stringify(dst_name) << "}" << std::endl;
@ -970,42 +1035,42 @@ struct JSONLogger {
this->stream << ", \"events\": [";
}
void on_worklist_extended(const int src, const int dst)
void on_worklist_extended(VariableRef src, VariableRef dst)
{
on_event("worklist added", [&]() {
this->stream << "\"source\": " << stringify(variable_names[src]) << ", \"target\": \""
<< variable_names[dst] << "\"";
this->stream << "\"source\": " << stringify(variable_names[src.index()]) << ", \"target\": \""
<< variable_names[dst.index()] << "\"";
});
}
void on_binary_constraint_applied(const int src, const int dst)
void on_binary_constraint_applied(VariableRef src, VariableRef dst)
{
on_event("constraint applied", [&]() {
this->stream << "\"source\": " << stringify(variable_names[src])
<< ", \"target\": " << stringify(variable_names[dst]);
this->stream << "\"source\": " << stringify(variable_names[src.index()])
<< ", \"target\": " << stringify(variable_names[dst.index()]);
});
}
void on_domain_init(const int var, const BitSpan domain)
void on_domain_init(VariableRef var, const BitSpan domain)
{
on_event("domain init", [&]() {
this->stream << "\"variable\": " << stringify(variable_names[var])
this->stream << "\"variable\": " << stringify(variable_names[var.index()])
<< ", \"domain\": " << domain_as_int(domain);
});
}
void on_domain_reduced(const int var, const BitSpan domain)
void on_domain_reduced(VariableRef var, const BitSpan domain)
{
on_event("domain reduced", [&]() {
this->stream << "\"variable\": " << stringify(variable_names[var])
this->stream << "\"variable\": " << stringify(variable_names[var.index()])
<< ", \"domain\": " << domain_as_int(domain);
});
}
void on_domain_empty(const int var)
void on_domain_empty(VariableRef var)
{
on_event("domain empty",
[&]() { this->stream << "\"variable\": " << stringify(variable_names[var]); });
[&]() { this->stream << "\"variable\": " << stringify(variable_names[var.index()]); });
}
void on_solve_end()
@ -1021,7 +1086,8 @@ static void solve_unary_constraints(const ConstraintSet &constraints,
Logger & /*logger*/)
{
for (const int i : variable_domains.index_range()) {
for (const UnaryConstraintFn &constraint : constraints.get_unary_constraints(i)) {
const MutableVariableRef var(i);
for (const UnaryConstraintFn &constraint : constraints.get_unary_constraints(var)) {
reduce_unary(constraint, variable_domains[i]);
}
}
@ -1036,36 +1102,41 @@ static void solve_binary_constraints(const ConstraintSet &constraints,
* by reducing unnecessary repetition of constraints.
* Using the topological sorting of sockets should make a decent "preconditioner".
* This is similar to what the current R-L/L-R solver does. */
Stack<int2> worklist;
struct BinaryKey {
VariableRef source;
MutableVariableRef target;
};
Stack<BinaryKey> worklist;
logger.notify("Binary Constraint Solve");
for (const int i : variable_domains.index_range()) {
for (const ConstraintSet::Target &target : constraints.get_target_constraints(i)) {
worklist.push({i, target.variable});
logger.on_worklist_extended(i, target.variable);
VariableRef source_var(i);
for (const ConstraintSet::Target &target : constraints.get_target_constraints(source_var)) {
worklist.push({source_var, target.variable});
logger.on_worklist_extended(source_var, target.variable);
}
}
while (!worklist.is_empty()) {
const int2 key = worklist.pop();
logger.on_binary_constraint_applied(key[0], key[1]);
const BinaryConstraintFn &constraint = constraints.get_binary_constraint(key[0], key[1]);
const MutableBitSpan domain_a = variable_domains[key[0]];
const BitSpan domain_b = variable_domains[key[1]];
if (reduce_binary(constraint, domain_a, domain_b)) {
logger.on_domain_reduced(key[0], domain_a);
if (!bits::any_bit_set(domain_a)) {
const BinaryKey key = worklist.pop();
logger.on_binary_constraint_applied(key.source, key.target);
const BinaryConstraintFn &constraint = constraints.get_binary_constraint(key.source, key.target);
const BitSpan domain_src = variable_domains[key.source];
const MutableBitSpan domain_dst = variable_domains[key.target];
if (reduce_binary(constraint, domain_src, domain_dst)) {
logger.on_domain_reduced(key.source, domain_src);
if (!bits::any_bit_set(domain_src)) {
/* TODO FAILURE CASE! */
logger.on_domain_empty(key[0]);
logger.on_domain_empty(key.target);
break;
}
/* Add arcs to A from all dependant variables (except B). */
for (const ConstraintSet::Source &source : constraints.get_source_constraints(key[0])) {
if (source.variable == key[1]) {
/* Add arcs from target to all dependant variables (except the source). */
for (const ConstraintSet::Target &target : constraints.get_target_constraints(key.target)) {
if (target.variable == key.source) {
continue;
}
logger.on_worklist_extended(source.variable, key[0]);
worklist.push({source.variable, key[0]});
logger.on_worklist_extended(key.target, target.variable);
worklist.push({key.source, target.variable});
}
}
}
@ -1081,7 +1152,7 @@ static BitGroupVector<> solve_constraints(const ConstraintSet &constraints,
logger.on_solve_start();
for (const int i : variable_domains.index_range()) {
logger.on_domain_init(i, variable_domains[i]);
logger.on_domain_init(VariableRef(i), variable_domains[i]);
}
solve_unary_constraints<Logger>(constraints, variable_domains, logger);
@ -1116,8 +1187,8 @@ static void add_node_type_constraints(const bNodeTree &tree,
if (!input_inputs[i]->is_available() || !output_inputs[i]->is_available()) {
continue;
}
const int var_a = input_inputs[i]->index_in_tree();
const int var_b = output_inputs[i]->index_in_tree();
const ac3::MutableVariableRef var_a(input_inputs[i]->index_in_tree());
const ac3::MutableVariableRef var_b(output_inputs[i]->index_in_tree());
constraints.add(var_a, var_b, shared_field_type_constraint);
constraints.add(var_b, var_a, shared_field_type_constraint);
}
@ -1125,8 +1196,8 @@ static void add_node_type_constraints(const bNodeTree &tree,
if (!input_outputs[i]->is_available() || !output_outputs[i]->is_available()) {
continue;
}
const int var_a = input_outputs[i]->index_in_tree();
const int var_b = output_outputs[i]->index_in_tree();
const ac3::MutableVariableRef var_a(input_outputs[i]->index_in_tree());
const ac3::MutableVariableRef var_b(output_outputs[i]->index_in_tree());
constraints.add(var_a, var_b, shared_field_type_constraint);
constraints.add(var_b, var_a, shared_field_type_constraint);
}
@ -1289,27 +1360,26 @@ static void test_ac3_field_inferencing(
const int num_vars = tree_output_vars.one_after_last();
tree.ensure_topology_cache();
auto variable_name = [&](const int var) -> std::string {
if (socket_vars.contains(var)) {
const bNodeSocket &socket = *tree.all_sockets()[var];
auto variable_name = [&](ac3::VariableRef var) -> std::string {
if (socket_vars.contains(var.index())) {
const bNodeSocket &socket = *tree.all_sockets()[var.index()];
const bNode &node = socket.owner_node();
return socket.is_output() ? std::string(node.name) + ":O:" + socket.identifier :
std::string(node.name) + ":I:" + socket.identifier;
}
if (tree_input_vars.contains(var)) {
if (tree_input_vars.contains(var.index())) {
const bNodeTreeInterfaceSocket &iosocket =
*tree.interface_inputs()[var - tree_input_vars.start()];
*tree.interface_inputs()[var.index() - tree_input_vars.start()];
return std::string("I:") + iosocket.identifier;
}
if (tree_output_vars.contains(var)) {
if (tree_output_vars.contains(var.index())) {
const bNodeTreeInterfaceSocket &iosocket =
*tree.interface_outputs()[var - tree_output_vars.start()];
*tree.interface_outputs()[var.index() - tree_output_vars.start()];
return std::string("O:") + iosocket.identifier;
}
return "";
};
logger.declare_variables(num_vars,
[&](const int var) -> std::string { return variable_name(var); });
logger.declare_variables(num_vars, variable_name);
const Span<const bNode *> nodes = tree.toposort_right_to_left();
@ -1322,37 +1392,56 @@ static void test_ac3_field_inferencing(
const IndexRange interface_outputs = node->is_group_output() ?
tree.interface_outputs().index_range() :
IndexRange();
auto get_socket_variable = [&](const bNodeSocket &socket) -> ac3::MutableVariableRef {
if (socket.is_output()) {
if (interface_inputs.contains(socket.index())) {
return ac3::MutableVariableRef(tree_input_vars[socket.index()]);
}
else {
return ac3::MutableVariableRef(socket_vars[socket.index_in_tree()]);
}
}
else {
if (interface_outputs.contains(socket.index())) {
return ac3::MutableVariableRef(tree_output_vars[socket.index()]);
}
else {
return ac3::MutableVariableRef(socket_vars[socket.index_in_tree()]);
}
}
};
const FieldInferencingInterface &inferencing_interface = *interface_by_node[node->index()];
for (const bNodeSocket *output_socket : node->output_sockets()) {
if (!output_socket->is_available()) {
continue;
}
const int var_index = interface_inputs.contains(output_socket->index()) ?
tree_input_vars[output_socket->index()] :
socket_vars[output_socket->index_in_tree()];
const ac3::MutableVariableRef var = get_socket_variable(*output_socket);
const bNodeSocketType *typeinfo = output_socket->typeinfo;
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
SOCK_CUSTOM;
if (!nodes::socket_type_supports_fields(type)) {
constraints.add(var_index, [](const int value) { return value == DomainValue::Single; });
constraints.add(var, [](const int value) { return value == DomainValue::Single; });
}
const OutputFieldDependency &field_dependency =
inferencing_interface.outputs[output_socket->index()];
if (field_dependency.field_type() == OutputSocketFieldType::FieldSource) {
constraints.add(var_index, [](const int value) { return value == DomainValue::Field; });
constraints.add(var,
[](const int value) { return value == DomainValue::Field; });
}
if (field_dependency.field_type() == OutputSocketFieldType::None) {
constraints.add(var_index, [](const int value) { return value == DomainValue::Single; });
constraints.add(var, [](const int value) { return value == DomainValue::Single; });
}
/* The output is required to be a single value when it is connected to any input that does
* not support fields. */
for (const bNodeSocket *target_socket : output_socket->directly_linked_sockets()) {
if (target_socket->is_available()) {
constraints.add(var_index, target_socket->index_in_tree(), [](int value_a, int value_b) {
constraints.add(var,
ac3::VariableRef{target_socket->index_in_tree()},
[](int value_a, int value_b) {
return value_a == DomainValue::Single || value_b == DomainValue::Field;
});
}
@ -1395,21 +1484,19 @@ static void test_ac3_field_inferencing(
if (!input_socket->is_available()) {
continue;
}
const int var_index = interface_outputs.contains(input_socket->index()) ?
tree_output_vars[input_socket->index()] :
socket_vars[input_socket->index_in_tree()];
const ac3::MutableVariableRef var = get_socket_variable(*input_socket);
const bNodeSocketType *typeinfo = input_socket->typeinfo;
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
SOCK_CUSTOM;
if (!nodes::socket_type_supports_fields(type)) {
constraints.add(var_index, [](const int value) { return value == DomainValue::Single; });
constraints.add(var, [](const int value) { return value == DomainValue::Single; });
}
const InputSocketFieldType field_type = inferencing_interface.inputs[input_socket->index()];
if (field_type == InputSocketFieldType::None) {
const int var_index = input_socket->index_in_tree();
constraints.add(var_index, [](int value) { return value == DomainValue::Single; });
constraints.add(ac3::MutableVariableRef(input_socket->index_in_tree()),
[](int value) { return value == DomainValue::Single; });
}
}