「题解」AGC023E Inversions

AtomAlpaca

Table of contents

题意

给定一个长度为 nn 的数组 aa,问对于所有满足 i,piai\forall i, p_i \le a_i 的排列 pp 的逆序对个数和。

题解

首先令 bbaa 从小到大排序的结果,rkirk_i 表示 aia_i 的排名,idiid_i 表示排名为 ii 的数在 aa 中的下标。

那么所有符合条件的排列个数为 cnt=i=1nairki+1cnt = \prod_{i=1}^{n}{a_i-rk_i+1}。因为我们从小到大考虑每个排名 xx,那么它前面已经被选了 x1x - 1aidxa_{id_x} 范围内的数,那么这个位置能选择的数就有 aidxx+1a_{id_x} - x + 1 个。

那么接下来我们依然从小到大考虑排名。假设我们考虑到排名 ii,那么对于每个 j<ij < i,我们分情况讨论:

首先如果 idi>idjid_i > id_j,那么我们 aidia_{id_i}aidja_{id_j} 大的部分并不会产生贡献,我们直接扔掉不考虑。那么我们把 aidia_{id_i} 削到和 aidja_{id_j} 相等,此时两个位置上选择方案有 (aidjj+1)(aidjj)(a_{id_j} - j + 1)(a_{id_j} - j) 种,其中只有一半是形成逆序对的。

然后我们发现如果强行让它的范围缩小,那么对所有满足 j<k<ij < k < ikk,我们可以选择的数其实是减少了一个。于是我们得到这一对点的贡献是:

(aidjj+1)(aidjj)2cnt(aidii+1)(aidjj+1)k=j+1i1aidkkaidkk+1\frac{(a_{id_j} - j + 1)(a_{id_j} - j)}{2} \frac{cnt}{(a_{id_i} - i + 1)(a_{id_j} - j + 1)} \prod_{k = j + 1}^{i - 1}{\frac{a_{id_k} - k}{a_{id_k} - k + 1}}

如果 idi<idjid_i < id_j,我们考虑直接求产生顺序对的数量,这和上面的式子是一样的。只需要用总排列数减去即可。

考虑维护最后一项的前缀积是可以轻松做到 O(n2)O(n^2) 的,但是这并不足够通过此题。我们考虑把上面的式子化简下。

(aidjj+1)(aidjj)2cnt(aidii+1)(aidjj+1)k=j+1i1aidkkaidkk+1=cnt2(aidii+1)(aidjj)k=j+1i1aidkkaidkk+1\begin{aligned} &\frac{(a_{id_j} - j + 1)(a_{id_j} - j)}{2} \frac{cnt}{(a_{id_i} - i + 1)(a_{id_j} - j + 1)} \prod_{k = j + 1}^{i - 1}{\frac{a_{id_k} - k}{a_{id_k} - k + 1}} \\ =&\frac{cnt}{2(a_{id_i} - i + 1)} (a_{id_j} - j)\prod_{k = j + 1}^{i - 1}{\frac{a_{id_k} - k}{a_{id_k} - k + 1}} \end{aligned} 最前面的一项是只和 ii 有关的,我们只需要维护后面的两项。

考虑我们每次枚举 ii 之后对后面的影响其实就是把最后一项乘了个 aidiiaidii+1\dfrac{a_{id_i} - i}{a_{id_i} - i + 1},然后要多考虑一个 ii。那么我们考虑用一个支持全局乘、单点加、区间和的线段树,每个位置 ii 记录 idiid_i 的后面两项的值。

这样我们的答案就是 qry(1,idi)+p×cntqry(idi,n)qry(1, id_i) + p \times cnt - qry(id_i, n),其中 ppj>ij > i 的点数。每次统计完 ii 的答案全局乘 aidiiaidii+1\dfrac{a_{id_i} - i}{a_{id_i} - i + 1} 单点加 aidi=ja_{id_i} = j 即可。

另外,我们用线段树只能得到所有 j>ij > i 的答案,并不知道有多少个,因此要类似二维数点地开一个树状数组维护。

