題目連結

我們先看看如果只要求取以$0$為root的tree下的distance,對於每個 node 都記錄對於他來說的 subtree nodes (including self) 和 subtree total path sum。可以發現,$0$ 的答案就是所有子樹的 = node count + 所有子樹的 total path sum

對於所有其他非 0 的 node,我們如果把它當 root 看,誒,不就是從 0 開始 dfs 到他時,把他的parent那邊當做 subtree 看了嗎!

AC Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#ifdef LOCAL
#include <bits/stdc++.h>
using namespace std;

// tree node stuff here...
#endif

static int __initialSetup = []()
{
    std::ios::sync_with_stdio(false);
    cin.tie(NULL);
    return 0;
}
();

// handle special cases first
// [], "", ...
// range of input?

class Solution
{
private:
    void dfs(int u, int par, vector<vector<int>> &g,
             vector<int> &subtreeNodeCount, vector<int> &subtreeSum, int n,
             bool isSecondPass = false)
    {
        if (isSecondPass == false)
            subtreeNodeCount[u]++;

        if (isSecondPass && par != -1) {
            // clean up info of current subtree from parent's answer
            int parSubtreeSum = subtreeSum[par] - subtreeSum[u] - subtreeNodeCount[u];
            // re-mount parent subtree, using current node as root
            subtreeSum[u] = parSubtreeSum + (n - subtreeNodeCount[u]) + subtreeSum[u];
        }

        for (auto v : g[u]) {
            if (v == par)
                continue;
            dfs(v, u, g, subtreeNodeCount, subtreeSum, n, isSecondPass);
            if (isSecondPass == false) {
                subtreeNodeCount[u] += subtreeNodeCount[v];
                subtreeSum[u] += subtreeNodeCount[v] + subtreeSum[v];
            }
        }
    }

public:
    vector<int> sumOfDistancesInTree(int N, vector<vector<int>> &edges)
    {
        vector<int> subtreeNodeCount(N, 0), subtreeSum(N, 0);

        vector<vector<int>> g(N, vector<int>());
        for (auto edge : edges) {
            g[edge[0]].push_back(edge[1]);
            g[edge[1]].push_back(edge[0]);
        }

        dfs(0, -1, g, subtreeNodeCount, subtreeSum, N);
        dfs(0, -1, g, subtreeNodeCount, subtreeSum, N, true);

        return subtreeSum;
    }
};

#ifdef LOCAL
int main()
{
    return 0;
}
#endif