Functions: implement common subnetwork elimination optimization
This was the last of the three network optimizations I developed in the functions branch. Common subnetwork elimination and constant folding together can get rid of most unnecessary nodes.
This commit is contained in:
		| @@ -18,10 +18,17 @@ | ||||
|  * \ingroup fn | ||||
|  */ | ||||
|  | ||||
| /* Used to check if two multi-functions have the exact same type. */ | ||||
| #include <typeinfo> | ||||
|  | ||||
| #include "FN_multi_function_builder.hh" | ||||
| #include "FN_multi_function_network_evaluation.hh" | ||||
| #include "FN_multi_function_network_optimization.hh" | ||||
|  | ||||
| #include "BLI_disjoint_set.hh" | ||||
| #include "BLI_ghash.h" | ||||
| #include "BLI_map.hh" | ||||
| #include "BLI_rand.h" | ||||
| #include "BLI_stack.hh" | ||||
|  | ||||
| namespace blender::fn::mf_network_optimization { | ||||
| @@ -292,4 +299,179 @@ void constant_folding(MFNetwork &network, ResourceCollector &resources) | ||||
|  | ||||
| /** \} */ | ||||
|  | ||||
| /* -------------------------------------------------------------------- */ | ||||
| /** \name Common Subnetwork Elimination | ||||
|  * | ||||
|  * \{ */ | ||||
|  | ||||
| static uint32_t compute_node_hash(MFFunctionNode &node, RNG *rng, Span<uint32_t> node_hashes) | ||||
| { | ||||
|   uint32_t combined_inputs_hash = 394659347u; | ||||
|   for (MFInputSocket *input_socket : node.inputs()) { | ||||
|     MFOutputSocket *origin_socket = input_socket->origin(); | ||||
|     uint32_t input_hash; | ||||
|     if (origin_socket == nullptr) { | ||||
|       input_hash = BLI_rng_get_uint(rng); | ||||
|     } | ||||
|     else { | ||||
|       input_hash = BLI_ghashutil_combine_hash(node_hashes[origin_socket->node().id()], | ||||
|                                               origin_socket->index()); | ||||
|     } | ||||
|     combined_inputs_hash = BLI_ghashutil_combine_hash(combined_inputs_hash, input_hash); | ||||
|   } | ||||
|  | ||||
|   uint32_t function_hash = node.function().hash(); | ||||
|   uint32_t node_hash = BLI_ghashutil_combine_hash(combined_inputs_hash, function_hash); | ||||
|   return node_hash; | ||||
| } | ||||
|  | ||||
| /** | ||||
|  * Produces a hash for every node. Two nodes with the same hash should have a high probability of | ||||
|  * outputting the same values. | ||||
|  */ | ||||
| static Array<uint32_t> compute_node_hashes(MFNetwork &network) | ||||
| { | ||||
|   RNG *rng = BLI_rng_new(0); | ||||
|   Array<uint32_t> node_hashes(network.node_id_amount()); | ||||
|   Array<bool> node_is_hashed(network.node_id_amount(), false); | ||||
|  | ||||
|   /* No dummy nodes are not assumed to output the same values. */ | ||||
|   for (MFDummyNode *node : network.dummy_nodes()) { | ||||
|     uint32_t node_hash = BLI_rng_get_uint(rng); | ||||
|     node_hashes[node->id()] = node_hash; | ||||
|     node_is_hashed[node->id()] = true; | ||||
|   } | ||||
|  | ||||
|   Stack<MFFunctionNode *> nodes_to_check; | ||||
|   nodes_to_check.push_multiple(network.function_nodes()); | ||||
|  | ||||
|   while (!nodes_to_check.is_empty()) { | ||||
|     MFFunctionNode &node = *nodes_to_check.peek(); | ||||
|     if (node_is_hashed[node.id()]) { | ||||
|       nodes_to_check.pop(); | ||||
|       continue; | ||||
|     } | ||||
|  | ||||
|     /* Make sure that origin nodes are hashed first. */ | ||||
|     bool all_dependencies_ready = true; | ||||
|     for (MFInputSocket *input_socket : node.inputs()) { | ||||
|       MFOutputSocket *origin_socket = input_socket->origin(); | ||||
|       if (origin_socket != nullptr) { | ||||
|         MFNode &origin_node = origin_socket->node(); | ||||
|         if (!node_is_hashed[origin_node.id()]) { | ||||
|           all_dependencies_ready = false; | ||||
|           nodes_to_check.push(&origin_node.as_function()); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     if (!all_dependencies_ready) { | ||||
|       continue; | ||||
|     } | ||||
|  | ||||
|     uint32_t node_hash = compute_node_hash(node, rng, node_hashes); | ||||
|     node_hashes[node.id()] = node_hash; | ||||
|     node_is_hashed[node.id()] = true; | ||||
|     nodes_to_check.pop(); | ||||
|   } | ||||
|  | ||||
|   BLI_rng_free(rng); | ||||
|   return node_hashes; | ||||
| } | ||||
|  | ||||
| static Map<uint32_t, Vector<MFNode *, 1>> group_nodes_by_hash(MFNetwork &network, | ||||
|                                                               Span<uint32_t> node_hashes) | ||||
| { | ||||
|   Map<uint32_t, Vector<MFNode *, 1>> nodes_by_hash; | ||||
|   for (uint id : IndexRange(network.node_id_amount())) { | ||||
|     MFNode *node = network.node_or_null_by_id(id); | ||||
|     if (node != nullptr) { | ||||
|       uint32_t node_hash = node_hashes[id]; | ||||
|       nodes_by_hash.lookup_or_add_default(node_hash).append(node); | ||||
|     } | ||||
|   } | ||||
|   return nodes_by_hash; | ||||
| } | ||||
|  | ||||
| static bool functions_are_equal(const MultiFunction &a, const MultiFunction &b) | ||||
| { | ||||
|   if (&a == &b) { | ||||
|     return true; | ||||
|   } | ||||
|   if (typeid(a) == typeid(b)) { | ||||
|     return a.equals(b); | ||||
|   } | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| static bool nodes_output_same_values(DisjointSet &cache, const MFNode &a, const MFNode &b) | ||||
| { | ||||
|   if (cache.in_same_set(a.id(), b.id())) { | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   if (a.is_dummy() || b.is_dummy()) { | ||||
|     return false; | ||||
|   } | ||||
|   if (!functions_are_equal(a.as_function().function(), b.as_function().function())) { | ||||
|     return false; | ||||
|   } | ||||
|   for (uint i : a.inputs().index_range()) { | ||||
|     const MFOutputSocket *origin_a = a.input(i).origin(); | ||||
|     const MFOutputSocket *origin_b = b.input(i).origin(); | ||||
|     if (origin_a == nullptr || origin_b == nullptr) { | ||||
|       return false; | ||||
|     } | ||||
|     if (!nodes_output_same_values(cache, origin_a->node(), origin_b->node())) { | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   cache.join(a.id(), b.id()); | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| static void relink_duplicate_nodes(MFNetwork &network, | ||||
|                                    Map<uint32_t, Vector<MFNode *, 1>> &nodes_by_hash) | ||||
| { | ||||
|   DisjointSet same_node_cache{network.node_id_amount()}; | ||||
|  | ||||
|   for (Span<MFNode *> nodes_with_same_hash : nodes_by_hash.values()) { | ||||
|     if (nodes_with_same_hash.size() <= 1) { | ||||
|       continue; | ||||
|     } | ||||
|  | ||||
|     Vector<MFNode *, 16> nodes_to_check = nodes_with_same_hash; | ||||
|     Vector<MFNode *, 16> remaining_nodes; | ||||
|     while (nodes_to_check.size() >= 2) { | ||||
|       MFNode &deduplicated_node = *nodes_to_check[0]; | ||||
|       for (MFNode *node : nodes_to_check.as_span().drop_front(1)) { | ||||
|         /* This is true with fairly high probability, but hash collisions can happen. So we have to | ||||
|          * check if the node actually output the same values. */ | ||||
|         if (nodes_output_same_values(same_node_cache, deduplicated_node, *node)) { | ||||
|           for (uint i : deduplicated_node.outputs().index_range()) { | ||||
|             network.relink(node->output(i), deduplicated_node.output(i)); | ||||
|           } | ||||
|         } | ||||
|         else { | ||||
|           remaining_nodes.append(node); | ||||
|         } | ||||
|       } | ||||
|       nodes_to_check = std::move(remaining_nodes); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| /** | ||||
|  * Tries to detect duplicate subnetworks and eliminates them. This can help quite a lot when node | ||||
|  * groups were used to create the network. | ||||
|  */ | ||||
| void common_subnetwork_elimination(MFNetwork &network) | ||||
| { | ||||
|   Array<uint32_t> node_hashes = compute_node_hashes(network); | ||||
|   Map<uint32_t, Vector<MFNode *, 1>> nodes_by_hash = group_nodes_by_hash(network, node_hashes); | ||||
|   relink_duplicate_nodes(network, nodes_by_hash); | ||||
| } | ||||
|  | ||||
| /** \} */ | ||||
|  | ||||
| }  // namespace blender::fn::mf_network_optimization | ||||
|   | ||||
		Reference in New Issue
	
	Block a user