这样我们就可以以 O(nlogn)O(n \log n) 的时间复杂度解决这个问题。

代码

#include <bits/stdc++.h>

typedef long long ll;
const int MAX = 2e5 + 5;
const int MOD = 1e9 + 7;

ll n, tot, ans;
ll a[MAX], rk[MAX], id[MAX];

struct N { ll x, y; } b[MAX];
bool cmp(N a, N b) { return (a.x == b.x ? a.y < b.y : a.x < b.x); }

ll qp(ll a, ll x)
{
    ll res = 1;
    while (x) { if (x & 1) { res = res * a % MOD; } a = a * a % MOD; x >>= 1; }
    return res;
}

ll inv(ll x) { return qp(x, MOD - 2); }

struct BIT
{
    int t[MAX];
    inline int lbt(int x) { return x & -x; }
    void mdf(int x, int c) { while (x <= n) { t[x] += c; x += lbt(x); } }
    int qry(int x) { int res = 0; while (x) { res += t[x]; x -= lbt(x); } return res; }
} bt;

struct SGT
{
    ll st[MAX << 2], tg[MAX << 2];
    void init() { for (int i = 1; i <= 4 * n; ++i) { tg[i] = 1; } }
    inline void pd(int x)
    {
        st[x << 1] = st[x << 1] * tg[x] % MOD; st[x << 1 | 1] = st[x << 1 | 1] * tg[x] % MOD;
        tg[x << 1] = tg[x << 1] * tg[x] % MOD; tg[x << 1 | 1] = tg[x << 1 | 1] * tg[x] % MOD;
        tg[x] = 1;
    }
    inline void pu(int x) { st[x] = (st[x << 1] + st[x << 1 | 1]) % MOD; }
    void mdf0(ll v) { st[1] = st[1] * v % MOD; tg[1] = tg[1] * v % MOD; }
    void mdf1(int l, int r, int s, ll c, int x)
    {
        if (l == r and l == s) { st[x] += c; return ; }
        pd(x); int k = l + ((r - l) >> 1);
        if (k >= s) { mdf1(l, k, s, c, x << 1); }
        else { mdf1(k + 1, r, s, c, x << 1 | 1); }
        pu(x);
    }
    ll qry(int l, int r, int s, int t, int x)
    {
        if (l >= s and r <= t) { return st[x]; }
        pd(x); int k = l + ((r - l) >> 1); ll res = 0;
        if (k >= s) { res = (res + qry(l, k, s, t, x << 1)) % MOD; }
        if (k <  t) { res = (res + qry(k + 1, r, s, t, x << 1 | 1)) % MOD; }
        return res;
    }
} st;

int main()
{
    scanf("%lld", &n); tot = 1; st.init();
    for (int i = 1; i <= n; ++i) { scanf("%lld", &a[i]); b[i] = {a[i], i}; }
    std::sort(b + 1, b + n + 1, cmp);
    for (int i = 1; i <= n; ++i) { id[i] = b[i].y; rk[id[i]] = i; }
    for (int i = 1; i <= n; ++i)
    {
        if (b[i].x - i + 1 <= 0) { printf("0"); return 0; }
        tot = tot * (b[i].x - i + 1) % MOD;
    }
    for (int i = 1; i <= n; ++i)
    {
        ll res1 = st.qry(1, n, 1, id[i], 1), res2 = st.qry(1, n, id[i], n, 1), cnt = bt.qry(n) - bt.qry(id[i]);
        ll tmp = inv(2) * inv(a[id[i]] - i + 1) % MOD * tot % MOD;
        ans = (ans + tmp * (res1 - res2 + MOD) % MOD) % MOD;
        ans = (ans + cnt * tot % MOD) % MOD;
        st.mdf0((a[id[i]] - i) * inv(a[id[i]] - i + 1) % MOD);
        st.mdf1(1, n, id[i], a[id[i]] - i, 1);
        bt.mdf(id[i], 1);
    }
    printf("%lld", ans);
    return 0;
}