题解:AT_arc176_d [ARC176D] Swap Permutation
StarPatrick · · 题解
不知道为什么题解都是清一色的矩阵乘法,
设初始序列为
考虑拆贡献,我们先计算
-
x\in \{i,i+1\},y\in \{i,i+1\} -
-
x\notin \{i,i+1\},y\notin \{i,i+1\}
可以发现,把
计算了每种情况对答案的贡献次数后,根据乘法分配律,接下来需要再算上
对于情况一,只有一对
对于情况二和情况三,这里我直接不想动脑子了,对
然后不断从
复杂度
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int MAXN = 2e5, mod = 998244353, inv2 = (mod+1)/2;
int n, m, a[MAXN+5];
ll dp[MAXN+5][3], all;
ll dfs(int i, int u) {
if (i==m+1) return (u==2);
if (!u&&n<=3) return 0;
if (u==1&&n<=2) return 0;
if (dp[i][u]!=-1) return dp[i][u];
if (!u) return dp[i][u]=(dfs(i+1, 1)*4+dfs(i+1, 0)*((all-4)%mod))%mod;
else if (u==1) return dp[i][u]=(dfs(i+1, 2)+dfs(i+1, 0)*(n-3)+dfs(i+1, 1)*((all-(n-2))%mod))%mod;
else return dp[i][u]=(dfs(i+1, 1)*2*(n-2)+dfs(i+1, 2)*((all-2*(n-2))%mod))%mod;
}
ll sum[4*MAXN+5];
int num[4*MAXN+5];
void insert(int i, int l, int r, int q, int x) {
if (l==r) {
num[i]+=x;
sum[i]+=x*l;
return ;
}
int mid = l+r>>1;
if (q<=mid) insert(i*2, l, mid, q, x);
else insert(i*2+1, mid+1, r, q, x);
sum[i] = sum[i*2]+sum[i*2+1];
num[i] = num[i*2]+num[i*2+1];
return ;
}
ll found1(int i, int l, int r, int q, int w) {
if (q>w) return 0;
if (l>=q&&r<=w) return sum[i];
int mid = l+r>>1;
ll ans = 0;
if (q<=mid) ans+=found1(i*2, l, mid, q, w);
if (w>mid) ans+=found1(i*2+1, mid+1, r, q, w);
return ans;
}
int found2(int i, int l, int r, int q, int w) {
if (q>w) return 0;
if (l>=q&&r<=w) return num[i];
int mid = l+r>>1, ans = 0;
if (q<=mid) ans+=found2(i*2, l, mid, q, w);
if (w>mid) ans+=found2(i*2+1, mid+1, r, q, w);
return ans;
}
ll getans(int x) {
return ((1ll*x*(2*found2(1, 1, n, 1, x)-(n-2))+found1(1, 1, n, x+1, n)-found1(1, 1, n, 1, x))%mod+mod)%mod;
}
int main() {
memset(dp, -1, sizeof(dp));
scanf("%d %d", &n, &m);
all = 1ll*n*(n-1)/2;
for (int p=1;p<=n;p++) {
scanf("%d", &a[p]);
}
ll u = 0;
for (int p=3;p<=n;p++) {
insert(1, 1, n, a[p], 1);
}
for (int p=3;p<=n;p++) {
u = (u+getans(a[p]))%mod;
}
u = u*inv2%mod;
ll ans = 0;
for (int p=2;p<=n;p++) {
ans = (ans+dfs(1, 1)*(getans(a[p])+getans(a[p-1]))+dfs(1, 0)*u+dfs(1, 2)*abs(a[p]-a[p-1]))%mod;
if (p!=n) {
u = (u-getans(a[p+1])+mod)%mod;
insert(1, 1, n, a[p+1], -1);
insert(1, 1, n, a[p-1], 1);
u = (u+getans(a[p-1]))%mod;
}
}
cout<<ans;
return 0;
}