Skip to content

rooted_tree_isomorphism.hpp

#include "noya/rooted_tree_isomorphism.hpp"

View on GitHub

#ifndef NOYA_ROOTED_TREE_ISOMORPHISM_HPP
#define NOYA_ROOTED_TREE_ISOMORPHISM_HPP 1

#include <algorithm>
#include <map>
#include <vector>

namespace noya {

/// @brief Assign canonical hash labels to rooted subtrees for isomorphism testing.
struct tree_isomorphism {
  std::map<std::vector<int>, int> mp;
  int cnt = 0;

  /// @brief Compute canonical labels for all nodes of a rooted tree.
  /// @return Vector mapping each node to its canonical subtree label.
  std::vector<int> solve(const std::vector<std::vector<int>> &g,
                         const int root = 0) {
    int N = int(g.size());
    std::vector<int> ans(N);
    auto dfs = [&](auto &dfs, int u, int p) -> int {
      std::vector<int> sons;
      for (auto v : g[u]) {
        if (v == p)
          continue;
        sons.push_back(dfs(dfs, v, u));
      }
      std::sort(sons.begin(), sons.end());
      if (!mp.count(sons))
        mp[sons] = cnt++;
      return ans[u] = mp[sons];
    };
    dfs(dfs, root, -1);
    return ans;
  }
};

} // namespace noya

#endif // NOYA_ROOTED_TREE_ISOMORPHISM_HPP
#include <algorithm>
#include <map>
#include <vector>

namespace noya {

/// @brief Assign canonical hash labels to rooted subtrees for isomorphism testing.
struct tree_isomorphism {
  std::map<std::vector<int>, int> mp;
  int cnt = 0;

  /// @brief Compute canonical labels for all nodes of a rooted tree.
  /// @return Vector mapping each node to its canonical subtree label.
  std::vector<int> solve(const std::vector<std::vector<int>> &g,
                         const int root = 0) {
    int N = int(g.size());
    std::vector<int> ans(N);
    auto dfs = [&](auto &dfs, int u, int p) -> int {
      std::vector<int> sons;
      for (auto v : g[u]) {
        if (v == p)
          continue;
        sons.push_back(dfs(dfs, v, u));
      }
      std::sort(sons.begin(), sons.end());
      if (!mp.count(sons))
        mp[sons] = cnt++;
      return ans[u] = mp[sons];
    };
    dfs(dfs, root, -1);
    return ans;
  }
};

} // namespace noya