WIP: Field type inferencing using a constraint solver method #120420
|
@ -716,6 +716,69 @@ using UnaryConstraintFn = std::function<bool(int value)>;
|
|||
/** Binary constraint function, returns true if both values are compatible. */
|
||||
using BinaryConstraintFn = std::function<bool(int value_a, int value_b)>;
|
||||
|
||||
class ConstraintSet {
|
||||
public:
|
||||
struct Target {
|
||||
int variable;
|
||||
BinaryConstraintFn constraint;
|
||||
};
|
||||
struct Source {
|
||||
int variable;
|
||||
BinaryConstraintFn constraint;
|
||||
};
|
||||
|
||||
private:
|
||||
MultiValueMap<int, UnaryConstraintFn> unary_;
|
||||
MultiValueMap<int, Target> forward_;
|
||||
MultiValueMap<int, Source> reverse_;
|
||||
|
||||
public:
|
||||
const MultiValueMap<int, UnaryConstraintFn> &all_unary_constraints() const
|
||||
{
|
||||
return unary_;
|
||||
}
|
||||
const MultiValueMap<int, Target> &binary_constraints_by_source() const
|
||||
{
|
||||
return forward_;
|
||||
}
|
||||
const MultiValueMap<int, Source> &binary_constraints_by_target() const
|
||||
{
|
||||
return reverse_;
|
||||
}
|
||||
|
||||
Span<UnaryConstraintFn> get_unary_constraints(const int source_key) const
|
||||
{
|
||||
return unary_.lookup(source_key);
|
||||
}
|
||||
Span<Target> get_target_constraints(const int source_key) const
|
||||
{
|
||||
return forward_.lookup(source_key);
|
||||
}
|
||||
Span<Source> get_source_constraints(const int target_key) const
|
||||
{
|
||||
return reverse_.lookup(target_key);
|
||||
}
|
||||
BinaryConstraintFn get_binary_constraint(const int source_key, const int target_key) const
|
||||
{
|
||||
for (const Target &target : forward_.lookup(source_key)) {
|
||||
if (target.variable == target_key) {
|
||||
return target.constraint;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void add(const int variable, UnaryConstraintFn constraint)
|
||||
{
|
||||
unary_.add(variable, constraint);
|
||||
}
|
||||
void add(const int source, const int target, BinaryConstraintFn constraint)
|
||||
{
|
||||
forward_.add(source, {target, constraint});
|
||||
reverse_.add(target, {source, constraint});
|
||||
}
|
||||
};
|
||||
|
||||
/* Remove all domain values that are not allowed by the constraint. */
|
||||
static void reduce_unary(const UnaryConstraintFn &constraint, MutableBitSpan domain)
|
||||
{
|
||||
|
@ -757,23 +820,20 @@ struct NullLogger {
|
|||
void on_start(StringRef /*message*/) {}
|
||||
void on_end() {}
|
||||
|
||||
static void set_variable_names(const int /*num_vars*/,
|
||||
FunctionRef<std::string(int)> /*names_fn*/)
|
||||
{
|
||||
}
|
||||
void declare_variables(const int /*num_vars*/, FunctionRef<std::string(int)> /*names_fn*/) {}
|
||||
void declare_constraints(const int /*num_cons*/, FunctionRef<int2(int)> /*nodes_fn*/) {}
|
||||
|
||||
static void notify(StringRef /*message*/) {}
|
||||
void notify(StringRef /*message*/) {}
|
||||
|
||||
static void on_solve_start() {}
|
||||
static void on_worklist_extended(const int /*var_src*/, const int /*var_dst*/) {}
|
||||
static void on_binary_constraint_applied(const int /*src*/, const int /*dst*/) {}
|
||||
void on_solve_start() {}
|
||||
void on_worklist_extended(const int /*var_src*/, const int /*var_dst*/) {}
|
||||
void on_binary_constraint_applied(const int /*src*/, const int /*dst*/) {}
|
||||
|
||||
static void on_domain_init(const int /*var*/, const BitSpan /*domain*/) {}
|
||||
static void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/) {}
|
||||
static void on_domain_empty(const int /*var*/) {}
|
||||
void on_domain_init(const int /*var*/, const BitSpan /*domain*/) {}
|
||||
void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/) {}
|
||||
void on_domain_empty(const int /*var*/) {}
|
||||
|
||||
static void on_variable_state_changed(BitGroupVector<> & /*variable_domains*/) {}
|
||||
static void on_solve_end() {}
|
||||
void on_solve_end() {}
|
||||
};
|
||||
|
||||
struct PrintLogger {
|
||||
|
@ -783,59 +843,41 @@ struct PrintLogger {
|
|||
}
|
||||
void on_end() {}
|
||||
|
||||
static void set_variable_names(const int /*num_vars*/,
|
||||
FunctionRef<std::string(int)> /*names_fn*/)
|
||||
{
|
||||
}
|
||||
void declare_variables(const int /*num_vars*/, FunctionRef<std::string(int)> /*names_fn*/) {}
|
||||
void declare_constraints(const int /*num_cons*/, FunctionRef<int2(int)> /*nodes_fn*/) {}
|
||||
|
||||
static void notify(StringRef message)
|
||||
void notify(StringRef message)
|
||||
{
|
||||
std::cout << message << std::endl;
|
||||
}
|
||||
|
||||
static void on_solve_start() {}
|
||||
void on_solve_start() {}
|
||||
|
||||
static void on_worklist_extended(const int src, const int dst)
|
||||
void on_worklist_extended(const int src, const int dst)
|
||||
{
|
||||
std::cout << " Worklist extended: " << src << ", " << dst << std::endl;
|
||||
}
|
||||
|
||||
static void on_binary_constraint_applied(const int src, const int dst)
|
||||
void on_binary_constraint_applied(const int src, const int dst)
|
||||
{
|
||||
std::cout << " Applying " << src << ", " << dst << std::endl;
|
||||
}
|
||||
|
||||
static void on_domain_init(const int /*var*/, const BitSpan /*domain*/)
|
||||
void on_domain_init(const int /*var*/, const BitSpan /*domain*/)
|
||||
{
|
||||
std::cout << " Initialized domain" << std::endl;
|
||||
}
|
||||
|
||||
static void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/)
|
||||
void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/)
|
||||
{
|
||||
std::cout << " Reduced domain!" << std::endl;
|
||||
}
|
||||
|
||||
static void on_domain_empty(const int var)
|
||||
void on_domain_empty(const int var)
|
||||
{
|
||||
std::cout << " FAILED! No possible values for " << var << std::endl;
|
||||
}
|
||||
|
||||
static void on_variable_state_changed(BitGroupVector<> &variable_domains)
|
||||
{
|
||||
for (const int i : variable_domains.index_range()) {
|
||||
std::string s = std::to_string(i) + ": ";
|
||||
for (const int d : IndexRange(variable_domains.group_size())) {
|
||||
if (variable_domains[i][d]) {
|
||||
s += std::to_string(d) + ", ";
|
||||
}
|
||||
else {
|
||||
s += " ";
|
||||
}
|
||||
}
|
||||
std::cout << s << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
static void on_solve_end() {}
|
||||
};
|
||||
|
||||
|
@ -853,7 +895,34 @@ struct JSONLogger {
|
|||
this->stream << "}" << std::endl;
|
||||
}
|
||||
|
||||
void set_variable_names(const int num_vars, FunctionRef<std::string(int)> names_fn)
|
||||
template<typename Fn> void on_event(StringRef type, Fn fn)
|
||||
{
|
||||
if (has_events) {
|
||||
this->stream << ", ";
|
||||
}
|
||||
this->stream << "{\"type\": \"" << type << "\", ";
|
||||
fn();
|
||||
this->stream << "}" << std::endl;
|
||||
has_events = true;
|
||||
}
|
||||
|
||||
std::string stringify(StringRef s)
|
||||
{
|
||||
return "\"" + s + "\"";
|
||||
}
|
||||
|
||||
int domain_as_int(const BitSpan domain)
|
||||
{
|
||||
int d = 0;
|
||||
for (const int k : domain.index_range()) {
|
||||
if (domain[k]) {
|
||||
d |= (1 << k);
|
||||
}
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
||||
void declare_variables(const int num_vars, FunctionRef<std::string(int)> names_fn)
|
||||
{
|
||||
variable_names.reinitialize(num_vars);
|
||||
this->stream << "\"variables\": [";
|
||||
|
@ -862,7 +931,29 @@ struct JSONLogger {
|
|||
if (i > 0) {
|
||||
this->stream << ", ";
|
||||
}
|
||||
this->stream << "{\"name\": \"" << variable_names[i] << "\"}" << std::endl;
|
||||
this->stream << "{\"name\": " << stringify(variable_names[i]) << "}" << std::endl;
|
||||
}
|
||||
this->stream << "]" << std::endl;
|
||||
}
|
||||
|
||||
void declare_constraints(const ConstraintSet &constraints)
|
||||
{
|
||||
bool has_constraints = false;
|
||||
this->stream << ", \"constraints\": [";
|
||||
for (const auto &item : constraints.binary_constraints_by_source().items()) {
|
||||
const int src = item.key;
|
||||
for (const ConstraintSet::Target &target : item.value) {
|
||||
const int dst = target.variable;
|
||||
if (has_constraints) {
|
||||
this->stream << ", ";
|
||||
}
|
||||
const std::string src_name = variable_names[src];
|
||||
const std::string dst_name = variable_names[dst];
|
||||
this->stream << "{\"name\": " << stringify(src_name + "--" + dst_name)
|
||||
<< ", \"source\": " << stringify(src_name)
|
||||
<< ", \"target\": " << stringify(dst_name) << "}" << std::endl;
|
||||
has_constraints = true;
|
||||
}
|
||||
}
|
||||
this->stream << "]" << std::endl;
|
||||
}
|
||||
|
@ -876,54 +967,40 @@ struct JSONLogger {
|
|||
|
||||
void on_worklist_extended(const int src, const int dst)
|
||||
{
|
||||
std::cout << " Worklist extended: " << src << ", " << dst << std::endl;
|
||||
on_event("worklist added", [&]() {
|
||||
this->stream << "\"source\": " << stringify(variable_names[src]) << ", \"target\": \""
|
||||
<< variable_names[dst] << "\"";
|
||||
});
|
||||
}
|
||||
|
||||
void on_binary_constraint_applied(const int src, const int dst)
|
||||
{
|
||||
std::cout << " Applying " << src << ", " << dst << std::endl;
|
||||
on_event("constraint applied", [&]() {
|
||||
this->stream << "\"source\": " << stringify(variable_names[src])
|
||||
<< ", \"target\": " << stringify(variable_names[dst]);
|
||||
});
|
||||
}
|
||||
|
||||
void on_domain_init(const int var, const BitSpan domain)
|
||||
{
|
||||
if (has_events) {
|
||||
this->stream << ", ";
|
||||
}
|
||||
int d = 0;
|
||||
for (const int k : domain.index_range()) {
|
||||
if (domain[k]) {
|
||||
d |= (1 << k);
|
||||
}
|
||||
}
|
||||
this->stream << "{\"type\": \"domain init\", \"variable\": \"" << variable_names[var]
|
||||
<< "\", \"domain\": " << d << "}" << std::endl;
|
||||
has_events = true;
|
||||
on_event("domain init", [&]() {
|
||||
this->stream << "\"variable\": " << stringify(variable_names[var])
|
||||
<< ", \"domain\": " << domain_as_int(domain);
|
||||
});
|
||||
}
|
||||
|
||||
void on_domain_reduced(const int /*var*/, const BitSpan /*domain*/)
|
||||
void on_domain_reduced(const int var, const BitSpan domain)
|
||||
{
|
||||
std::cout << " Reduced domain!" << std::endl;
|
||||
on_event("domain reduced", [&]() {
|
||||
this->stream << "\"variable\": " << stringify(variable_names[var])
|
||||
<< ", \"domain\": " << domain_as_int(domain);
|
||||
});
|
||||
}
|
||||
|
||||
void on_domain_empty(const int var)
|
||||
{
|
||||
std::cout << " FAILED! No possible values for " << var << std::endl;
|
||||
}
|
||||
|
||||
void on_variable_state_changed(BitGroupVector<> &variable_domains)
|
||||
{
|
||||
for (const int i : variable_domains.index_range()) {
|
||||
std::string s = std::to_string(i) + ": ";
|
||||
for (const int d : IndexRange(variable_domains.group_size())) {
|
||||
if (variable_domains[i][d]) {
|
||||
s += std::to_string(d) + ", ";
|
||||
}
|
||||
else {
|
||||
s += " ";
|
||||
}
|
||||
}
|
||||
std::cout << s << std::endl;
|
||||
}
|
||||
on_event("domain empty",
|
||||
[&]() { this->stream << "\"variable\": " << stringify(variable_names[var]); });
|
||||
}
|
||||
|
||||
void on_solve_end()
|
||||
|
@ -932,56 +1009,6 @@ struct JSONLogger {
|
|||
}
|
||||
};
|
||||
|
||||
class ConstraintSet {
|
||||
public:
|
||||
struct Target {
|
||||
int variable;
|
||||
BinaryConstraintFn constraint;
|
||||
};
|
||||
struct Source {
|
||||
int variable;
|
||||
BinaryConstraintFn constraint;
|
||||
};
|
||||
|
||||
private:
|
||||
MultiValueMap<int, UnaryConstraintFn> unary_;
|
||||
MultiValueMap<int, Target> forward_;
|
||||
MultiValueMap<int, Source> reverse_;
|
||||
|
||||
public:
|
||||
Span<UnaryConstraintFn> get_unary_constraints(const int source_key) const
|
||||
{
|
||||
return unary_.lookup(source_key);
|
||||
}
|
||||
Span<Target> get_target_constraints(const int source_key) const
|
||||
{
|
||||
return forward_.lookup(source_key);
|
||||
}
|
||||
Span<Source> get_source_constraints(const int target_key) const
|
||||
{
|
||||
return reverse_.lookup(target_key);
|
||||
}
|
||||
BinaryConstraintFn get_binary_constraint(const int source_key, const int target_key) const
|
||||
{
|
||||
for (const Target &target : forward_.lookup(source_key)) {
|
||||
if (target.variable == target_key) {
|
||||
return target.constraint;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void add(const int variable, UnaryConstraintFn constraint)
|
||||
{
|
||||
unary_.add(variable, constraint);
|
||||
}
|
||||
void add(const int source, const int target, BinaryConstraintFn constraint)
|
||||
{
|
||||
forward_.add(source, {target, constraint});
|
||||
reverse_.add(target, {source, constraint});
|
||||
}
|
||||
};
|
||||
|
||||
/* Apply all unitary constraints. */
|
||||
template<typename Logger = NullLogger>
|
||||
static void solve_unary_constraints(const ConstraintSet &constraints,
|
||||
|
@ -1012,7 +1039,6 @@ static void solve_binary_constraints(const ConstraintSet &constraints,
|
|||
logger.on_worklist_extended(i, target.variable);
|
||||
}
|
||||
}
|
||||
logger.on_variable_state_changed(variable_domains);
|
||||
|
||||
while (!worklist.is_empty()) {
|
||||
const int2 key = worklist.pop();
|
||||
|
@ -1027,7 +1053,6 @@ static void solve_binary_constraints(const ConstraintSet &constraints,
|
|||
logger.on_domain_empty(key[0]);
|
||||
break;
|
||||
}
|
||||
logger.on_variable_state_changed(variable_domains);
|
||||
|
||||
/* Add arcs to A from all dependant variables (except B). */
|
||||
for (const ConstraintSet::Source &source : constraints.get_source_constraints(key[0])) {
|
||||
|
@ -1058,8 +1083,6 @@ static BitGroupVector<> solve_constraints(const ConstraintSet &constraints,
|
|||
solve_binary_constraints<Logger>(constraints, variable_domains, logger);
|
||||
logger.on_solve_end();
|
||||
|
||||
logger.on_variable_state_changed(variable_domains);
|
||||
|
||||
return variable_domains;
|
||||
}
|
||||
|
||||
|
@ -1257,8 +1280,9 @@ static void test_ac3_field_inferencing(
|
|||
const IndexRange tree_input_vars = socket_vars.after(tree.interface_inputs().size());
|
||||
const IndexRange tree_output_vars = tree_input_vars.after(tree.interface_outputs().size());
|
||||
const int num_vars = tree_output_vars.one_after_last();
|
||||
|
||||
tree.ensure_topology_cache();
|
||||
logger.set_variable_names(num_vars, [&](const int var) -> std::string {
|
||||
auto variable_name = [&](const int var) -> std::string {
|
||||
if (socket_vars.contains(var)) {
|
||||
const bNodeSocket &socket = *tree.all_sockets()[var];
|
||||
const bNode &node = socket.owner_node();
|
||||
|
@ -1276,7 +1300,9 @@ static void test_ac3_field_inferencing(
|
|||
return std::string("O:") + iosocket.identifier;
|
||||
}
|
||||
return "";
|
||||
});
|
||||
};
|
||||
logger.declare_variables(num_vars,
|
||||
[&](const int var) -> std::string { return variable_name(var); });
|
||||
|
||||
const Span<const bNode *> nodes = tree.toposort_right_to_left();
|
||||
|
||||
|
@ -1377,6 +1403,7 @@ static void test_ac3_field_inferencing(
|
|||
/* Constraints for consistent field type across zones. */
|
||||
add_node_type_constraints(tree, *node, constraints);
|
||||
}
|
||||
logger.declare_constraints(constraints);
|
||||
|
||||
BitGroupVector<> result = ac3::solve_constraints(constraints, num_vars, NumDomainValues, logger);
|
||||
|
||||
|
|
Loading…
Reference in New Issue