AtomAlpaca
给定一棵树,所有叶子节点向下引出一条射线,根节点向上引出一条射线,所有边视作一条线段,保证所有线段都尽在节点处和其它线段相交。
多次询问,每次给定 ,求从 走到 最少经过多少条线。
询问强制在线。。
首先我们提出一个结论:最优解经过的所有边都在 到 的路径,及这条路径所挂的所有边中。
证明比较感性。首先考虑仅经过上述边集的所有路径中最优的一条。我们令 是较左的一个点。如果最优路径是从 向左走,然后从上面绕到右边,然后向左走到 。考虑随着深度增加,绕过去需要经过的线一定不会减少,所有最优策略一定是走到 再绕;否则修改最优路径,强制走一段不在边集上的路径,对于第一条不在边集中的边,如果我们跨这条边走到它的某一边,假如最优的路径也需要走到这一边,我们一定可以用边集中的某条边替代这条边,而且答案是不会变劣的;否则,我们一定要走到条边后折返回目标的这一边,那么我们可以用一条边替换这若干条边,答案是更优的。
形式化证明还不会。会了回来补。
所以我们考虑分别求 和 两条路径的答案,然后进行合并。
然而我们发现题目不允许经过点,因此我们考虑分别算每个节点的“左边”和“右边”。令 表示从 的左/右边,走到它的第 级祖先的左/右边最少经过多少边。考虑转移,我们有:
其它三种情况是类似的。
然后考虑边界条件。依然用 的情况举例,我们考虑建图的时候处理出每个节点左侧儿子数 和右侧儿子数 ,那我们要么穿过左边的所有边,要么穿过右侧的所有边,再从上面走一条边绕到左边。因此:
其它三种情况依然是类似的。最终合并答案也是类似地处理,要么穿过中间的所有边,要么从两边绕。令 分别是 走向 中经过的最后一个节点。考虑两个都最终停留在左侧,那么最后的合并要走的边数为 。其它情况依然是类似的。
综上我们得到了一个 的做法。然后我们可以长链剖分配合一些 查询的数据结构(比如猫树)优化到 。但是 做法实现得足够精细是可以过的!
感谢 Piggy424008 教我卡常。
#include <algorithm>
#include <iostream>
#include <cstdio>
const int MAX = 5e5 + 5;
const int LG = 18;
using std::min;
int n, q, s, u, v, lst, xrs, tot, dfc;
long long sum;
int deg[MAX], dep[MAX], lft[MAX], rht[MAX], h[MAX], dfn[MAX], pos[MAX], lg2[MAX];
int fth[MAX][LG + 5], st[LG + 5][MAX];
inline int rd()
{
char c=getchar();int x=0;bool f=0;
for(;!isdigit(c);c=getchar())f^=!(c^45);
for(;isdigit(c);c=getchar())x=(x<<1)+(x<<3)+(c^48);
if(f) { x=-x; }
return x;
}
int abs(int x) { return x > 0 ? x : -x; }
int min(int a, int b, int c, int d)
{
a < b ? b = a : 0; c < d ? d = c : 0;
return b < d ? b : d;
}
namespace GenHelper
{
unsigned z1,z2,z3,z4,b;
unsigned rand_()
{
b=((z1<<6)^z1)>>13;
z1=((z1&4294967294U)<<18)^b;
b=((z2<<2)^z2)>>27;
z2=((z2&4294967288U)<<2)^b;
b=((z3<<13)^z3)>>21;
z3=((z3&4294967280U)<<7)^b;
b=((z4<<3)^z4)>>12;
z4=((z4&4294967168U)<<13)^b;
return (z1^z2^z3^z4);
}
}
void srand(unsigned x)
{using namespace GenHelper;
z1=x; z2=(~x)^0x233333333U; z3=x^0x1234598766U; z4=(~x)+51;}
int read()
{
using namespace GenHelper;
int a=rand_()&32767;
int b=rand_()&32767;
return a*32768+b;
}
struct E { int v, x; } e[MAX];
void add(const int u, const int v)
{
lft[v] = deg[u]; e[++tot] = {v, h[u]}; h[u] = tot; ++deg[u];
}
struct N
{
int f[2][2] = {{0, 0}, {0, 0}};
N operator + (const N & y) const
{
N res;
res.f[0][0] = min(f[0][0] + y.f[0][0], f[0][0] + 1 + y.f[1][0],
f[0][1] + y.f[1][0], f[0][1] + 1 + y.f[0][0]);
res.f[0][1] = min(f[0][0] + y.f[0][1], f[0][0] + 1 + y.f[1][1],
f[0][1] + y.f[1][1], f[0][1] + 1 + y.f[0][1]);
res.f[1][0] = min(f[1][0] + y.f[0][0], f[1][0] + 1 + y.f[1][0],
f[1][1] + y.f[1][0], f[1][1] + 1 + y.f[0][0]);
res.f[1][1] = min(f[1][0] + y.f[0][1], f[1][0] + 1 + y.f[1][1],
f[1][1] + y.f[1][1], f[1][1] + 1 + y.f[0][1]);
return res;
}
} f[MAX][LG];
inline int get(const int x, const int y) { return dfn[x] < dfn[y] ? x : y; }
void init()
{
for (int i = 1; (1 << i) <= n; ++i)
{
for (int j = 1; j + (1 << i) <= n + 1; ++j)
{
st[i][j] = get(st[i - 1][j], st[i - 1][j + (1 << (i - 1))]);
}
}
}
int lca(int u, int v)
{
if (dfn[u] > dfn[v]) { std::swap(u, v); }
const int lg = lg2[dfn[v] - dfn[u]];
return get(st[lg][dfn[u] + 1], st[lg][dfn[v] - (1 << lg) + 1]);
}
void dfs(const int u)
{
dfn[u] = ++dfc; pos[dfc] = u;
st[0][dfc] = fth[u][0];
dep[u] = dep[fth[u][0]] + 1;
for (int i = 1; (1 << i) <= dep[u]; ++i) { fth[u][i] = fth[fth[u][i - 1]][i - 1]; }
for (int i = h[u]; i; i = e[i].x) { int v = e[i].v; dfs(v); }
}
N qry(int & u, const int fa)
{
N res;
int d = dep[u] - dep[fa] - 1;
if (!d) { return res; }
while (d)
{
int t = lg2[(d & -d)];
res = res + f[u][t];
u = fth[u][t];
d ^= (1 << t);
}
return res;
}
int solve(int u, int v)
{
if (u == v) { return 0; }
if (dep[u] < dep[v]) { std::swap(u, v); }
int fa = lca(u, v);
if (fa == v)
{
N ru = qry(u, v);
int ul = min(ru.f[0][0], ru.f[1][0]), ur = min(ru.f[0][1], ru.f[1][1]);
return min(ul, ur);
}
N ru = qry(u, fa), rv = qry(v, fa);
int ul = min(ru.f[0][0], ru.f[1][0]), ur = min(ru.f[0][1], ru.f[1][1]),
vl = min(rv.f[0][0], rv.f[1][0]), vr = min(rv.f[0][1], rv.f[1][1]);
return min(ul + vl + min(abs(lft[u] - lft[v]), deg[fa] + 1 - abs(lft[u] - lft[v])),
ur + vr + min(abs(lft[u] - lft[v]), deg[fa] + 1 - abs(lft[u] - lft[v])),
ul + vr + min(abs(lft[u] - (lft[v] + 1)), deg[fa] + 1 - abs(lft[u] - (lft[v] + 1))),
ur + vl + min(abs(lft[v] - (lft[u] + 1)), deg[fa] + 1 - abs(lft[v] - (lft[u] + 1))));
}
void solve0()
{
while (q--)
{
u = rd(); u ^= lst;
v = rd(); v ^= lst;
lst = solve(u, v);
printf("%d\n", lst);
}
}
void solve1()
{
srand((unsigned)s);
while (q--)
{
u = (read() ^ lst) % n + 1, v = (read() ^ lst) % n + 1;
lst = solve(u, v);
sum += lst; xrs ^= lst;
}
printf("%d %lld", xrs, sum);
}
int main()
{
n = rd(); q = rd(); s = rd();
for (int i = 2; i <= n; ++i) { fth[i][0] = rd(); add(fth[i][0], i); }
for (int i = 2; i <= n; ++i) { rht[i] = deg[fth[i][0]] - lft[i] - 1; }
for (int i = 2; i <= n; ++i) { lg2[i] = lg2[i >> 1] + 1; }
dfs(1); init();
for (int i = 2; i <= n; ++i)
{
f[i][0].f[0][0] = min(lft[i], rht[i] + 2);
f[i][0].f[0][1] = min(lft[i], rht[i]) + 1;
f[i][0].f[1][0] = min(lft[i], rht[i]) + 1;
f[i][0].f[1][1] = min(rht[i], lft[i] + 2);
}
for (int i = 2; i <= n; ++i)
{
for (int j = 1; (1 << j) <= dep[i]; ++j)
{
f[i][j] = f[i][j - 1] + f[fth[i][j - 1]][j - 1];
}
}
s == -1 ? solve0() : solve1();
return 0;
}