题解 P4221 【[WC2018]州区划分】

Great_Influence

2018-05-21 09:43:53

Solution

来一发FMT的题解。 这道题目其实并不难,但是考的考点非常偏。。。(松松松:这次没有卡常了) 可以发现,这道题显然可以用dp来解决。设$dp[S]$表示已经分了$S$这个子集的答案,$sum[S]$为$S$中元素的权值之和,$g[S]$为$sum[S]$乘以该子集是否没有欧拉回路,则存在以下转移式: $\displaystyle dp[S]=\frac{1}{sum[S]^p}\sum_{T\in S}g[T]*dp[S-T]$ 发现满足子集卷积形式。直接枚举子集转移为$3^nn^2$的,无法通过此题。但是利用$FWT$或者$FST$或者$FMT$都可以优化成$2^nn^2$,便可以通过了。 $FMT$为快速莫比乌斯变换,正操作可以将一个子集的权值转为某种点值形式(和$DFT$差不多),逆操作则是将正操作得到的点值转回子集权值(和$IDFT$差不多)。转成的点值可以加减乘除,对应子集的加减乘除。对于本题,只要打一个子集乘就可以了。 记得至少需要转移$n-1$次,因为最多可以划出$n-1$个州。 代码: ```cpp #include<bits/stdc++.h> #include<ext/pb_ds/assoc_container.hpp> #define Rep(i,a,b) for(register int i=(a),i##end=(b);i<=i##end;++i) #define Repe(i,a,b) for(register int i=(a),i##end=(b);i>=i##end;--i) #define For(i,a,b) for(i=(a),i<=(b);++i) #define Forward(i,a,b) for(i=(a),i>=(b);--i) template<typename T>inline void read(T &x) { T f=1;x=0;char c; for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-1; for(;isdigit(c);c=getchar())x=x*10+(c^48); x*=f; } using namespace std; static int n,m,P,w[31]; static int p[31]; inline void init() { read(n);read(m);read(P); static int u,v; Rep(i,1,m)read(u),read(v),p[u]|=1<<v-1,p[v]|=1<<u-1; Rep(i,1,n)read(w[i]); } const int mod=998244353; typedef long long ll; inline int power(int a,int b) { static int sm; for(sm=1;b;b>>=1,a=(ll)a*a%mod)if(b&1) sm=(ll)sm*a%mod; return sm; } static int Len; inline int ad(int u,int v){return (u+=v)>=mod?u-mod:u;} inline void FMT(int *a) { for(register int z=1;z<Len;z<<=1) Rep(j,0,Len-1)if(z&j)a[j]=ad(a[j],a[j^z]); } inline void IFMT(int *a) { for(register int z=1;z<Len;z<<=1) Rep(j,0,Len-1)if(z&j)a[j]=ad(a[j],mod-a[j^z]); } static int bit[1<<21],S[1<<21],fa[23],inv[1<<21]; int find(int x){return x==fa[x]?x:fa[x]=find(fa[x]);} static int dp[22][1<<21],g[22][1<<21]; static int sta[23],tp; inline int check(int x) { if(bit[x]==1)return 0;tp=0; static int now,u,v; Rep(i,1,n)if(x&(1<<i-1)) { sta[++tp]=i; now=x&p[i]; if(!now||bit[now]&1)return 1; } Rep(i,1,tp)fa[sta[i]]=sta[i]; Rep(i,1,tp)Rep(j,1,tp)if(p[sta[i]]&(1<<sta[j]-1)) if((u=find(sta[i]))^(v=find(sta[j])))fa[u]=v; Rep(i,1,tp)if(find(sta[i])^fa[sta[1]])return 1; return 0; } inline int calc(int x) { if(!P)return 1; static int sm;sm=0; Rep(i,1,n)if(x&(1<<i-1))sm=ad(sm,w[i]); if(P==1)return sm; return sm*sm; } inline void solve() { static int sm,flag,u,v; Len=1<<n; Rep(i,1,Len-1)bit[i]=bit[i>>1]+(i&1); Rep(i,1,Len-1) { S[i]=calc(i); g[bit[i]][i]=check(i)?S[i]:0; inv[i]=power(S[i],mod-2); } dp[0][0]=1;FMT(dp[0]); Rep(i,1,n)FMT(g[i]); Rep(i,1,n) { int *f=dp[i]; Rep(j,0,i-1) { int *a=dp[j],*b=g[i-j]; Rep(k,0,Len-1)f[k]=ad((ll)a[k]*b[k]%mod,f[k]); } IFMT(f); Rep(k,0,Len-1)f[k]=i==bit[k]?(ll)f[k]*inv[k]%mod:0; if(i^n)FMT(f); } printf("%d\n",dp[n][Len-1]); } int main() { init(); solve(); return 0; } ```