「题解」 P7348 「MCOI-04」重型管制巡航机

AtomAlpaca

Table of contents

题意

link

给定一棵树,所有叶子节点向下引出一条射线,根节点向上引出一条射线,所有边视作一条线段,保证所有线段都尽在节点处和其它线段相交。

多次询问,每次给定 u,vu, v,求从 uu 走到 vv 最少经过多少条线。

询问强制在线。n5×105,q5×106n \le 5\times 10^5, q \le 5 \times 10^6

题解

首先我们提出一个结论:最优解经过的所有边都在 uuvv 的路径,及这条路径所挂的所有边中。

证明比较感性。首先考虑仅经过上述边集的所有路径中最优的一条。我们令 uu 是较左的一个点。如果最优路径是从 uu 向左走,然后从上面绕到右边,然后向左走到 vv。考虑随着深度增加,绕过去需要经过的线一定不会减少,所有最优策略一定是走到 lca(u,v)\operatorname{lca}(u, v) 再绕;否则修改最优路径,强制走一段不在边集上的路径,对于第一条不在边集中的边,如果我们跨这条边走到它的某一边,假如最优的路径也需要走到这一边,我们一定可以用边集中的某条边替代这条边,而且答案是不会变劣的;否则,我们一定要走到条边后折返回目标的这一边,那么我们可以用一条边替换这若干条边,答案是更优的。

形式化证明还不会。会了回来补。

所以我们考虑分别求 ulca(u,v)u \rightarrow \operatorname{lca}(u, v)vlca(u,v)v \rightarrow \operatorname{lca}(u, v) 两条路径的答案,然后进行合并。

然而我们发现题目不允许经过点,因此我们考虑分别算每个节点的“左边”和“右边”。令 fu,k,0/1,0/1f_{u, k, 0/1, 0/1} 表示从 uu 的左/右边,走到它的第 2k2^k 级祖先的左/右边最少经过多少边。考虑转移,我们有:

