题解 CF755G 【PolandBall and Many Other Balls】
猪脑子
·
·
题解
原来的题解只讲了基本思路,并没有讲太多,虽然我还是秒懂了(逃)
因此我就来讲一下这个题怎么做吧~
首先你得看出这是个DP。
我们可以设一个状态:dp(i,j) 表示前 i 个球,分出 j 组的方案数。
我们可以考虑两种转移方式:
第一种方式:我们考虑第 i 个球如何分组,那么有三种方法,分别是 1、单独一组 2、和前一个分到一组 3、扔掉不管。
那么有如下递推式:
dp(i,j)=dp(i-1,j-1)+dp(i-2,j-1)+dp(i-1,j)
其中, dp(i-1,j) 表示不选,剩下两个表示选。
也就是说,我们可以从 dp(i-1),dp(i-2) 推出 dp(i)。
第二种方式:我们考虑将两段球“拼起来”。
假如说我们已经算出了 dp(a),dp(b),我们将 a 个球放在前面,b 个球放在后面,那么我们考虑两种情形:
1、从前 a 个球中分若干组,从后 b 个球中分若干组,两者互不干扰,那么转移到 dp(a+b,j) 的有:
\sum_{k=0}^{j} dp(a,k) * dp(b,j-k)
2、第 a,a+1 个分到一组,前 a-1 与后 b-1 个分若干组,那么转移到 dp(a+b,j) 的有:
\sum_{k=0}^{j-1} dp(a-1,k) * dp(b-1,j-k-1)
观察这两个式子,我们发现如果我们将 dp(i) 看做一个多项式,其中 dp(i,j) 是它的 j 次项系数,那么刚刚两个式子是不是就是多项式的乘法?
那么思路就来了~
假如说我们知道了 dp(x),dp(x-1),dp(x-2)(这里的 dp(x) 指的是多项式),那么我们就可以用上面的第二种方式推出
$dp(x*2-1)=dp(x+(x-1))$,
$dp(x*2-2)=dp((x-1)+(x-1))$。
其中多项式乘法我们可以用NTT在 $O(k \log_2 k)$ 的时间内求出,其中 $k$ 就是题目中的 $k$。
我们考虑将 $n$ 二进制拆分,并且从高位向低位枚举每一位。每一次都先按照上面的方法用已经求出的 $dp(x)$ 求出 $dp(2*x)$,然后如果该位为 $1$ ,那么就用最开始讲的第一种方式在 $O(k)$ 的时间内用 $x,x-1$ 推出 $x+1$。
思路讲的差不多了,那么就上代码吧:
~~在此献上本蒟蒻又丑又慢的代码……~~
```cpp
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define Mod 998244353
using namespace std;
long long fpow(long long a,long long b)
{
long long ans=1,t=a;
while(b)
{
if(b&1)ans=ans*t%Mod;
t=t*t%Mod;
b>>=1;
}
return ans;
}
long long inv(long long a)
{
return fpow(a,Mod-2);
}
int rev[150010];
void getrev(int bit)
{
for(register int i=1;i<(1<<bit);i++)
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
return ;
}
void ntt(long long* a,int n,int dft)//就是一个ntt
{
for(register int i=0;i<n;i++)
if(i<rev[i])swap(a[i],a[rev[i]]);
for(register int step=1;step<n;step<<=1)
{
long long wn=fpow(3,(Mod-1)/(step<<1));
if(dft==-1)wn=inv(wn);
for(register int i=0;i<n;i+=(step<<1))
{
long long wnk=1;
for(register int j=i;j<i+step;j++)
{
long long x=a[j];
long long y=wnk*a[j+step]%Mod;
a[j]=(x+y)%Mod;
a[j+step]=(x-y+Mod)%Mod;
wnk=wnk*wn%Mod;
}
}
}
if(dft==-1)
{
long long invn=inv(n);
for(int i=0;i<n;i++)
a[i]=a[i]*invn%Mod;
}
return ;
}
void cpy(long long* b,long long* a,int n)//复制多项式……
{
for(int i=0;i<n;i++)
b[i]=a[i];
return ;
}
//ans表示当前的答案,ans(i)即dp(x-i)
long long ans[3][150010];
long long t[3][150010];
long long tmp[3][150010];
int n,k;
//I wanna 最好玩了
void IwannaMUL()//用x推出2*x
{
int s=1,bit=0;
while(s<(k<<1))s<<=1,bit++;
getrev(bit);
for(int i=0;i<=2;i++)
ntt(ans[i],s,1);
for(int i=0;i<=2;i++)
{
//这里就是计算乘积的地方,tmp即a*b,t即(a-1)*(b-1)
for(int j=0;j<s;j++)
tmp[i][j]=ans[i/2][j]*ans[i-i/2][j]%Mod;
for(int j=0;j<s;j++)
t[i][j]=ans[i/2+1][j]*ans[i-i/2+1][j]%Mod;
}
for(int i=0;i<=2;i++)
{
ntt(tmp[i],s,-1);
ntt(t[i],s,-1);
for(int j=1;j<s;j++)
tmp[i][j]=(tmp[i][j]+t[i][j-1])%Mod;
}
for(int i=0;i<=2;i++)
{
cpy(ans[i],tmp[i],k);
for(int j=k;j<s;j++)
ans[i][j]=0;
}
return ;
}
void IwannaADD()//用x推出x+1
{
cpy(ans[2],ans[1],k);
cpy(ans[1],ans[0],k);
for(int i=0;i<k;i++)
{
//那个递推式
ans[0][i]=ans[1][i];
if(i>0)ans[0][i]=(ans[0][i]+ans[1][i-1]+ans[2][i-1])%Mod;
}
return ;
}
int main()
{
scanf("%d %d",&n,&k);
k++;
ans[0][0]=1;
int s=0;
for(int i=1<<30;i;i>>=1)
{
if(s)IwannaMUL(),s<<=1;
if(n&i)IwannaADD(),s++;
}
for(int i=1;i<k;i++)
printf("%d ",ans[0][i]);
return ~~(0-0);//点个赞呗~\(≧▽≦)/~
}
```