题解:P14224 [ICPC 2024 Kunming I] 子数组
考虑
优化的话,令
实现的时候注意边界,式子有些地方可能需要
#include<bits/stdc++.h>
using namespace std;
const int M=998244353;
int T,n,a[400005],b[400005];
vector<int> G[400005];
int val[2000005],val2[2000005],val3[2000005],ans[2000005],len;
namespace NTT {
int pow(int x, int y) {
int res=1;
while(y) {
if(y&1) res=1ll*res*x%M;
x=1ll*x*x%M;
y>>=1;
}
return res;
}
int N,K,p[2000005];
void ntt(int* x, int inv) {
for(int i=0; i<N; i++) if(p[i]<i) swap(x[p[i]],x[i]);
for(int h=2; h<=N; h<<=1) {
int gn=pow(3,(M-1)/h);
for(int i=0; i<N; i+=h) {
int g=1;
for(int j=i; j<i+h/2; j++, g=1ll*g*gn%M) {
int u=x[j], v=1ll*x[j+h/2]*g%M;
x[j]=(u+v)%M, x[j+h/2]=(u-v+M)%M;
}
}
}
if(inv) {
reverse(x+1,x+N);
int invn=pow(N,M-2);
for(int i=0; i<N; i++) x[i]=1ll*x[i]*invn%M;
}
}
void calc() {
if(len<=10) {
for(int i=0; i<len*2; i++) val3[i]=0;
for(int i=0; i<len; i++)
for(int j=0; j<len; j++)
val3[i+j]=(val3[i+j]+1ll*val[i]*val2[j])%M;
return;
}
N=1, K=0;
while(N<len*2) N<<=1, K++;
for(int i=1; i<=N; i++) p[i]=(p[i>>1]>>1)+((i&1)<<(K-1));
for(int i=len; i<N; i++) val[i]=val2[i]=0;
ntt(val,0), ntt(val2,0);
for(int i=0; i<N; i++) val3[i]=1ll*val[i]*val2[i]%M;
ntt(val3,1);
}
}
int st[1000005][19];
void initst() {
for(int i=1; i<=n; i++) st[i][0]=a[i];
for(int i=1; i<=18; i++)
for(int j=1; j<=n; j++)
st[j][i]=max(st[j][i-1], st[j+(1<<i-1)][i-1]);
}
int getst(int x, int y) {
int len=y-x+1, lg2=31-__builtin_clz(len);
return max(st[x][lg2], st[y-(1<<lg2)+1][lg2]);
}
void solve(int l, int r) {
if(l>r) return;
int mx=getst(l,r);
len=0;
int posl=lower_bound(G[mx].begin(), G[mx].end(), l)-G[mx].begin();
int posr=upper_bound(G[mx].begin(), G[mx].end(), r)-G[mx].begin()-1;
val[len++]=G[mx][posl]-l+1;
for(int i=posl; i<posr; i++) val[len++]=G[mx][i+1]-G[mx][i];
val[len++]=r-G[mx][posr]+1;
// printf("l=%d r=%d\n",l,r);
// for(int i=0; i<len; i++) cout<<val[i]<<' ';
// cout<<'\n';
for(int i=0; i<len; i++) val2[i]=val[len-1-i];
NTT::calc();
for(int i=1; i<len; i++) ans[i]=(ans[i]+val3[len-1-i])%M;
solve(l,G[mx][posl]-1);
for(int i=posl; i<posr; i++) solve(G[mx][i]+1,G[mx][i+1]-1);
solve(G[mx][posr]+1,r);
}
int main() {
ios::sync_with_stdio(0), cin.tie(0);
// freopen("in","r",stdin);
cin>>T;
while(T--) {
cin>>n;
for(int i=1; i<=n; i++) ans[i]=0;
for(int i=1; i<=n; i++) cin>>a[i], b[i]=a[i];
sort(b+1,b+1+n);
int m=unique(b+1,b+1+n)-b-1;
for(int i=1; i<=n; i++) a[i]=lower_bound(b+1,b+1+m,a[i])-b;
for(int i=1; i<=m; i++) G[i].clear();
for(int i=1; i<=n; i++) G[a[i]].push_back(i);
initst();
solve(1,n);
int anss=0;
for(int i=1; i<=n; i++) anss+=1ll*i*ans[i]%M*ans[i]%M, anss%=M;
cout<<anss<<'\n';
}
return 0;
}