题解:AT_arc176_d [ARC176D] Swap Permutation

· · 题解

不知道为什么题解都是清一色的矩阵乘法,m 不是才 2e5 吗?

设初始序列为 a

考虑拆贡献,我们先计算 P_i=a_x,P_{i+1}=a_y 对答案贡献了多少次,分三种情况:

  1. x\in \{i,i+1\},y\in \{i,i+1\}
  2. x\notin \{i,i+1\},y\notin \{i,i+1\}

可以发现,把 x,y 看成 1,其他看成 0,可以发现,对于同一种情况的 x,y 对答案的贡献次数都是一样的。定义 dp_{a,b} 表示经过 a 次操作后 \{i, i+1\} 中有 b 个 1 的操作方案数。可以发现,这个 dp 是与 i 无关的,预处理即可,转移显然。(其他题解难道是对这个 dp 矩阵乘法?)

计算了每种情况对答案的贡献次数后,根据乘法分配律,接下来需要再算上 \sum |a_x-a_y| 即可。

对于情况一,只有一对 x,y,直接计算即可。

对于情况二和情况三,这里我直接不想动脑子了,对 x\notin \{i,i+1\}a_x 建立了一颗权值线段树,暴力计算贡献即可,不知道还有没有更简单的方法。

然后不断从 i, i+1 推到 i+1, i+2 就做完了。

复杂度 O(n\log n+m)

#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int MAXN = 2e5, mod = 998244353, inv2 = (mod+1)/2;
int n, m, a[MAXN+5];
ll dp[MAXN+5][3], all;
ll dfs(int i, int u) {
    if (i==m+1) return (u==2);
    if (!u&&n<=3) return 0;
    if (u==1&&n<=2) return 0;
    if (dp[i][u]!=-1) return dp[i][u];
    if (!u) return dp[i][u]=(dfs(i+1, 1)*4+dfs(i+1, 0)*((all-4)%mod))%mod;
    else if (u==1) return dp[i][u]=(dfs(i+1, 2)+dfs(i+1, 0)*(n-3)+dfs(i+1, 1)*((all-(n-2))%mod))%mod;
    else return dp[i][u]=(dfs(i+1, 1)*2*(n-2)+dfs(i+1, 2)*((all-2*(n-2))%mod))%mod;
}
ll sum[4*MAXN+5];
int num[4*MAXN+5];
void insert(int i, int l, int r, int q, int x) {
    if (l==r) {
        num[i]+=x;
        sum[i]+=x*l;
        return ;
    }
    int mid = l+r>>1;
    if (q<=mid) insert(i*2, l, mid, q, x);
    else insert(i*2+1, mid+1, r, q, x);
    sum[i] = sum[i*2]+sum[i*2+1];
    num[i] = num[i*2]+num[i*2+1];
    return ;
}
ll found1(int i, int l, int r, int q, int w) {
    if (q>w) return 0;
    if (l>=q&&r<=w) return sum[i];
    int mid = l+r>>1;
    ll ans = 0;
    if (q<=mid) ans+=found1(i*2, l, mid, q, w);
    if (w>mid) ans+=found1(i*2+1, mid+1, r, q, w);
    return ans;
}
int found2(int i, int l, int r, int q, int w) {
    if (q>w) return 0;
    if (l>=q&&r<=w) return num[i];
    int mid = l+r>>1, ans = 0;
    if (q<=mid) ans+=found2(i*2, l, mid, q, w);
    if (w>mid) ans+=found2(i*2+1, mid+1, r, q, w);
    return ans;
}
ll getans(int x) {
    return ((1ll*x*(2*found2(1, 1, n, 1, x)-(n-2))+found1(1, 1, n, x+1, n)-found1(1, 1, n, 1, x))%mod+mod)%mod;
}
int main() {
    memset(dp, -1, sizeof(dp));
    scanf("%d %d", &n, &m);
    all = 1ll*n*(n-1)/2;
    for (int p=1;p<=n;p++) {
        scanf("%d", &a[p]);
    }
    ll u = 0;
    for (int p=3;p<=n;p++) {
        insert(1, 1, n, a[p], 1);
    }
    for (int p=3;p<=n;p++) {
        u = (u+getans(a[p]))%mod;
    }
    u = u*inv2%mod;
    ll ans = 0;
    for (int p=2;p<=n;p++) {
        ans = (ans+dfs(1, 1)*(getans(a[p])+getans(a[p-1]))+dfs(1, 0)*u+dfs(1, 2)*abs(a[p]-a[p-1]))%mod;
        if (p!=n) {
            u = (u-getans(a[p+1])+mod)%mod;
            insert(1, 1, n, a[p+1], -1);
            insert(1, 1, n, a[p-1], 1);
            u = (u+getans(a[p-1]))%mod;
        }
    }
    cout<<ans;
    return 0;
}