Rerooting tutorial
Basic example
Sometimes a problem involves computing something about the subtree of a node based on some information you have about each child’s subtree. A simple example of this is the problem Tree Distances II from CSES.
In this problem, you are asked to compute for each node, the sum of tree distances from that node to every other node. As you might have guessed, since this is about rerooting, we will turn this into a subtree problem, and then find a way to reroot the tree.
Let’s root the tree at an arbitrary vertex
Then, after some calculations, we find that
Where
Now, clearly
auto dfs = [&](auto &&self, int x, int par, i64 sum) -> void {
ans[x] = dp[x] + sum + (n - size[x]);
for (int y: adj[x]) if (y != par) {
self(self, y, x, ans[x] - dp[y] - size[y]);
}
};
Another example
Let’s solve this problem.
Short statement: You are given a tree with
There are probably many ways to solve this problem, so I will show you how I did it.
Well, to make
Let’s consider the LCA of
Then it’s not hard to see that
For each vertex
To do that, notice that this is the same as calculating
Here is the core of my implementation for this problem.
void solve() {
int n, q;
cin >> n >> q;
Tree t(n);
for (int i = 1; i < n; i++) {
int p;
cin >> p;
--p;
t.addEdge(p, i);
}
t.build();
TreeMove tmv(t);
VI ans(q);
VVI qs(n), qp(n);
for (int i = 0; i < q; i++) {
int a, b;
cin >> a >> b;
a--,b--;
qs[tmv.lca(a, b)].push_back(i);
qp[a].push_back(i);
qp[b].push_back(i);
}
vector<i64> dp(n);
{
auto dfs = [&](auto &&self, int x) -> void {
dp[x] = 1;
for (int y: t.adj[x]) if (y != t.par[x]) {
self(self, y);
dp[x] = dp[x] * (dp[y] + 1) % MOD;
}
};
dfs(dfs, 0);
}
{
auto dfs = [&](auto &&self, int x, i64 acc, i64 sum) -> void {
sum = (sum + dp[x]) % MOD;
for (int i: qs[x]) {
ans[i] -= 2 * sum;
ans[i] %= MOD;
if (ans[i] < 0) ans[i] += MOD;
ans[i] += dp[x] * (acc + 1) % MOD;
ans[i] %= MOD;
}
for (int i: qp[x]) {
ans[i] = (ans[i] + sum) % MOD;
}
auto adj = t.adj[x];
if (t.par[x] != x) adj.erase(find(adj.begin(), adj.end(), t.par[x]));
int d = adj.size();
vector<i64> mul(d, 1);
for (int i = 0, a = 1; i < d; i++) {
mul[i] = mul[i] * a % MOD;
a = a * (dp[adj[i]] + 1) % MOD;
}
for (int i = d - 1, a = 1; i >= 0; i--) {
mul[i] = mul[i] * a % MOD;
a = a * (dp[adj[i]] + 1) % MOD;
}
for (int i = 0; i < d; i++) {
int y = adj[i];
self(self, y, mul[i] * (acc + 1) % MOD, sum);
}
};
dfs(dfs, 0, 0, 0);
}
for (int i = 0; i < q; i++) {
cout << ans[i] << endl;
}
}