AtomAlpaca
啥你问我为啥题解满了还要写。因为题解里看起来都是感性理解没有证明。而且我好久没写题解了。
给定一棵树,和一个点集,每次往点集里加一个点或删除一个点,并求经过点集内所有点的最短回路中最短的一条。
我们考虑走两个节点最短路是什么样子的,应该是 从 走到 上,然后再从 走到 上。
然后我们把关键点中两两的路径抽出来取并集,它就构成了原图的一个联通子图,也就是一棵树。我们显然可以把一段只有头尾是关键点或 的链上的边都缩成一个边,因为我们在这些边上没有其它走法;于是这棵树又变成了对于关键点的一个虚树,而且它包含且仅包含我们必经的点和边(或者和原来的一串边等价的一个新边),这太令人开心了。
我们考虑怎么走最近。当然是从某个节点出发 dfs 这棵树最后走回来,这样我们每条必经边恰好都经过两次,而我们要在原图构成一个回路也要求我们每条必经边至少经过两次,因为我们到达之后还要走回去。所以 dfs 也就是最短的走法了。
那我们只看关键点,大概就是,从第 个到达的关键点走到第 个,第 个走到第 个,……,第 个走回第 个。那我们直接把这些点按照原图的 排序,然后 这样。这样做一定是对的因为对它的一个子图 dfs 其实是 原图 dfs 的一个子过程。
然后我们开一个 set 无脑维护关键点,然后加入删除的时候维护一下答案就好了。复杂度 。
#include <bits/stdc++.h>
const int MAX = 1e5 + 5;
typedef long long ll;
typedef std::set<int>::iterator IT;
int n, m, t, u, v, tot, dfc; ll w, ans;
int h[MAX], dfn[MAX], dep[MAX], pos[MAX], fth[MAX][21]; ll dis[MAX];
std::bitset <MAX> vis;
std::set <int> st;
struct E { int v, x; ll w; } e[MAX << 2];
void add(int u, int v, ll w)
{
e[++tot] = {v, h[u], w}; h[u] = tot;
e[++tot] = {u, h[v], w}; h[v] = tot;
}
void dfs(int u, int fa)
{
dep[u] = dep[fa] + 1; fth[u][0] = fa; dfn[u] = ++dfc; pos[dfc] = u;
for (int i = 1; i <= 18; ++i) { fth[u][i] = fth[fth[u][i - 1]][i - 1]; }
for (int i = h[u]; i; i = e[i].x)
{
int v = e[i].v; if (v == fa) { continue; }
dis[v] = dis[u] + e[i].w; dfs(v, u);
}
}
int lca(int u, int v)
{
if (dep[u] < dep[v]) { std::swap(u, v); }
for (int i = 18; i >= 0; --i) { if (dep[fth[u][i]] >= dep[v]) { u = fth[u][i]; } }
if (u == v) { return u; }
for (int i = 18; i >= 0; --i) { if (fth[u][i] != fth[v][i]) { u = fth[u][i]; v = fth[v][i]; } }
return fth[u][0];
}
ll ds(int u, int v) { return dis[u] + dis[v] - 2ll * dis[lca(u, v)]; }
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i < n; ++i) { scanf("%d%d%lld", &u, &v, &w); add(u, v, w); }
dfs(1, 0);
for (int i = 1; i <= m; ++i)
{
scanf("%d", &t); IT it, lst, nxt;
if (!vis[t]) { it = st.insert(dfn[t]).first; } else { it = st.find(dfn[t]); }
if (it == st.begin()) { lst = --st.end(); } else { lst = std::prev(it, 1); }
if (it == --st.end()) { nxt = st.begin(); } else { nxt = std::next(it, 1); }
ans += (vis[t] ? -1ll : 1ll) * (ds(pos[*lst], t) + ds(t, pos[*nxt]) - ds(pos[*lst], pos[*nxt]));
if (vis[t]) { st.erase(it); } vis[t] = (vis[t] ? 0 : 1);
printf("%lld\n", ans);
}
return 0;
}