「题解」P5305 [GXOI/GZOI2019] 旧词

AtomAlpaca

Table of contents

题目

给一棵树和常数 kk,每次询问给出 l,r,ul, r, uv[l,r]dep(lca(u,v))k\sum_{v \in [l, r]}{\operatorname{dep}(\operatorname{lca}(u, v))^k}

正文

很小清新的一道题目,之前给学弟学妹讲课讲过,很喜欢。休学了找点事情写一下。

首先我们发现这个 kk 次方很难处理。考虑 k=1k = 1 的情况。发现答案有可减性,考虑可以固定 l=1l = 1 差分。

然后我们可以对编号在 [1,r][1, r] 的每个节点,都把它到树根的路径点值 +1+1,然后对 uu 到根的路径求和即可。

考虑拓展。思考这里我们发现我们点值 +1+1 其实是可以看成对每个节点加 (d+1)1d1(d+1)^1 - d^1,这样一条链求和就是 d1(d1)1+(d1)1(d2)1=d1d^1 - (d - 1)^1 + (d - 1) ^1 - (d - 2) ^ 1 \cdots = d^1

现在要拓展到任意 kk,我们只需要修改指数,让每次对 uu 节点点值增加 duk(du1)kd_{u}^{k} - (d_u - 1) ^ k 即可。树剖转换为一个简单的数据结构问题:给两个序列 a,ba, b,每次给一个区间 [l,r][l, r] 使得 i[l,r],aiai+bi\forall i \in[l, r], a_i \leftarrow a_i + b_i,线段树解决即可。最终复杂度 O(nlog2n)O(n\log^2n)。可以上科技摘一只 log\log 但是没必要。

据说这个技巧也叫树上差分。

代码

#include <bits/stdc++.h>

using std::cin;
using std::cout;
using std::vector;

const int MAX = 5e5 + 5;
const int MOD = 998244353;;

typedef long long ll;

struct Q { int p, x, id; } q[MAX];
vector <int> e[MAX];
ll n, m, l, r, x, v, tot, sz, p, k;
ll dfn[MAX], fth[MAX], son[MAX], siz[MAX], st[MAX << 2], tag[MAX << 2], ans[MAX], top[MAX], dep[MAX], w[MAX], wht[MAX << 2];

ll qp(ll a, ll x)
{
    ll res = a; --x;
    while (x)
    {
        if (x & 1) { res = res * a % MOD; }
        x >>= 1; a = a * a % MOD;
    }
    return res;
}

void dfs0(int u, int fa)
{
    siz[u] = 1; fth[u] = fa; dep[u] = dep[fa] + 1; ll mx = -1;
    for (int v : e[u])
    {
        if (v == fa) { continue; }
        dfs0(v, u);
        siz[u] += siz[v];
        if (siz[v] > mx) { mx = siz[v]; son[u] = v; }
    }
}

void dfs1(int u, int fa)
{
    dfn[u] = ++tot; w[tot] = ((qp(dep[u], k) - qp(dep[u] - 1, k)) % MOD + MOD) % MOD;
    if (son[fth[u]] == u) { top[u] = top[fa]; } else { top[u] = u; }
    if (!son[u]) { return ; } dfs1(son[u], u);
    for (int v : e[u])
    {
        if (v == fa or v == son[u]) { continue; }
        dfs1(v, u);
    }
}
void pd(int x)
{
    if (!tag[x]) { return ; }
    ll k = l + ((r - l) >> 1);
    tag[x << 1] += tag[x]; tag[x << 1 | 1] += tag[x];
    st[x << 1] += wht[x << 1] * tag[x]; st[x << 1 | 1] += wht[x << 1 | 1] * tag[x];
    tag[x] = 0;
}

void pu(int x) { st[x] = (st[x << 1] + st[x << 1 | 1]) % MOD; }

void build(int l, int r, int x)
{
    if (l == r) { wht[x] = w[l]; return ; }
    ll k = l + ((r - l) >> 1);
    build(l, k, x << 1); build(k + 1, r, x << 1 | 1);
    wht[x] = wht[x << 1] + wht[x << 1 | 1];
}

void add(int l, int r, int s, int t, int c, int x)
{
    if (l >= s and r <= t) { tag[x] += c; st[x] += wht[x] * c; st[x] %= MOD; return ; }
    pd(x); ll k = l + ((r - l) >> 1);
    if (s <= k) { add(l, k, s, t, c, x << 1); }
    if (t >  k) { add(k + 1, r, s, t, c, x << 1 | 1); }
    pu(x);
}

int sum(int l, int r, int s, int t, int x)
{
    if (l >= s and r <= t) { return st[x] % MOD; }
    pd(x); ll k = l + ((r - l) >> 1), res = 0;
    if (s <= k) { res += sum(l, k, s, t, x << 1); res %= MOD; }
    if (t >  k) { res += sum(k + 1, r, s, t, x << 1 | 1); res %= MOD; }
    return res % MOD;
}

void ins(int x)
{
    while (x) { add(1, n, dfn[top[x]], dfn[x], 1, 1); x = fth[top[x]]; }
}

int que(int x)
{
    ll res = 0;
    while (x) { res += sum(1, n, dfn[top[x]], dfn[x], 1); x = fth[top[x]]; res %= MOD; }
    return res;
}

bool cmp(Q a, Q b) { return a.p < b.p; }

void add(int u, int v) { e[u].push_back(v); e[v].push_back(u); }

void solve()
{
    int I = 0;
    for (int i = 1; i <= m; ++i)
    {
        while (I < q[i].p) { ins(++I); }
        ans[q[i].id] = que(q[i].x);
    }
}

int main()
{
    cin.tie(NULL);
    cout.tie(NULL);
    std::ios::sync_with_stdio(false);
    cin >> n >> m >> k;
    for (int i = 2; i <= n; ++i) { cin >> v; add(i, v); }
    dfs0(1, 0); dfs1(1, 0);
    build(1, n, 1);
    for (int i = 1; i <= m; ++i)
    {
        cin >> p >> x; q[i].p = p; q[i].x = x; q[i].id = i;
    }
    std::sort(q + 1, q + m + 1, cmp);
    solve();
    for (int i = 1; i <= m; ++i) { cout << ((ans[i] % MOD) + MOD) % MOD << '\n'; } 
}

另解

如果你是一个不那么小清新的人,我们还可以暴力分块解决这个问题。

首先 O(1)O(1)lca\operatorname{lca} 的方法是众所周知的。然后对序列分块,设 fi,u=vblockidep(lca(u,v))kf_{i, u} = \sum_{v \in block_i}{\operatorname{dep}(\operatorname{lca}(u, v))^k},处理方式是和上面相同的链加,不过同时处理多个节点我们可以一次 dfs 把块内点的祖先全标记上,再进行一次 dfs 在过程中累加, O(n)O(n) 求出一个块对所有点的贡献,然后对每个块间做前缀和,查询时散块逐个块暴力查,整块前缀和查询就好了。

复杂度是 O(n2B+qB)O(\frac{n^2}{B} + qB),取块长 B=nqB = \frac{n}{\sqrt{q}} 能做到 O(nq)O(n\sqrt{q})。跑得很慢。