「题解」CF1442E Black, White and Grey Tree

AtomAlpaca

发现我的 dp 方式和别人都不一样啊。当场决定发篇题解。

首先考虑没有灰点怎么做。观察样例可以很轻易得出一个错误的解法:考虑一次删掉所有白/黑点,然后依次删掉剩下的所有联通块。

但我们考虑这样一个 hack:

image

按照上述的做法我们至少需要删除至少四次,但实际上我们最少只需要三次就可以了。

顺着这个 hack 往下想。我们最优的策略应该是钦定某一个点为根之后,将所有链按照颜色分层,一条链层数增加当且仅当一条边的两个点颜色不同。然后从深层到浅层依次删除,令最长的一条链最短时就能够使得答案最小。不难发现这时的根节点就是树的“直径”的中点,层数为 $ + 1$。

然后考虑加入灰点。不难发现其实灰点可以视作白点和黑点的任意一种,我们只需要将 dp 求树直径的方法稍微改造一下即可。令 fu,x,1/2f_{u,x,1/2} 为节点 uu 在颜色为 xx 时的第一/二长链,其中第一、第二长链没有公共边。令 uu 为树上 vv 的父亲,当 uu 是灰点时,分别视作黑/白点转移即可;当 vv 是灰点时,用 vv 分别为黑/白点时的最小值更新 uu,简单分讨一下即可。具体的实现方法不是很好描述,建议参考代码。

#include <bits/stdc++.h>

const int MAX = 2e5 + 5;

int T, n, tot, ans, u, v;
int h[MAX], a[MAX], f[MAX][3][3];
struct E { int v, x; } e[MAX << 2];

void clear()
{
  tot = ans = 0;
  for (int i = 1; i <= n; ++i) { h[i] = 0; f[i][1][1] = f[i][1][2] = f[i][2][1] = f[i][2][2] = 0; }
}

void add(int u, int v)
{
  e[++tot] = {v, h[u]}; h[u] = tot;
  e[++tot] = {u, h[v]}; h[v] = tot;
}

void work(int u, int x, int v)
{
  if (v > f[u][x][1]) { f[u][x][2] = f[u][x][1]; f[u][x][1] = v; }
  else if (v > f[u][x][2]) { f[u][x][2] = v; }
}

void dfs(int u, int fa)
{
  for (int i = h[u]; i; i = e[i].x)
  {
    int v = e[i].v; if (v == fa) { continue; } dfs(v, u);
    if (a[u])
    {
      if (a[v]) { work(u, a[u], f[v][a[v]][1] + (a[u] != a[v])); }
      else { work(u, a[u], std::min(f[v][1][1] + (a[u] != 1), f[v][2][1] + (a[u] != 2))); }
    }
    else
    {
      if (a[v])
      {
        work(u, 1, f[v][a[v]][1] + (a[v] != 1));
        work(u, 2, f[v][a[v]][1] + (a[v] != 2));
      }
      else
      {
        work(u, 1, std::min(f[v][1][1], f[v][2][1] + 1));
        work(u, 2, std::min(f[v][1][1] + 1, f[v][2][1]));
      }
    }
  }
}

void solve()
{
  scanf("%d", &n);
  clear();
  for (int i = 1; i <= n; ++i) { scanf("%d", &a[i]); }
  for (int i = 1; i <  n; ++i) { scanf("%d%d", &u, &v); add(u, v); }
  dfs(1, 0);
  for (int i = 1; i <= n; ++i)
  {
    if (a[i]) { ans = std::max(ans, f[i][a[i]][1] + f[i][a[i]][2]); }
    else { ans = std::max(ans, std::min(f[i][1][1] + f[i][1][2], f[i][2][1] + f[i][2][2])); }
    }
  printf("%d\n", ((ans + 1) >> 1) + 1);
}

int main()
{
  scanf("%d", &T); while (T--) { solve(); }
}