lg9483 [NOI2023] 合并书本

· · 题解

考虑对合并过程建一棵树。

对于一个点 x,定义 a_x 表示它向上合并的时候,对答案造成的重量贡献的系数。

定义一个点的层级 d_x 为它的两个儿子层级的较大值 +1。我们称 d 更小的层级为更深的层级。

那么层级为 i 的非根非叶子节点会对答案造成 2^i-1 的磨损值贡献。由于非根非叶子节点共有 n-2 个,可以当成是造成 2^i 的贡献,最后给答案减掉 n-2 即可。

考虑这样一个过程:从浅往深(从大往小)扫每一层 i,找出所有 d_x<id_{fa_x}\geq i 的点 x,那么我们只关心可重集 S=\{a_x\} 以及层数 \geq i 的点的磨损值贡献之和(记作 c)。

考虑从 i 层转移到第 i-1 层时会发生什么,这相当于选择一些叶子把它们分裂成两个叶子。也就是选择 S 的一个子集 T,对每个 t\in T,将 t+1 加入到 S,然后令 c=2(c+|T|)

然后我们发现两个性质。一个是 T 一定是 S 的一个前缀,这是显然的。另一个是对于一个方案,|T| 从浅到深单调不降,否则我们可以把某个点留到更深的层级再分裂。

于是可以考虑搜索,每次暴力枚举 |T|,一共有划分数种方案,可以获得 75 分。

然后我们发现,对于 S 相同的状态,只有 c 最小的那个是有用的,可以 bfs 搜出所有状态。实测 n\leq 100 的时候只有 47575 个状态,可以通过。

int lim;
int t,_,n[15],m,i,j,k,a[15][105],vis[500005],lst[500005];
i64 ans,dis[500005];
int ch[500005][105],cnt;
vector<int> seq[500005];
vector<int> f[105];
int qid(vector<int> v)
{
    int x=0,i,len=0;
    vector<int> cur;
    ff(v,it){
        cur.push_back(*it);
        if(!ch[x][*it]){
            ch[x][*it]=++cnt;
            f[cur.size()].push_back(cnt);
            seq[cnt]=cur;
            dis[cnt]=1e18;
        }
        x=ch[x][*it];
    }
    return x;
}
void upd(vector<int> v,i64 c,int l)
{
    int x=qid(v);
    if(dis[x]>c){
        dis[x]=c;lst[x]=l;
    }
    if(dis[x]==c){
        lst[x]=max(lst[x],l);
    }
}
int main()
{
    //cerr<<sizeof(ch)/1048576<<endl;
    read(t);fz1(i,t){read(n[i]);fz1(j,n[i])read(a[i][j]);lim=max(lim,n[i]);}
    upd({0},0,0);
    fz1(_,lim)for(int x:f[_]){
        i64 cur=dis[x]*2;
        vector<int> v=seq[x];
        vector<int> nv=v;
        fz0k(i,v.size()){
            if(x!=1) cur+=2;nv.push_back(v[i]+1);int j=nv.size()-1;
            if(nv.size()>lim) break;
            while(j&&nv[j]<nv[j-1])swap(nv[j],nv[j-1]),j--;
            if(i>=lst[x]) upd(nv,cur,i);
        }
    }
    //cerr<<cnt<<endl;
    fz1(k,t){
        ans=1e18;sort(a[k]+1,a[k]+n[k]+1);
        if(n[k]==1){puts("0");continue;}
        for(int i:f[n[k]]){
            i64 sum=dis[i];
//          cerr<<dis[i]<<endl;ff(seq[i],it) cerr<<*it<<' ';cerr<<endl;
            fz1(j,n[k]) sum+=1ll*a[k][j]*seq[i][n[k]-j];
            ans=min(ans,sum);
        }
        printf("%lld\n",ans-(n[k]-2));
    }
    return 0;
}