f_{u,k,0,0}=min{fu,k1,0,0+ffak1,k1,0,0,fu,k1,0,0+ffak1,k1,1,0+1,fu,k1,0,1+ffak1,k1,0,0+1,fu,k1,0,1+ffak1,k1,1,0\begin{aligned} f\_\{u, k, 0, 0\} = \min\left\{ \begin{aligned} &f_{u, k - 1, 0, 0} + f_{fa^{k - 1}, k - 1, 0, 0}, \\ &f_{u, k - 1, 0, 0} + f_{fa^{k - 1}, k - 1, 1, 0} + 1, \\ &f_{u, k - 1, 0, 1} + f_{fa^{k - 1}, k - 1, 0, 0} + 1, \\ &f_{u, k - 1, 0, 1} + f_{fa^{k - 1}, k - 1, 1, 0} \end{aligned} \right. \end{aligned}

其它三种情况是类似的。

然后考虑边界条件。依然用 fu,0,0,0f_{u, 0, 0, 0} 的情况举例,我们考虑建图的时候处理出每个节点左侧儿子数 lul_u 和右侧儿子数 rur_u,那我们要么穿过左边的所有边,要么穿过右侧的所有边,再从上面走一条边绕到左边。因此:

fu,0,0,0=min{lu,ru+2\begin{aligned} f_{u, 0, 0, 0} = \min\left\{ \begin{aligned} &l_u,\\ &r_u + 2 \end{aligned} \right. \end{aligned}

其它三种情况依然是类似的。最终合并答案也是类似地处理,要么穿过中间的所有边,要么从两边绕。令 a,ba, b 分别是 u,vu, v 走向 lca(u,v)\operatorname{lca}(u, v) 中经过的最后一个节点。考虑两个都最终停留在左侧,那么最后的合并要走的边数为 min(abs(l[u]l[v]),deg[fa]+1abs(l[u]l[v]))\min( \operatorname{abs}(l[u] - l[v]), \deg[fa] + 1 - \operatorname{abs}(l[u] - l[v]))。其它情况依然是类似的。

综上我们得到了一个 O(qnlogn)O(qn\log n) 的做法。然后我们可以长链剖分配合一些 O(1)O(1) 查询的数据结构(比如猫树)优化到 O(qn)O(qn)。但是 log\log 做法实现得足够精细是可以过的!

代码

感谢 Piggy424008 教我卡常。

#include <algorithm>
#include <iostream>
#include <cstdio>

const int MAX = 5e5 + 5;
const int LG = 18;

using std::min;

int n, q, s, u, v, lst, xrs, tot, dfc;
long long sum;
int deg[MAX], dep[MAX], lft[MAX], rht[MAX], h[MAX], dfn[MAX], pos[MAX], lg2[MAX];
int fth[MAX][LG + 5], st[LG + 5][MAX];
inline int rd()
{
    char c=getchar();int x=0;bool f=0;
    for(;!isdigit(c);c=getchar())f^=!(c^45);
    for(;isdigit(c);c=getchar())x=(x<<1)+(x<<3)+(c^48);
    if(f) { x=-x; }
    return x;
}

int abs(int x) { return x > 0 ? x : -x; }
int min(int a, int b, int c, int d)
{
  a < b ? b = a : 0; c < d ? d = c : 0;
  return b < d ? b : d;
}

namespace GenHelper
{
    unsigned z1,z2,z3,z4,b;
    unsigned rand_()
    {
      b=((z1<<6)^z1)>>13;
      z1=((z1&4294967294U)<<18)^b;
      b=((z2<<2)^z2)>>27;
      z2=((z2&4294967288U)<<2)^b;
      b=((z3<<13)^z3)>>21;
      z3=((z3&4294967280U)<<7)^b;
      b=((z4<<3)^z4)>>12;
      z4=((z4&4294967168U)<<13)^b;
      return (z1^z2^z3^z4);
    }
}
void srand(unsigned x)
{using namespace GenHelper;
z1=x; z2=(~x)^0x233333333U; z3=x^0x1234598766U; z4=(~x)+51;}
int read()
{
    using namespace GenHelper;
    int a=rand_()&32767;
    int b=rand_()&32767;
    return a*32768+b;
}

struct E { int v, x; } e[MAX];
void add(const int u, const int v)
{
  lft[v] = deg[u]; e[++tot] = {v, h[u]}; h[u] = tot; ++deg[u];
}

struct N
{
  int f[2][2] = {{0, 0}, {0, 0}};
  N operator + (const N & y) const
  {
    N res;
    res.f[0][0] = min(f[0][0] + y.f[0][0], f[0][0] + 1 + y.f[1][0], 
                      f[0][1] + y.f[1][0], f[0][1] + 1 + y.f[0][0]);
    res.f[0][1] = min(f[0][0] + y.f[0][1], f[0][0] + 1 + y.f[1][1],
                      f[0][1] + y.f[1][1], f[0][1] + 1 + y.f[0][1]);
    res.f[1][0] = min(f[1][0] + y.f[0][0], f[1][0] + 1 + y.f[1][0], 
                      f[1][1] + y.f[1][0], f[1][1] + 1 + y.f[0][0]);
    res.f[1][1] = min(f[1][0] + y.f[0][1], f[1][0] + 1 + y.f[1][1],
                      f[1][1] + y.f[1][1], f[1][1] + 1 + y.f[0][1]);
    return res;
  }
} f[MAX][LG];

inline int get(const int x, const int y) { return dfn[x] < dfn[y] ? x : y; }

void init()
{
  for (int i = 1; (1 << i) <= n; ++i)
  {
    for (int j = 1; j + (1 << i) <= n + 1; ++j)
    {
      st[i][j] = get(st[i - 1][j], st[i - 1][j + (1 << (i - 1))]);
    }
  }
}

int lca(int u, int v)
{
  if (dfn[u] > dfn[v]) { std::swap(u, v); }
  const int lg = lg2[dfn[v] - dfn[u]];
  return get(st[lg][dfn[u] + 1], st[lg][dfn[v] - (1 << lg) + 1]);
}

void dfs(const int u)
{
  dfn[u] = ++dfc; pos[dfc] = u;
  st[0][dfc] = fth[u][0];
  dep[u] = dep[fth[u][0]] + 1;
  for (int i = 1; (1 << i) <= dep[u]; ++i) { fth[u][i] = fth[fth[u][i - 1]][i - 1]; }
  for (int i = h[u]; i; i = e[i].x) { int v = e[i].v; dfs(v); }
}

N qry(int & u, const int fa)
{
  N res;
  int d = dep[u] - dep[fa] - 1;
  if (!d) { return res; }
  while (d)
  {
    int t = lg2[(d & -d)];
    res = res + f[u][t];
    u = fth[u][t];
    d ^= (1 << t);
  }
  return res;
}

int solve(int u, int v)
{
  if (u == v) { return 0; }
  if (dep[u] < dep[v]) { std::swap(u, v); }
  int fa = lca(u, v);
  if (fa == v)
  {
    N ru = qry(u, v);
    int ul = min(ru.f[0][0], ru.f[1][0]), ur = min(ru.f[0][1], ru.f[1][1]);
    return min(ul, ur);
  }
  N ru = qry(u, fa), rv = qry(v, fa);
  int ul = min(ru.f[0][0], ru.f[1][0]), ur = min(ru.f[0][1], ru.f[1][1]),
      vl = min(rv.f[0][0], rv.f[1][0]), vr = min(rv.f[0][1], rv.f[1][1]);
  return min(ul + vl + min(abs(lft[u] - lft[v]), deg[fa] + 1 - abs(lft[u] - lft[v])),
             ur + vr + min(abs(lft[u] - lft[v]), deg[fa] + 1 - abs(lft[u] - lft[v])),
             ul + vr + min(abs(lft[u] - (lft[v] + 1)), deg[fa] + 1 - abs(lft[u] - (lft[v] + 1))),
             ur + vl + min(abs(lft[v] - (lft[u] + 1)), deg[fa] + 1 - abs(lft[v] - (lft[u] + 1))));
}

void solve0()
{
  while (q--)
  {
    u = rd(); u ^= lst; 
    v = rd(); v ^= lst;
    lst = solve(u, v);
    printf("%d\n", lst);
  }
}

void solve1()
{
  srand((unsigned)s);
  while (q--)
  {
    u = (read() ^ lst) % n + 1, v = (read() ^ lst) % n + 1;
    lst = solve(u, v);
    sum += lst; xrs ^= lst;
  }
  printf("%d %lld", xrs, sum);
}

int main()
{
  n = rd(); q = rd(); s = rd();
  for (int i = 2; i <= n; ++i) { fth[i][0] = rd(); add(fth[i][0], i); }
  for (int i = 2; i <= n; ++i) { rht[i] = deg[fth[i][0]] - lft[i] - 1; }
  for (int i = 2; i <= n; ++i) { lg2[i] = lg2[i >> 1] + 1; }
  dfs(1); init();
  for (int i = 2; i <= n; ++i)
  {
    f[i][0].f[0][0] = min(lft[i], rht[i] + 2);
    f[i][0].f[0][1] = min(lft[i], rht[i]) + 1;
    f[i][0].f[1][0] = min(lft[i], rht[i]) + 1;
    f[i][0].f[1][1] = min(rht[i], lft[i] + 2);
  }
  for (int i = 2; i <= n; ++i)
  {
    for (int j = 1; (1 << j) <= dep[i]; ++j)
    {
      f[i][j] = f[i][j - 1] + f[fth[i][j - 1]][j - 1];
    }
  }
  s == -1 ? solve0() : solve1();
  return 0;
}