WIP: Field type inferencing using a constraint solver method #120420

Draft
Lukas Tönne wants to merge 52 commits from LukasTonne/blender:socket-type-inference into main

When changing the target branch, be careful to rebase the branch in your fork to match. See documentation.
1 changed files with 158 additions and 131 deletions
Showing only changes of commit aef69e3f98 - Show all commits

View File

@ -716,6 +716,69 @@ using UnaryConstraintFn = std::function<bool(int value)>;
/** Binary constraint function, returns true if both values are compatible. */
using BinaryConstraintFn = std::function<bool(int value_a, int value_b)>;
class ConstraintSet {
public:
struct Target {
int variable;
BinaryConstraintFn constraint;
};
struct Source {
int variable;
BinaryConstraintFn constraint;
};
private:
MultiValueMap<int, UnaryConstraintFn> unary_;
MultiValueMap<int, Target> forward_;
MultiValueMap<int, Source> reverse_;
public:
const MultiValueMap<int, UnaryConstraintFn> &all_unary_constraints() const
{
return unary_;
}
const MultiValueMap<int, Target> &binary_constraints_by_source() const
{
return forward_;
}
const MultiValueMap<int, Source> &binary_constraints_by_target() const
{
return reverse_;
}
Span<UnaryConstraintFn> get_unary_constraints(const int source_key) const
{
return unary_.lookup(source_key);
}
Span<Target> get_target_constraints(const int source_key) const
{
return forward_.lookup(source_key);
}
Span<Source> get_source_constraints(const int target_key) const
{
return reverse_.lookup(target_key);
}
BinaryConstraintFn get_binary_constraint(const int source_key, const int target_key) const
{
for (const Target &target : forward_.lookup(source_key)) {
if (target.variable == target_key) {
return target.constraint;
}
}
return nullptr;
}
void add(const int variable, UnaryConstraintFn constraint)
{
unary_.add(variable, constraint);
}
void add(const int source, const int target, BinaryConstraintFn constraint)
{
forward_.add(source, {target, constraint});
reverse_.add(target, {source, constraint});
}
};
/* Remove all domain values that are not allowed by the constraint. */
static void reduce_unary(const UnaryConstraintFn &constraint, MutableBitSpan domain)
{
@ -757,23 +820,20 @@ struct NullLogger {
void on_start(StringRef /*message*/) {}
void on_end() {}
static void set_variable_names(const int /*num_vars*/,
FunctionRef<std::string(int)> /*names_fn*/)
{
}
void declare_variables(const int /*num_vars*/, FunctionRef<std::string(int)> /*names_fn*/) {}
void declare_constraints(const int /*num_cons*/, FunctionRef<int2(int)> /*nodes_fn*/) {}
static void notify(StringRef /*message*/) {}
void notify(StringRef /*message*/) {}
static void on_solve_start() {}
static void on_worklist_extended(const int /*var_src*/, const int /*var_dst*/) {}
static void on_binary_constraint_applied(const int /*src*/, const int /*dst*/) {}
void on_solve_start() {}
void on_worklist_extended(const int /*var_src*/, const int /*var_dst*/) {}
void on_binary_constraint_applied(const int /*src*/, const int /*dst*/) {}
static void on_domain_init(const int /*var*/, const BitSpan /*domain*/) {}
static void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/) {}
static void on_domain_empty(const int /*var*/) {}
void on_domain_init(const int /*var*/, const BitSpan /*domain*/) {}
void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/) {}
void on_domain_empty(const int /*var*/) {}
static void on_variable_state_changed(BitGroupVector<> & /*variable_domains*/) {}
static void on_solve_end() {}
void on_solve_end() {}
};
struct PrintLogger {
@ -783,59 +843,41 @@ struct PrintLogger {
}
void on_end() {}
static void set_variable_names(const int /*num_vars*/,
FunctionRef<std::string(int)> /*names_fn*/)
{
}
void declare_variables(const int /*num_vars*/, FunctionRef<std::string(int)> /*names_fn*/) {}
void declare_constraints(const int /*num_cons*/, FunctionRef<int2(int)> /*nodes_fn*/) {}
static void notify(StringRef message)
void notify(StringRef message)
{
std::cout << message << std::endl;
}
static void on_solve_start() {}
void on_solve_start() {}
static void on_worklist_extended(const int src, const int dst)
void on_worklist_extended(const int src, const int dst)
{
std::cout << " Worklist extended: " << src << ", " << dst << std::endl;
}
static void on_binary_constraint_applied(const int src, const int dst)
void on_binary_constraint_applied(const int src, const int dst)
{
std::cout << " Applying " << src << ", " << dst << std::endl;
}
static void on_domain_init(const int /*var*/, const BitSpan /*domain*/)
void on_domain_init(const int /*var*/, const BitSpan /*domain*/)
{
std::cout << " Initialized domain" << std::endl;
}
static void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/)
void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/)
{
std::cout << " Reduced domain!" << std::endl;
}
static void on_domain_empty(const int var)
void on_domain_empty(const int var)
{
std::cout << " FAILED! No possible values for " << var << std::endl;
}
static void on_variable_state_changed(BitGroupVector<> &variable_domains)
{
for (const int i : variable_domains.index_range()) {
std::string s = std::to_string(i) + ": ";
for (const int d : IndexRange(variable_domains.group_size())) {
if (variable_domains[i][d]) {
s += std::to_string(d) + ", ";
}
else {
s += " ";
}
}
std::cout << s << std::endl;
}
}
static void on_solve_end() {}
};
@ -853,7 +895,34 @@ struct JSONLogger {
this->stream << "}" << std::endl;
}
void set_variable_names(const int num_vars, FunctionRef<std::string(int)> names_fn)
template<typename Fn> void on_event(StringRef type, Fn fn)
{
if (has_events) {
this->stream << ", ";
}
this->stream << "{\"type\": \"" << type << "\", ";
fn();
this->stream << "}" << std::endl;
has_events = true;
}
std::string stringify(StringRef s)
{
return "\"" + s + "\"";
}
int domain_as_int(const BitSpan domain)
{
int d = 0;
for (const int k : domain.index_range()) {
if (domain[k]) {
d |= (1 << k);
}
}
return d;
}
void declare_variables(const int num_vars, FunctionRef<std::string(int)> names_fn)
{
variable_names.reinitialize(num_vars);
this->stream << "\"variables\": [";
@ -862,7 +931,29 @@ struct JSONLogger {
if (i > 0) {
this->stream << ", ";
}
this->stream << "{\"name\": \"" << variable_names[i] << "\"}" << std::endl;
this->stream << "{\"name\": " << stringify(variable_names[i]) << "}" << std::endl;
}
this->stream << "]" << std::endl;
}
void declare_constraints(const ConstraintSet &constraints)
{
bool has_constraints = false;
this->stream << ", \"constraints\": [";
for (const auto &item : constraints.binary_constraints_by_source().items()) {
const int src = item.key;
for (const ConstraintSet::Target &target : item.value) {
const int dst = target.variable;
if (has_constraints) {
this->stream << ", ";
}
const std::string src_name = variable_names[src];
const std::string dst_name = variable_names[dst];
this->stream << "{\"name\": " << stringify(src_name + "--" + dst_name)
<< ", \"source\": " << stringify(src_name)
<< ", \"target\": " << stringify(dst_name) << "}" << std::endl;
has_constraints = true;
}
}
this->stream << "]" << std::endl;
}
@ -876,54 +967,40 @@ struct JSONLogger {
void on_worklist_extended(const int src, const int dst)
{
std::cout << " Worklist extended: " << src << ", " << dst << std::endl;
on_event("worklist added", [&]() {
this->stream << "\"source\": " << stringify(variable_names[src]) << ", \"target\": \""
<< variable_names[dst] << "\"";
});
}
void on_binary_constraint_applied(const int src, const int dst)
{
std::cout << " Applying " << src << ", " << dst << std::endl;
on_event("constraint applied", [&]() {
this->stream << "\"source\": " << stringify(variable_names[src])
<< ", \"target\": " << stringify(variable_names[dst]);
});
}
void on_domain_init(const int var, const BitSpan domain)
{
if (has_events) {
this->stream << ", ";
}
int d = 0;
for (const int k : domain.index_range()) {
if (domain[k]) {
d |= (1 << k);
}
}
this->stream << "{\"type\": \"domain init\", \"variable\": \"" << variable_names[var]
<< "\", \"domain\": " << d << "}" << std::endl;
has_events = true;
on_event("domain init", [&]() {
this->stream << "\"variable\": " << stringify(variable_names[var])
<< ", \"domain\": " << domain_as_int(domain);
});
}
void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/)
void on_domain_reduced(const int var, const BitSpan domain)
{
std::cout << " Reduced domain!" << std::endl;
on_event("domain reduced", [&]() {
this->stream << "\"variable\": " << stringify(variable_names[var])
<< ", \"domain\": " << domain_as_int(domain);
});
}
void on_domain_empty(const int var)
{
std::cout << " FAILED! No possible values for " << var << std::endl;
}
void on_variable_state_changed(BitGroupVector<> &variable_domains)
{
for (const int i : variable_domains.index_range()) {
std::string s = std::to_string(i) + ": ";
for (const int d : IndexRange(variable_domains.group_size())) {
if (variable_domains[i][d]) {
s += std::to_string(d) + ", ";
}
else {
s += " ";
}
}
std::cout << s << std::endl;
}
on_event("domain empty",
[&]() { this->stream << "\"variable\": " << stringify(variable_names[var]); });
}
void on_solve_end()
@ -932,56 +1009,6 @@ struct JSONLogger {
}
};
class ConstraintSet {
public:
struct Target {
int variable;
BinaryConstraintFn constraint;
};
struct Source {
int variable;
BinaryConstraintFn constraint;
};
private:
MultiValueMap<int, UnaryConstraintFn> unary_;
MultiValueMap<int, Target> forward_;
MultiValueMap<int, Source> reverse_;
public:
Span<UnaryConstraintFn> get_unary_constraints(const int source_key) const
{
return unary_.lookup(source_key);
}
Span<Target> get_target_constraints(const int source_key) const
{
return forward_.lookup(source_key);
}
Span<Source> get_source_constraints(const int target_key) const
{
return reverse_.lookup(target_key);
}
BinaryConstraintFn get_binary_constraint(const int source_key, const int target_key) const
{
for (const Target &target : forward_.lookup(source_key)) {
if (target.variable == target_key) {
return target.constraint;
}
}
return nullptr;
}
void add(const int variable, UnaryConstraintFn constraint)
{
unary_.add(variable, constraint);
}
void add(const int source, const int target, BinaryConstraintFn constraint)
{
forward_.add(source, {target, constraint});
reverse_.add(target, {source, constraint});
}
};
/* Apply all unitary constraints. */
template<typename Logger = NullLogger>
static void solve_unary_constraints(const ConstraintSet &constraints,
@ -1012,7 +1039,6 @@ static void solve_binary_constraints(const ConstraintSet &constraints,
logger.on_worklist_extended(i, target.variable);
}
}
logger.on_variable_state_changed(variable_domains);
while (!worklist.is_empty()) {
const int2 key = worklist.pop();
@ -1027,7 +1053,6 @@ static void solve_binary_constraints(const ConstraintSet &constraints,
logger.on_domain_empty(key[0]);
break;
}
logger.on_variable_state_changed(variable_domains);
/* Add arcs to A from all dependant variables (except B). */
for (const ConstraintSet::Source &source : constraints.get_source_constraints(key[0])) {
@ -1058,8 +1083,6 @@ static BitGroupVector<> solve_constraints(const ConstraintSet &constraints,
solve_binary_constraints<Logger>(constraints, variable_domains, logger);
logger.on_solve_end();
logger.on_variable_state_changed(variable_domains);
return variable_domains;
}
@ -1257,8 +1280,9 @@ static void test_ac3_field_inferencing(
const IndexRange tree_input_vars = socket_vars.after(tree.interface_inputs().size());
const IndexRange tree_output_vars = tree_input_vars.after(tree.interface_outputs().size());
const int num_vars = tree_output_vars.one_after_last();
tree.ensure_topology_cache();
logger.set_variable_names(num_vars, [&](const int var) -> std::string {
auto variable_name = [&](const int var) -> std::string {
if (socket_vars.contains(var)) {
const bNodeSocket &socket = *tree.all_sockets()[var];
const bNode &node = socket.owner_node();
@ -1276,7 +1300,9 @@ static void test_ac3_field_inferencing(
return std::string("O:") + iosocket.identifier;
}
return "";
});
};
logger.declare_variables(num_vars,
[&](const int var) -> std::string { return variable_name(var); });
const Span<const bNode *> nodes = tree.toposort_right_to_left();
@ -1377,6 +1403,7 @@ static void test_ac3_field_inferencing(
/* Constraints for consistent field type across zones. */
add_node_type_constraints(tree, *node, constraints);
}
logger.declare_constraints(constraints);
BitGroupVector<> result = ac3::solve_constraints(constraints, num_vars, NumDomainValues, logger);