「题解」P3320 [SDOI2015] 寻宝游戏

AtomAlpaca

Table of contents

啥你问我为啥题解满了还要写。因为题解里看起来都是感性理解没有证明。而且我好久没写题解了。

题意

link

给定一棵树,和一个点集,每次往点集里加一个点或删除一个点,并求经过点集内所有点的最短回路中最短的一条。

题解

我们考虑走两个节点最短路是什么样子的,应该是 从 uu 走到 lcalca 上,然后再从 lcalca 走到 vv 上。

然后我们把关键点中两两的路径抽出来取并集,它就构成了原图的一个联通子图,也就是一棵树。我们显然可以把一段只有头尾是关键点或 lcalca 的链上的边都缩成一个边,因为我们在这些边上没有其它走法;于是这棵树又变成了对于关键点的一个虚树,而且它包含且仅包含我们必经的点和边(或者和原来的一串边等价的一个新边),这太令人开心了。

我们考虑怎么走最近。当然是从某个节点出发 dfs 这棵树最后走回来,这样我们每条必经边恰好都经过两次,而我们要在原图构成一个回路也要求我们每条必经边至少经过两次,因为我们到达之后还要走回去。所以 dfs 也就是最短的走法了。

那我们只看关键点,大概就是,从第 11 个到达的关键点走到第 22 个,第 22 个走到第 33 个,……,第 nn 个走回第 11 个。那我们直接把这些点按照原图的 dfndfn 排序,然后 ans=dis(un,u1)+i=1n1dis(ui,ui+1)ans = dis(u_n, u_1) + \sum_{i=1}^{n-1}{ dis(u_i, u_{i+1}) } 这样。这样做一定是对的因为对它的一个子图 dfs 其实是 原图 dfs 的一个子过程。

然后我们开一个 set 无脑维护关键点,然后加入删除的时候维护一下答案就好了。复杂度 O(nlogn)O(n\log n)

代码

#include <bits/stdc++.h>

const int MAX = 1e5 + 5;
typedef long long ll;
typedef std::set<int>::iterator IT;
int n, m, t, u, v, tot, dfc; ll w, ans;
int h[MAX], dfn[MAX], dep[MAX], pos[MAX], fth[MAX][21]; ll dis[MAX];
std::bitset <MAX> vis;
std::set <int> st;
struct E { int v, x; ll w; } e[MAX << 2];
void add(int u, int v, ll w)
{
    e[++tot] = {v, h[u], w}; h[u] = tot;
    e[++tot] = {u, h[v], w}; h[v] = tot;
}

void dfs(int u, int fa)
{
    dep[u] = dep[fa] + 1; fth[u][0] = fa; dfn[u] = ++dfc; pos[dfc] = u;
    for (int i = 1; i <= 18; ++i) { fth[u][i] = fth[fth[u][i - 1]][i - 1]; }
    for (int i = h[u]; i; i = e[i].x)
    {
        int v = e[i].v; if (v == fa) { continue; }
        dis[v] = dis[u] + e[i].w; dfs(v, u);
    }
}

int lca(int u, int v)
{
    if (dep[u] < dep[v]) { std::swap(u, v); }
    for (int i = 18; i >= 0; --i) { if (dep[fth[u][i]] >= dep[v]) { u = fth[u][i]; } }
    if (u == v) { return u; }
    for (int i = 18; i >= 0; --i) { if (fth[u][i] != fth[v][i]) { u = fth[u][i]; v = fth[v][i]; } }
    return fth[u][0];
}

ll ds(int u, int v) { return dis[u] + dis[v] - 2ll * dis[lca(u, v)]; }

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <  n; ++i) { scanf("%d%d%lld", &u, &v, &w); add(u, v, w); }
    dfs(1, 0);
    for (int i = 1; i <= m; ++i)
    {
        scanf("%d", &t); IT it, lst, nxt;
        if (!vis[t]) { it = st.insert(dfn[t]).first; } else { it = st.find(dfn[t]); }
        if (it == st.begin()) { lst = --st.end(); } else { lst = std::prev(it, 1); }
        if (it == --st.end()) { nxt = st.begin(); } else { nxt = std::next(it, 1); }
        ans += (vis[t] ? -1ll : 1ll) * (ds(pos[*lst], t) + ds(t, pos[*nxt]) - ds(pos[*lst], pos[*nxt]));
        if (vis[t]) { st.erase(it); } vis[t] = (vis[t] ? 0 : 1);
        printf("%lld\n", ans);   
    }
    return 0;
}