「题解」ARC167C MST on Line++

AtomAlpaca

Table of contents

题意

给定正整数 n,Kn,K 和一个长度为 nn 的序列 AA。对于一个 1n1\sim n 的排列 PP,我们定义 f(P)f(P) 为以下问题的答案:

给一个 nn 个点的无向带权图,对于两点 i<ji<j,当且仅当 jiKj-i\le K 时,它们之间有边,边权为 max(APi,APj)\max(A_{P_i},A_{P_j})
求这个图的最小生成树边权和。

对于所有可能的排列 PP,求出它们的 f(P)f(P) 之和,答案对 998244353998\,244\,353 取模。

1K<N50001\le K< N\le 50001Ai1091\le A_i \le 10^9

奇妙深刻数数题。

题解

首先发现因为要枚举排列,AA 的顺序是无关紧要的,可以升序排序处理。

考虑拆贡献,考虑对于每一个 AiA_i 计算它在所有情况中被选的次数 fif_i。答案就是 infiAi\sum_{i}^{n}{f_i A_i}

但是这样还是不好求,我们考虑用小于等于 AiA_i 的边权被选的次数减去小于等于 Ai1A_{i - 1} 的边权的被选次数。这相当于构造了这样一个问题:

对于两点 i<ji<j,当且仅当 jiKj-i\le Kmax(APi,APj)Ax\max(A_{P_i},A_{P_j}) \le A_x 时,它们之间有边。求最多选择多少条边使得构成的图没有环,对于所有排列 PP 求和。

因为取 max\max,所有 APi>xA_{P_i} > xii 都不会被选,被选的只有小于 AxA_x 的所有位置。

再考虑 jiKj - i \le K 这个条件怎么做。我们这里不妨对于一个所有选出来的数的集合有序 QQ,考虑钦定 QjQj1=KQ_j - Q_{j - 1} = K,这样的方案一共有 (nkx1)\binom{n-k}{x-1} 种,小于等于 KK 的情况就有 i=1K(nkx1)\sum_{i = 1}^{K}{\binom{n-k}{x-1}} 种。一共有 x1x-1 个这样的位置,那么贡献总和就是 (x1)i=1K(nix1)(x - 1)\sum_{i = 1}^{K}{\binom{n-i}{x-1}}

又考虑和一个 QQ 对应的排列的排列方式应该是选出来的 xx 个随便排,剩下的 nxn - x 个随便排并在一起,方案数应该是 x!(nx)!x!(n - x)!

所以上述问题的答案就是 x!(nx)!(x1)i=1K(nix1)x!(n - x)!(x - 1)\sum_{i = 1}^{K}{\binom{n-i}{x-1}}。相邻的两项相减就能得到 ff

代码

#include <bits/stdc++.h>

typedef long long ll;
const int MAX = 5005;
const int MOD = 998244353;

int n, k; ll ans;
ll frc[MAX], ifrc[MAX], g[MAX], a[MAX];

ll qp(ll a, ll x)
{
    ll res = 1;
    while (x) { if (x & 1) { res = res * a % MOD; } x >>= 1; a = a * a % MOD; }
    return res;
}

void init(int x)
{
    frc[0] = ifrc[0] = 1;
    for (int i = 1; i <= x; ++i) { frc[i] = frc[i - 1] * i % MOD; } ifrc[x] = qp(frc[x], MOD - 2);
    for (int i = x - 1; i >= 1; --i) { ifrc[i] = ifrc[i + 1] * (i + 1) % MOD; }
}

ll C(ll x, ll y) { if (x < y) { return 0; } return frc[x] * ifrc[y] % MOD * ifrc[x - y] % MOD; }

int main()
{
    scanf("%d%d", &n, &k); init(5000);
    for (int i = 1; i <= n; ++i) { scanf("%lld", &a[i]); }
    std::sort(a + 1, a + n + 1);
    for (int i = 1; i <= n; ++i)
    {
        ll tmp = 0;
        g[i] = frc[i] * frc[n - i] % MOD * (i - 1) % MOD;
        for (int j = 1; j <= k; ++j)
        {
            tmp = (tmp + C(n - j, i - 1)) % MOD;
        }
        g[i] = g[i] * tmp % MOD;
    }
    for (int i = 1; i <= n; ++i) { ans = (ans + (g[i] - g[i - 1] + MOD) % MOD * a[i] % MOD) % MOD; }
    printf("%lld", ans);
    return 0;
}