题解 P4221 【[WC2018]州区划分】
Great_Influence
2018-05-21 09:43:53
来一发FMT的题解。
这道题目其实并不难,但是考的考点非常偏。。。(松松松:这次没有卡常了)
可以发现,这道题显然可以用dp来解决。设$dp[S]$表示已经分了$S$这个子集的答案,$sum[S]$为$S$中元素的权值之和,$g[S]$为$sum[S]$乘以该子集是否没有欧拉回路,则存在以下转移式:
$\displaystyle dp[S]=\frac{1}{sum[S]^p}\sum_{T\in S}g[T]*dp[S-T]$
发现满足子集卷积形式。直接枚举子集转移为$3^nn^2$的,无法通过此题。但是利用$FWT$或者$FST$或者$FMT$都可以优化成$2^nn^2$,便可以通过了。
$FMT$为快速莫比乌斯变换,正操作可以将一个子集的权值转为某种点值形式(和$DFT$差不多),逆操作则是将正操作得到的点值转回子集权值(和$IDFT$差不多)。转成的点值可以加减乘除,对应子集的加减乘除。对于本题,只要打一个子集乘就可以了。
记得至少需要转移$n-1$次,因为最多可以划出$n-1$个州。
代码:
```cpp
#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#define Rep(i,a,b) for(register int i=(a),i##end=(b);i<=i##end;++i)
#define Repe(i,a,b) for(register int i=(a),i##end=(b);i>=i##end;--i)
#define For(i,a,b) for(i=(a),i<=(b);++i)
#define Forward(i,a,b) for(i=(a),i>=(b);--i)
template<typename T>inline void read(T &x)
{
T f=1;x=0;char c;
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-1;
for(;isdigit(c);c=getchar())x=x*10+(c^48);
x*=f;
}
using namespace std;
static int n,m,P,w[31];
static int p[31];
inline void init()
{
read(n);read(m);read(P);
static int u,v;
Rep(i,1,m)read(u),read(v),p[u]|=1<<v-1,p[v]|=1<<u-1;
Rep(i,1,n)read(w[i]);
}
const int mod=998244353;
typedef long long ll;
inline int power(int a,int b)
{
static int sm;
for(sm=1;b;b>>=1,a=(ll)a*a%mod)if(b&1)
sm=(ll)sm*a%mod;
return sm;
}
static int Len;
inline int ad(int u,int v){return (u+=v)>=mod?u-mod:u;}
inline void FMT(int *a)
{
for(register int z=1;z<Len;z<<=1)
Rep(j,0,Len-1)if(z&j)a[j]=ad(a[j],a[j^z]);
}
inline void IFMT(int *a)
{
for(register int z=1;z<Len;z<<=1)
Rep(j,0,Len-1)if(z&j)a[j]=ad(a[j],mod-a[j^z]);
}
static int bit[1<<21],S[1<<21],fa[23],inv[1<<21];
int find(int x){return x==fa[x]?x:fa[x]=find(fa[x]);}
static int dp[22][1<<21],g[22][1<<21];
static int sta[23],tp;
inline int check(int x)
{
if(bit[x]==1)return 0;tp=0;
static int now,u,v;
Rep(i,1,n)if(x&(1<<i-1))
{
sta[++tp]=i;
now=x&p[i];
if(!now||bit[now]&1)return 1;
}
Rep(i,1,tp)fa[sta[i]]=sta[i];
Rep(i,1,tp)Rep(j,1,tp)if(p[sta[i]]&(1<<sta[j]-1))
if((u=find(sta[i]))^(v=find(sta[j])))fa[u]=v;
Rep(i,1,tp)if(find(sta[i])^fa[sta[1]])return 1;
return 0;
}
inline int calc(int x)
{
if(!P)return 1;
static int sm;sm=0;
Rep(i,1,n)if(x&(1<<i-1))sm=ad(sm,w[i]);
if(P==1)return sm;
return sm*sm;
}
inline void solve()
{
static int sm,flag,u,v;
Len=1<<n;
Rep(i,1,Len-1)bit[i]=bit[i>>1]+(i&1);
Rep(i,1,Len-1)
{
S[i]=calc(i);
g[bit[i]][i]=check(i)?S[i]:0;
inv[i]=power(S[i],mod-2);
}
dp[0][0]=1;FMT(dp[0]);
Rep(i,1,n)FMT(g[i]);
Rep(i,1,n)
{
int *f=dp[i];
Rep(j,0,i-1)
{
int *a=dp[j],*b=g[i-j];
Rep(k,0,Len-1)f[k]=ad((ll)a[k]*b[k]%mod,f[k]);
}
IFMT(f);
Rep(k,0,Len-1)f[k]=i==bit[k]?(ll)f[k]*inv[k]%mod:0;
if(i^n)FMT(f);
}
printf("%d\n",dp[n][Len-1]);
}
int main()
{
init();
solve();
return 0;
}
```