Sum in the tree

給你一棵樹,每個節點上都有一個數值$a_i$,但是input卻只給你奇數層節點到根節點的prefix sum,要你求出最小的$\sum\limits_{i=1}^n a_i$

題解

對於每一個偶數層的節點$u$,如果他的parent數值是$a_p$,prefix sum是$s_p$,自己的數值是$a_u$,prefix sum是$s_u$,其中一個小孩是$a_v$,prefix sum是$s_v$,那麼我們可以發現如果把$a_u$弄成最小的$s_v - s_p$,每個小孩的$a_v$直都可以是最小!

可以拿這tree想想看,給定節點的$a_i$如下

1
2
3
4
5
     1
     |
     3
   / | \
  1  2  5

這樣的$\sum\limits_{i=1}^n a_i = 12$

prefix sum算出來會是

1
2
3
4
5
     1
     |
     4(令他為a)
   / | \
  5  6  9

假設我們塗掉$a$的數值,動手試試看把各種$s_v - s_p$填入試試看吧!

1
2
3
4
5
     1
     |
     ?
   / | \
  5  6  9

如果填入最小的$s_v - s_p$,我們可以把樹的每個$a_i$還原成

1
2
3
4
5
     1
     |
     4
   / | \
  0  1  4

這樣的$\sum\limits_{i=1}^n a_i = 10$,比起12好,對吧!因為你如果放$<4$的數值,所有的小孩都要加上一個共同的數值$x$,才能彌補回不足的prefix sum數值,不划算啊!

AC Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef pair<int, int> ii;

const int N = 100010;
vector<int> g[N];
ll sum[N];
ll val[N];
int par[N];

void dfs(int u, int depth)
{
    if (depth % 2 == 0)
        val[u] = INT_MAX;
    for (auto v : g[u]) {
        if (depth % 2 == 0) {
            val[u] = min(val[u], sum[v] - sum[par[u]]);
        }
        dfs(v, depth + 1);
    }
    if (val[u] == INT_MAX) // even level leaf
        val[u] = 0;
}

bool error;
ll ans;
void dfs2(int u, ll tot)
{
    if (val[u] == -1) {
        val[u] = sum[u] - tot;
    }

    if (val[u] < 0)
        error = true;
    ans += val[u];

    for (auto v : g[u]) {
        dfs2(v, tot + val[u]);
    }
}

int main()
{
    int n;
    scanf("%d", &n);

    for (int i = 1; i < n; i++) {
        int p;
        scanf("%d", &p);
        p--;

        par[i] = p;
        g[p].push_back(i);
    }

    for (int i = 0; i < n; i++) {
        scanf("%lld", &sum[i]);
        val[i] = -1;
    }

    val[0] = sum[0];
    // get all ai for all even level nodes
    dfs(0, 1);
    error = false;
    ans = 0;
    dfs2(0, 0);

    if (error)
        printf("-1\n");
    else
        printf("%lld\n", ans);

    return 0;
}