Functions: implement common subnetwork elimination optimization

This was the last of the three network optimizations I developed in
the functions branch. Common subnetwork elimination and constant
folding together can get rid of most unnecessary nodes.
This commit is contained in:
2020-07-08 15:10:24 +02:00
parent e3e42c00cb
commit d1f4546a59
2 changed files with 183 additions and 0 deletions

View File

@@ -18,10 +18,17 @@
* \ingroup fn
*/
/* Used to check if two multi-functions have the exact same type. */
#include <typeinfo>
#include "FN_multi_function_builder.hh"
#include "FN_multi_function_network_evaluation.hh"
#include "FN_multi_function_network_optimization.hh"
#include "BLI_disjoint_set.hh"
#include "BLI_ghash.h"
#include "BLI_map.hh"
#include "BLI_rand.h"
#include "BLI_stack.hh"
namespace blender::fn::mf_network_optimization {
@@ -292,4 +299,179 @@ void constant_folding(MFNetwork &network, ResourceCollector &resources)
/** \} */
/* -------------------------------------------------------------------- */
/** \name Common Subnetwork Elimination
*
* \{ */
static uint32_t compute_node_hash(MFFunctionNode &node, RNG *rng, Span<uint32_t> node_hashes)
{
uint32_t combined_inputs_hash = 394659347u;
for (MFInputSocket *input_socket : node.inputs()) {
MFOutputSocket *origin_socket = input_socket->origin();
uint32_t input_hash;
if (origin_socket == nullptr) {
input_hash = BLI_rng_get_uint(rng);
}
else {
input_hash = BLI_ghashutil_combine_hash(node_hashes[origin_socket->node().id()],
origin_socket->index());
}
combined_inputs_hash = BLI_ghashutil_combine_hash(combined_inputs_hash, input_hash);
}
uint32_t function_hash = node.function().hash();
uint32_t node_hash = BLI_ghashutil_combine_hash(combined_inputs_hash, function_hash);
return node_hash;
}
/**
* Produces a hash for every node. Two nodes with the same hash should have a high probability of
* outputting the same values.
*/
static Array<uint32_t> compute_node_hashes(MFNetwork &network)
{
RNG *rng = BLI_rng_new(0);
Array<uint32_t> node_hashes(network.node_id_amount());
Array<bool> node_is_hashed(network.node_id_amount(), false);
/* No dummy nodes are not assumed to output the same values. */
for (MFDummyNode *node : network.dummy_nodes()) {
uint32_t node_hash = BLI_rng_get_uint(rng);
node_hashes[node->id()] = node_hash;
node_is_hashed[node->id()] = true;
}
Stack<MFFunctionNode *> nodes_to_check;
nodes_to_check.push_multiple(network.function_nodes());
while (!nodes_to_check.is_empty()) {
MFFunctionNode &node = *nodes_to_check.peek();
if (node_is_hashed[node.id()]) {
nodes_to_check.pop();
continue;
}
/* Make sure that origin nodes are hashed first. */
bool all_dependencies_ready = true;
for (MFInputSocket *input_socket : node.inputs()) {
MFOutputSocket *origin_socket = input_socket->origin();
if (origin_socket != nullptr) {
MFNode &origin_node = origin_socket->node();
if (!node_is_hashed[origin_node.id()]) {
all_dependencies_ready = false;
nodes_to_check.push(&origin_node.as_function());
}
}
}
if (!all_dependencies_ready) {
continue;
}
uint32_t node_hash = compute_node_hash(node, rng, node_hashes);
node_hashes[node.id()] = node_hash;
node_is_hashed[node.id()] = true;
nodes_to_check.pop();
}
BLI_rng_free(rng);
return node_hashes;
}
static Map<uint32_t, Vector<MFNode *, 1>> group_nodes_by_hash(MFNetwork &network,
Span<uint32_t> node_hashes)
{
Map<uint32_t, Vector<MFNode *, 1>> nodes_by_hash;
for (uint id : IndexRange(network.node_id_amount())) {
MFNode *node = network.node_or_null_by_id(id);
if (node != nullptr) {
uint32_t node_hash = node_hashes[id];
nodes_by_hash.lookup_or_add_default(node_hash).append(node);
}
}
return nodes_by_hash;
}
static bool functions_are_equal(const MultiFunction &a, const MultiFunction &b)
{
if (&a == &b) {
return true;
}
if (typeid(a) == typeid(b)) {
return a.equals(b);
}
return false;
}
static bool nodes_output_same_values(DisjointSet &cache, const MFNode &a, const MFNode &b)
{
if (cache.in_same_set(a.id(), b.id())) {
return true;
}
if (a.is_dummy() || b.is_dummy()) {
return false;
}
if (!functions_are_equal(a.as_function().function(), b.as_function().function())) {
return false;
}
for (uint i : a.inputs().index_range()) {
const MFOutputSocket *origin_a = a.input(i).origin();
const MFOutputSocket *origin_b = b.input(i).origin();
if (origin_a == nullptr || origin_b == nullptr) {
return false;
}
if (!nodes_output_same_values(cache, origin_a->node(), origin_b->node())) {
return false;
}
}
cache.join(a.id(), b.id());
return true;
}
static void relink_duplicate_nodes(MFNetwork &network,
Map<uint32_t, Vector<MFNode *, 1>> &nodes_by_hash)
{
DisjointSet same_node_cache{network.node_id_amount()};
for (Span<MFNode *> nodes_with_same_hash : nodes_by_hash.values()) {
if (nodes_with_same_hash.size() <= 1) {
continue;
}
Vector<MFNode *, 16> nodes_to_check = nodes_with_same_hash;
Vector<MFNode *, 16> remaining_nodes;
while (nodes_to_check.size() >= 2) {
MFNode &deduplicated_node = *nodes_to_check[0];
for (MFNode *node : nodes_to_check.as_span().drop_front(1)) {
/* This is true with fairly high probability, but hash collisions can happen. So we have to
* check if the node actually output the same values. */
if (nodes_output_same_values(same_node_cache, deduplicated_node, *node)) {
for (uint i : deduplicated_node.outputs().index_range()) {
network.relink(node->output(i), deduplicated_node.output(i));
}
}
else {
remaining_nodes.append(node);
}
}
nodes_to_check = std::move(remaining_nodes);
}
}
}
/**
* Tries to detect duplicate subnetworks and eliminates them. This can help quite a lot when node
* groups were used to create the network.
*/
void common_subnetwork_elimination(MFNetwork &network)
{
Array<uint32_t> node_hashes = compute_node_hashes(network);
Map<uint32_t, Vector<MFNode *, 1>> nodes_by_hash = group_nodes_by_hash(network, node_hashes);
relink_duplicate_nodes(network, nodes_by_hash);
}
/** \} */
} // namespace blender::fn::mf_network_optimization