题解 P4516 【[JSOI2018]潜入行动】

· · 题解

不难想到一个沙雕dp,设f[i][j][0/1][0/1]表示当前点i,子树中一共放了j个,这个点是否放了,这个是否被覆盖了。
看起来直接合并是O(nk^2)的QwQ。。。。。
然后我以为是O(nk^2)的就不会做了嘤嘤嘤。
实际上是O(nk)的。。。
证明大概是这样的(为啥泥萌都不证明一下复杂度的啊):
考虑什么时候会产生O(k^2)的贡献,即一个点有两棵子树的大小大于k,而这样子合并次数不会超过O(\frac{n}{k}),所以这部分的复杂度是O(nk)的。
另外一种情况是一个子树小于k,经过合并之后变成大于k的子树了,显然对于一个点,如果它的子树小于k,在某次合并之后它的子树就会大于k,并且对于每个点而言,只会在他的某个祖先的地方经历一次这样子的合并,所以这样子均摊每个点会产生O(k)的贡献。
第三种情况是两个点的子树大小都小于k,合并完之后两者还是小于k。这个操作理解为每个两个集合中的点一一对应的产生一次贡献,那么盯着某一个特定点考虑,它每次产生的贡献是合并进来的子树大小的,因为在这一部分的过程中子树大小总是小于k,因此每个点产生的贡献也最多是O(k)的。
综上,在三种合并情况中,每种情况产生的贡献都最多是O(nk)的,所以全局的复杂度就是O(nk)

#include<iostream>
#include<cstdio>
using namespace std;
#define MAX 100100
#define MOD 1000000007
void add(int &x,int y){x+=y;if(x>=MOD)x-=MOD;}
inline int read()
{
    int x=0;bool t=false;char ch=getchar();
    while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
    if(ch=='-')t=true,ch=getchar();
    while(ch<='9'&&ch>='0')x=x*10+ch-48,ch=getchar();
    return t?-x:x;
}
struct Line{int v,next;}e[MAX<<1];
int h[MAX],cnt=1;
inline void Add(int u,int v){e[cnt]=(Line){v,h[u]};h[u]=cnt++;}
int n,K,f[MAX][101][2][2],size[MAX];
int tmp[101][2][2];
void dfs(int u,int ff)
{
    size[u]=1;f[u][0][0][0]=1;f[u][1][1][0]=1;
    for(int i=h[u];i;i=e[i].next)
    {
        int v=e[i].v;if(v==ff)continue;dfs(v,u);
        for(int j=0;j<=size[u]&&j<=K;++j)
            for(int k=0;k<=size[v]&&j+k<=K;++k)
            {
                if(f[u][j][0][0])
                {
                    add(tmp[j+k][0][0],1ll*f[u][j][0][0]*f[v][k][0][1]%MOD);
                    add(tmp[j+k][0][1],1ll*f[u][j][0][0]*f[v][k][1][1]%MOD);
                }
                if(f[u][j][0][1])
                {
                    add(tmp[j+k][0][1],1ll*f[u][j][0][1]*(f[v][k][0][1]+f[v][k][1][1])%MOD);
                }
                if(f[u][j][1][0])
                {
                    add(tmp[j+k][1][0],1ll*f[u][j][1][0]*(f[v][k][0][0]+f[v][k][0][1])%MOD);
                    add(tmp[j+k][1][1],1ll*f[u][j][1][0]*(f[v][k][1][0]+f[v][k][1][1])%MOD);
                }
                if(f[u][j][1][1])
                {
                    int s=0;
                    add(s,f[v][k][0][0]);add(s,f[v][k][0][1]);
                    add(s,f[v][k][1][0]);add(s,f[v][k][1][1]);
                    add(tmp[j+k][1][1],1ll*f[u][j][1][1]*s%MOD);
                }
            }
        size[u]+=size[v];
        for(int j=0;j<=size[u]&&j<=K;++j)
        {
            f[u][j][0][0]=tmp[j][0][0];tmp[j][0][0]=0;
            f[u][j][0][1]=tmp[j][0][1];tmp[j][0][1]=0;
            f[u][j][1][0]=tmp[j][1][0];tmp[j][1][0]=0;
            f[u][j][1][1]=tmp[j][1][1];tmp[j][1][1]=0;
        }
    }
}
int main()
{
    n=read();K=read();
    for(int i=1,u,v;i<n;++i)u=read(),v=read(),Add(u,v),Add(v,u);
    dfs(1,0);
    int ans=(f[1][K][0][1]+f[1][K][1][1])%MOD;
    printf("%d\n",ans);
    return 0;
}