WIP: Field type inferencing using a constraint solver method #120420
|
@ -721,51 +721,111 @@ using UnaryConstraintFn = std::function<bool(int value)>;
|
|||
/** Binary constraint function, returns true if both values are compatible. */
|
||||
using BinaryConstraintFn = std::function<bool(int value_a, int value_b)>;
|
||||
|
||||
class VariableRef {
|
||||
int index_;
|
||||
|
||||
public:
|
||||
explicit VariableRef(const int index) : index_(index) {}
|
||||
|
||||
int index() const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
|
||||
operator int() const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
|
||||
bool operator==(const VariableRef other) const
|
||||
{
|
||||
return index_ == other.index_;
|
||||
}
|
||||
|
||||
uint64_t hash() const
|
||||
{
|
||||
return get_default_hash(index_);
|
||||
}
|
||||
};
|
||||
|
||||
struct MutableVariableRef {
|
||||
int index_;
|
||||
|
||||
public:
|
||||
explicit MutableVariableRef(const int index) : index_(index) {}
|
||||
|
||||
int index() const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
|
||||
operator int() const
|
||||
{
|
||||
return index_;
|
||||
}
|
||||
|
||||
operator VariableRef() const
|
||||
{
|
||||
return VariableRef(index_);
|
||||
}
|
||||
|
||||
bool operator==(const MutableVariableRef other) const
|
||||
{
|
||||
return index_ == other.index_;
|
||||
}
|
||||
|
||||
uint64_t hash() const
|
||||
{
|
||||
return get_default_hash(index_);
|
||||
}
|
||||
};
|
||||
|
||||
class ConstraintSet {
|
||||
public:
|
||||
struct Target {
|
||||
int variable;
|
||||
MutableVariableRef variable;
|
||||
BinaryConstraintFn constraint;
|
||||
};
|
||||
struct Source {
|
||||
int variable;
|
||||
VariableRef variable;
|
||||
BinaryConstraintFn constraint;
|
||||
};
|
||||
|
||||
private:
|
||||
MultiValueMap<int, UnaryConstraintFn> unary_;
|
||||
MultiValueMap<int, Target> forward_;
|
||||
MultiValueMap<int, Source> reverse_;
|
||||
MultiValueMap<MutableVariableRef, UnaryConstraintFn> unary_;
|
||||
MultiValueMap<VariableRef, Target> binary_by_source_;
|
||||
MultiValueMap<MutableVariableRef, Source> binary_by_target_;
|
||||
|
||||
public:
|
||||
const MultiValueMap<int, UnaryConstraintFn> &all_unary_constraints() const
|
||||
const MultiValueMap<MutableVariableRef, UnaryConstraintFn> &all_unary_constraints() const
|
||||
{
|
||||
return unary_;
|
||||
}
|
||||
const MultiValueMap<int, Target> &binary_constraints_by_source() const
|
||||
const MultiValueMap<VariableRef, Target> &binary_constraints_by_source() const
|
||||
{
|
||||
return forward_;
|
||||
return binary_by_source_;
|
||||
}
|
||||
const MultiValueMap<int, Source> &binary_constraints_by_target() const
|
||||
const MultiValueMap<MutableVariableRef, Source> &binary_constraints_by_target() const
|
||||
{
|
||||
return reverse_;
|
||||
return binary_by_target_;
|
||||
}
|
||||
|
||||
Span<UnaryConstraintFn> get_unary_constraints(const int source_key) const
|
||||
Span<UnaryConstraintFn> get_unary_constraints(const MutableVariableRef source) const
|
||||
{
|
||||
return unary_.lookup(source_key);
|
||||
return unary_.lookup(source);
|
||||
}
|
||||
Span<Target> get_target_constraints(const int source_key) const
|
||||
Span<Target> get_target_constraints(const VariableRef source) const
|
||||
{
|
||||
return forward_.lookup(source_key);
|
||||
return binary_by_source_.lookup(source);
|
||||
}
|
||||
Span<Source> get_source_constraints(const int target_key) const
|
||||
Span<Source> get_source_constraints(const MutableVariableRef target) const
|
||||
{
|
||||
return reverse_.lookup(target_key);
|
||||
return binary_by_target_.lookup(target);
|
||||
}
|
||||
BinaryConstraintFn get_binary_constraint(const int source_key, const int target_key) const
|
||||
BinaryConstraintFn get_binary_constraint(const VariableRef source_key,
|
||||
const MutableVariableRef target_key) const
|
||||
{
|
||||
for (const Target &target : forward_.lookup(source_key)) {
|
||||
for (const Target &target : binary_by_source_.lookup(source_key)) {
|
||||
if (target.variable == target_key) {
|
||||
return target.constraint;
|
||||
}
|
||||
|
@ -773,14 +833,16 @@ class ConstraintSet {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void add(const int variable, UnaryConstraintFn constraint)
|
||||
void add(const MutableVariableRef variable, UnaryConstraintFn constraint)
|
||||
{
|
||||
unary_.add(variable, constraint);
|
||||
}
|
||||
void add(const int source, const int target, BinaryConstraintFn constraint)
|
||||
void add(const MutableVariableRef target,
|
||||
const VariableRef source,
|
||||
BinaryConstraintFn constraint)
|
||||
{
|
||||
forward_.add(source, {target, constraint});
|
||||
reverse_.add(target, {source, constraint});
|
||||
binary_by_source_.add(source, {target, constraint});
|
||||
binary_by_target_.add(target, {source, constraint});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -799,22 +861,22 @@ static void reduce_unary(const UnaryConstraintFn &constraint, MutableBitSpan dom
|
|||
|
||||
/* Remove all domain values from A that can not be paired with any value in B. */
|
||||
static bool reduce_binary(const BinaryConstraintFn &constraint,
|
||||
MutableBitSpan domain_a,
|
||||
BitSpan domain_b)
|
||||
BitSpan domain_src,
|
||||
MutableBitSpan domain_dst)
|
||||
{
|
||||
bool changed = false;
|
||||
for (const int i : domain_a.index_range()) {
|
||||
if (!domain_a[i]) {
|
||||
for (const int i : domain_dst.index_range()) {
|
||||
if (!domain_dst[i]) {
|
||||
continue;
|
||||
}
|
||||
bool valid = false;
|
||||
for (const int j : domain_b.index_range()) {
|
||||
if (domain_b[j] && constraint(i, j)) {
|
||||
for (const int j : domain_src.index_range()) {
|
||||
if (domain_src[j] && constraint(i, j)) {
|
||||
valid = true;
|
||||
}
|
||||
}
|
||||
if (!valid) {
|
||||
domain_a[i].reset();
|
||||
domain_dst[i].reset();
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
@ -825,18 +887,18 @@ struct NullLogger {
|
|||
void on_start(StringRef /*message*/) {}
|
||||
void on_end() {}
|
||||
|
||||
void declare_variables(const int /*num_vars*/, FunctionRef<std::string(int)> /*names_fn*/) {}
|
||||
void declare_variables(const int /*num_vars*/, FunctionRef<std::string(VariableRef)> /*names_fn*/) {}
|
||||
void declare_constraints(const ConstraintSet &/*constraints*/) {}
|
||||
|
||||
void notify(StringRef /*message*/) {}
|
||||
|
||||
void on_solve_start() {}
|
||||
void on_worklist_extended(const int /*var_src*/, const int /*var_dst*/) {}
|
||||
void on_binary_constraint_applied(const int /*src*/, const int /*dst*/) {}
|
||||
void on_worklist_extended(VariableRef /*src*/, VariableRef /*dst*/) {}
|
||||
void on_binary_constraint_applied(VariableRef /*src*/, VariableRef /*dst*/) {}
|
||||
|
||||
void on_domain_init(const int /*var*/, const BitSpan /*domain*/) {}
|
||||
void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/) {}
|
||||
void on_domain_empty(const int /*var*/) {}
|
||||
void on_domain_init(VariableRef /*var*/, const BitSpan /*domain*/) {}
|
||||
void on_domain_reduced(VariableRef /*var*/, const BitSpan /*domain*/) {}
|
||||
void on_domain_empty(VariableRef /*var*/) {}
|
||||
|
||||
void on_solve_end() {}
|
||||
};
|
||||
|
@ -848,7 +910,10 @@ struct PrintLogger {
|
|||
}
|
||||
void on_end() {}
|
||||
|
||||
void declare_variables(const int /*num_vars*/, FunctionRef<std::string(int)> /*names_fn*/) {}
|
||||
void declare_variables(const int /*num_vars*/,
|
||||
FunctionRef<std::string(VariableRef)> /*names_fn*/)
|
||||
{
|
||||
}
|
||||
void declare_constraints(const ConstraintSet & /*constraints*/) {}
|
||||
|
||||
void notify(StringRef message)
|
||||
|
@ -858,29 +923,29 @@ struct PrintLogger {
|
|||
|
||||
void on_solve_start() {}
|
||||
|
||||
void on_worklist_extended(const int src, const int dst)
|
||||
void on_worklist_extended(VariableRef src, VariableRef dst)
|
||||
{
|
||||
std::cout << " Worklist extended: " << src << ", " << dst << std::endl;
|
||||
std::cout << " Worklist extended: " << src.index() << ", " << dst.index() << std::endl;
|
||||
}
|
||||
|
||||
void on_binary_constraint_applied(const int src, const int dst)
|
||||
void on_binary_constraint_applied(VariableRef src, VariableRef dst)
|
||||
{
|
||||
std::cout << " Applying " << src << ", " << dst << std::endl;
|
||||
std::cout << " Applying " << src.index() << ", " << dst.index() << std::endl;
|
||||
}
|
||||
|
||||
void on_domain_init(const int /*var*/, const BitSpan /*domain*/)
|
||||
void on_domain_init(VariableRef /*var*/, const BitSpan /*domain*/)
|
||||
{
|
||||
std::cout << " Initialized domain" << std::endl;
|
||||
}
|
||||
|
||||
void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/)
|
||||
void on_domain_reduced(VariableRef /*var*/, const BitSpan /*domain*/)
|
||||
{
|
||||
std::cout << " Reduced domain!" << std::endl;
|
||||
}
|
||||
|
||||
void on_domain_empty(const int var)
|
||||
void on_domain_empty(VariableRef var)
|
||||
{
|
||||
std::cout << " FAILED! No possible values for " << var << std::endl;
|
||||
std::cout << " FAILED! No possible values for " << var.index() << std::endl;
|
||||
}
|
||||
|
||||
static void on_solve_end() {}
|
||||
|
@ -927,12 +992,12 @@ struct JSONLogger {
|
|||
return d;
|
||||
}
|
||||
|
||||
void declare_variables(const int num_vars, FunctionRef<std::string(int)> names_fn)
|
||||
void declare_variables(const int num_vars, FunctionRef<std::string(VariableRef)> names_fn)
|
||||
{
|
||||
variable_names.reinitialize(num_vars);
|
||||
this->stream << "\"variables\": [";
|
||||
for (const int i : IndexRange(num_vars)) {
|
||||
variable_names[i] = names_fn(i);
|
||||
variable_names[i] = names_fn(VariableRef(i));
|
||||
if (i > 0) {
|
||||
this->stream << ", ";
|
||||
}
|
||||
|
@ -946,14 +1011,14 @@ struct JSONLogger {
|
|||
bool has_constraints = false;
|
||||
this->stream << ", \"constraints\": [";
|
||||
for (const auto &item : constraints.binary_constraints_by_source().items()) {
|
||||
const int src = item.key;
|
||||
const VariableRef src = item.key;
|
||||
for (const ConstraintSet::Target &target : item.value) {
|
||||
const int dst = target.variable;
|
||||
const VariableRef dst = target.variable;
|
||||
if (has_constraints) {
|
||||
this->stream << ", ";
|
||||
}
|
||||
const std::string src_name = variable_names[src];
|
||||
const std::string dst_name = variable_names[dst];
|
||||
const std::string src_name = variable_names[src.index()];
|
||||
const std::string dst_name = variable_names[dst.index()];
|
||||
this->stream << "{\"name\": " << stringify(src_name + "--" + dst_name)
|
||||
<< ", \"source\": " << stringify(src_name)
|
||||
<< ", \"target\": " << stringify(dst_name) << "}" << std::endl;
|
||||
|
@ -970,42 +1035,42 @@ struct JSONLogger {
|
|||
this->stream << ", \"events\": [";
|
||||
}
|
||||
|
||||
void on_worklist_extended(const int src, const int dst)
|
||||
void on_worklist_extended(VariableRef src, VariableRef dst)
|
||||
{
|
||||
on_event("worklist added", [&]() {
|
||||
this->stream << "\"source\": " << stringify(variable_names[src]) << ", \"target\": \""
|
||||
<< variable_names[dst] << "\"";
|
||||
this->stream << "\"source\": " << stringify(variable_names[src.index()]) << ", \"target\": \""
|
||||
<< variable_names[dst.index()] << "\"";
|
||||
});
|
||||
}
|
||||
|
||||
void on_binary_constraint_applied(const int src, const int dst)
|
||||
void on_binary_constraint_applied(VariableRef src, VariableRef dst)
|
||||
{
|
||||
on_event("constraint applied", [&]() {
|
||||
this->stream << "\"source\": " << stringify(variable_names[src])
|
||||
<< ", \"target\": " << stringify(variable_names[dst]);
|
||||
this->stream << "\"source\": " << stringify(variable_names[src.index()])
|
||||
<< ", \"target\": " << stringify(variable_names[dst.index()]);
|
||||
});
|
||||
}
|
||||
|
||||
void on_domain_init(const int var, const BitSpan domain)
|
||||
void on_domain_init(VariableRef var, const BitSpan domain)
|
||||
{
|
||||
on_event("domain init", [&]() {
|
||||
this->stream << "\"variable\": " << stringify(variable_names[var])
|
||||
this->stream << "\"variable\": " << stringify(variable_names[var.index()])
|
||||
<< ", \"domain\": " << domain_as_int(domain);
|
||||
});
|
||||
}
|
||||
|
||||
void on_domain_reduced(const int var, const BitSpan domain)
|
||||
void on_domain_reduced(VariableRef var, const BitSpan domain)
|
||||
{
|
||||
on_event("domain reduced", [&]() {
|
||||
this->stream << "\"variable\": " << stringify(variable_names[var])
|
||||
this->stream << "\"variable\": " << stringify(variable_names[var.index()])
|
||||
<< ", \"domain\": " << domain_as_int(domain);
|
||||
});
|
||||
}
|
||||
|
||||
void on_domain_empty(const int var)
|
||||
void on_domain_empty(VariableRef var)
|
||||
{
|
||||
on_event("domain empty",
|
||||
[&]() { this->stream << "\"variable\": " << stringify(variable_names[var]); });
|
||||
[&]() { this->stream << "\"variable\": " << stringify(variable_names[var.index()]); });
|
||||
}
|
||||
|
||||
void on_solve_end()
|
||||
|
@ -1021,7 +1086,8 @@ static void solve_unary_constraints(const ConstraintSet &constraints,
|
|||
Logger & /*logger*/)
|
||||
{
|
||||
for (const int i : variable_domains.index_range()) {
|
||||
for (const UnaryConstraintFn &constraint : constraints.get_unary_constraints(i)) {
|
||||
const MutableVariableRef var(i);
|
||||
for (const UnaryConstraintFn &constraint : constraints.get_unary_constraints(var)) {
|
||||
reduce_unary(constraint, variable_domains[i]);
|
||||
}
|
||||
}
|
||||
|
@ -1036,36 +1102,41 @@ static void solve_binary_constraints(const ConstraintSet &constraints,
|
|||
* by reducing unnecessary repetition of constraints.
|
||||
* Using the topological sorting of sockets should make a decent "preconditioner".
|
||||
* This is similar to what the current R-L/L-R solver does. */
|
||||
Stack<int2> worklist;
|
||||
struct BinaryKey {
|
||||
VariableRef source;
|
||||
MutableVariableRef target;
|
||||
};
|
||||
Stack<BinaryKey> worklist;
|
||||
logger.notify("Binary Constraint Solve");
|
||||
for (const int i : variable_domains.index_range()) {
|
||||
for (const ConstraintSet::Target &target : constraints.get_target_constraints(i)) {
|
||||
worklist.push({i, target.variable});
|
||||
logger.on_worklist_extended(i, target.variable);
|
||||
VariableRef source_var(i);
|
||||
for (const ConstraintSet::Target &target : constraints.get_target_constraints(source_var)) {
|
||||
worklist.push({source_var, target.variable});
|
||||
logger.on_worklist_extended(source_var, target.variable);
|
||||
}
|
||||
}
|
||||
|
||||
while (!worklist.is_empty()) {
|
||||
const int2 key = worklist.pop();
|
||||
logger.on_binary_constraint_applied(key[0], key[1]);
|
||||
const BinaryConstraintFn &constraint = constraints.get_binary_constraint(key[0], key[1]);
|
||||
const MutableBitSpan domain_a = variable_domains[key[0]];
|
||||
const BitSpan domain_b = variable_domains[key[1]];
|
||||
if (reduce_binary(constraint, domain_a, domain_b)) {
|
||||
logger.on_domain_reduced(key[0], domain_a);
|
||||
if (!bits::any_bit_set(domain_a)) {
|
||||
const BinaryKey key = worklist.pop();
|
||||
logger.on_binary_constraint_applied(key.source, key.target);
|
||||
const BinaryConstraintFn &constraint = constraints.get_binary_constraint(key.source, key.target);
|
||||
const BitSpan domain_src = variable_domains[key.source];
|
||||
const MutableBitSpan domain_dst = variable_domains[key.target];
|
||||
if (reduce_binary(constraint, domain_src, domain_dst)) {
|
||||
logger.on_domain_reduced(key.source, domain_src);
|
||||
if (!bits::any_bit_set(domain_src)) {
|
||||
/* TODO FAILURE CASE! */
|
||||
logger.on_domain_empty(key[0]);
|
||||
logger.on_domain_empty(key.target);
|
||||
break;
|
||||
}
|
||||
|
||||
/* Add arcs to A from all dependant variables (except B). */
|
||||
for (const ConstraintSet::Source &source : constraints.get_source_constraints(key[0])) {
|
||||
if (source.variable == key[1]) {
|
||||
/* Add arcs from target to all dependant variables (except the source). */
|
||||
for (const ConstraintSet::Target &target : constraints.get_target_constraints(key.target)) {
|
||||
if (target.variable == key.source) {
|
||||
continue;
|
||||
}
|
||||
logger.on_worklist_extended(source.variable, key[0]);
|
||||
worklist.push({source.variable, key[0]});
|
||||
logger.on_worklist_extended(key.target, target.variable);
|
||||
worklist.push({key.source, target.variable});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1081,7 +1152,7 @@ static BitGroupVector<> solve_constraints(const ConstraintSet &constraints,
|
|||
|
||||
logger.on_solve_start();
|
||||
for (const int i : variable_domains.index_range()) {
|
||||
logger.on_domain_init(i, variable_domains[i]);
|
||||
logger.on_domain_init(VariableRef(i), variable_domains[i]);
|
||||
}
|
||||
|
||||
solve_unary_constraints<Logger>(constraints, variable_domains, logger);
|
||||
|
@ -1116,8 +1187,8 @@ static void add_node_type_constraints(const bNodeTree &tree,
|
|||
if (!input_inputs[i]->is_available() || !output_inputs[i]->is_available()) {
|
||||
continue;
|
||||
}
|
||||
const int var_a = input_inputs[i]->index_in_tree();
|
||||
const int var_b = output_inputs[i]->index_in_tree();
|
||||
const ac3::MutableVariableRef var_a(input_inputs[i]->index_in_tree());
|
||||
const ac3::MutableVariableRef var_b(output_inputs[i]->index_in_tree());
|
||||
constraints.add(var_a, var_b, shared_field_type_constraint);
|
||||
constraints.add(var_b, var_a, shared_field_type_constraint);
|
||||
}
|
||||
|
@ -1125,8 +1196,8 @@ static void add_node_type_constraints(const bNodeTree &tree,
|
|||
if (!input_outputs[i]->is_available() || !output_outputs[i]->is_available()) {
|
||||
continue;
|
||||
}
|
||||
const int var_a = input_outputs[i]->index_in_tree();
|
||||
const int var_b = output_outputs[i]->index_in_tree();
|
||||
const ac3::MutableVariableRef var_a(input_outputs[i]->index_in_tree());
|
||||
const ac3::MutableVariableRef var_b(output_outputs[i]->index_in_tree());
|
||||
constraints.add(var_a, var_b, shared_field_type_constraint);
|
||||
constraints.add(var_b, var_a, shared_field_type_constraint);
|
||||
}
|
||||
|
@ -1289,27 +1360,26 @@ static void test_ac3_field_inferencing(
|
|||
const int num_vars = tree_output_vars.one_after_last();
|
||||
|
||||
tree.ensure_topology_cache();
|
||||
auto variable_name = [&](const int var) -> std::string {
|
||||
if (socket_vars.contains(var)) {
|
||||
const bNodeSocket &socket = *tree.all_sockets()[var];
|
||||
auto variable_name = [&](ac3::VariableRef var) -> std::string {
|
||||
if (socket_vars.contains(var.index())) {
|
||||
const bNodeSocket &socket = *tree.all_sockets()[var.index()];
|
||||
const bNode &node = socket.owner_node();
|
||||
return socket.is_output() ? std::string(node.name) + ":O:" + socket.identifier :
|
||||
std::string(node.name) + ":I:" + socket.identifier;
|
||||
}
|
||||
if (tree_input_vars.contains(var)) {
|
||||
if (tree_input_vars.contains(var.index())) {
|
||||
const bNodeTreeInterfaceSocket &iosocket =
|
||||
*tree.interface_inputs()[var - tree_input_vars.start()];
|
||||
*tree.interface_inputs()[var.index() - tree_input_vars.start()];
|
||||
return std::string("I:") + iosocket.identifier;
|
||||
}
|
||||
if (tree_output_vars.contains(var)) {
|
||||
if (tree_output_vars.contains(var.index())) {
|
||||
const bNodeTreeInterfaceSocket &iosocket =
|
||||
*tree.interface_outputs()[var - tree_output_vars.start()];
|
||||
*tree.interface_outputs()[var.index() - tree_output_vars.start()];
|
||||
return std::string("O:") + iosocket.identifier;
|
||||
}
|
||||
return "";
|
||||
};
|
||||
logger.declare_variables(num_vars,
|
||||
[&](const int var) -> std::string { return variable_name(var); });
|
||||
logger.declare_variables(num_vars, variable_name);
|
||||
|
||||
const Span<const bNode *> nodes = tree.toposort_right_to_left();
|
||||
|
||||
|
@ -1322,37 +1392,56 @@ static void test_ac3_field_inferencing(
|
|||
const IndexRange interface_outputs = node->is_group_output() ?
|
||||
tree.interface_outputs().index_range() :
|
||||
IndexRange();
|
||||
auto get_socket_variable = [&](const bNodeSocket &socket) -> ac3::MutableVariableRef {
|
||||
if (socket.is_output()) {
|
||||
if (interface_inputs.contains(socket.index())) {
|
||||
return ac3::MutableVariableRef(tree_input_vars[socket.index()]);
|
||||
}
|
||||
else {
|
||||
return ac3::MutableVariableRef(socket_vars[socket.index_in_tree()]);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (interface_outputs.contains(socket.index())) {
|
||||
return ac3::MutableVariableRef(tree_output_vars[socket.index()]);
|
||||
}
|
||||
else {
|
||||
return ac3::MutableVariableRef(socket_vars[socket.index_in_tree()]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const FieldInferencingInterface &inferencing_interface = *interface_by_node[node->index()];
|
||||
for (const bNodeSocket *output_socket : node->output_sockets()) {
|
||||
if (!output_socket->is_available()) {
|
||||
continue;
|
||||
}
|
||||
const int var_index = interface_inputs.contains(output_socket->index()) ?
|
||||
tree_input_vars[output_socket->index()] :
|
||||
socket_vars[output_socket->index_in_tree()];
|
||||
const ac3::MutableVariableRef var = get_socket_variable(*output_socket);
|
||||
const bNodeSocketType *typeinfo = output_socket->typeinfo;
|
||||
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
|
||||
SOCK_CUSTOM;
|
||||
|
||||
if (!nodes::socket_type_supports_fields(type)) {
|
||||
constraints.add(var_index, [](const int value) { return value == DomainValue::Single; });
|
||||
constraints.add(var, [](const int value) { return value == DomainValue::Single; });
|
||||
}
|
||||
|
||||
const OutputFieldDependency &field_dependency =
|
||||
inferencing_interface.outputs[output_socket->index()];
|
||||
if (field_dependency.field_type() == OutputSocketFieldType::FieldSource) {
|
||||
constraints.add(var_index, [](const int value) { return value == DomainValue::Field; });
|
||||
constraints.add(var,
|
||||
[](const int value) { return value == DomainValue::Field; });
|
||||
}
|
||||
if (field_dependency.field_type() == OutputSocketFieldType::None) {
|
||||
constraints.add(var_index, [](const int value) { return value == DomainValue::Single; });
|
||||
constraints.add(var, [](const int value) { return value == DomainValue::Single; });
|
||||
}
|
||||
|
||||
/* The output is required to be a single value when it is connected to any input that does
|
||||
* not support fields. */
|
||||
for (const bNodeSocket *target_socket : output_socket->directly_linked_sockets()) {
|
||||
if (target_socket->is_available()) {
|
||||
constraints.add(var_index, target_socket->index_in_tree(), [](int value_a, int value_b) {
|
||||
constraints.add(var,
|
||||
ac3::VariableRef{target_socket->index_in_tree()},
|
||||
[](int value_a, int value_b) {
|
||||
return value_a == DomainValue::Single || value_b == DomainValue::Field;
|
||||
});
|
||||
}
|
||||
|
@ -1395,21 +1484,19 @@ static void test_ac3_field_inferencing(
|
|||
if (!input_socket->is_available()) {
|
||||
continue;
|
||||
}
|
||||
const int var_index = interface_outputs.contains(input_socket->index()) ?
|
||||
tree_output_vars[input_socket->index()] :
|
||||
socket_vars[input_socket->index_in_tree()];
|
||||
const ac3::MutableVariableRef var = get_socket_variable(*input_socket);
|
||||
const bNodeSocketType *typeinfo = input_socket->typeinfo;
|
||||
const eNodeSocketDatatype type = typeinfo ? eNodeSocketDatatype(typeinfo->type) :
|
||||
SOCK_CUSTOM;
|
||||
|
||||
if (!nodes::socket_type_supports_fields(type)) {
|
||||
constraints.add(var_index, [](const int value) { return value == DomainValue::Single; });
|
||||
constraints.add(var, [](const int value) { return value == DomainValue::Single; });
|
||||
}
|
||||
|
||||
const InputSocketFieldType field_type = inferencing_interface.inputs[input_socket->index()];
|
||||
if (field_type == InputSocketFieldType::None) {
|
||||
const int var_index = input_socket->index_in_tree();
|
||||
constraints.add(var_index, [](int value) { return value == DomainValue::Single; });
|
||||
constraints.add(ac3::MutableVariableRef(input_socket->index_in_tree()),
|
||||
[](int value) { return value == DomainValue::Single; });
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue