题解:P12551 [UOI 2025] Simple Task

· · 题解

P12551 [UOI 2025] Simple Task

分讨题,细节比较多,单纯的恶心人,实际思维难度应该没有黑。为了理解方便,下文中忽略了若干没有难度的细节。

特判掉 k=1

部分分启发很好,先看部分分。

n 是偶数

如果 a_i 全是奇质数,那么方案是显然的,每次合并两个奇质数为偶数,如果全是偶数了那就随便合。这样每操作一次就少俩质数,显然最优。

如果加上 2 会怎么样呢。

如果奇质数的数量和 2 的数量都是偶数,那么按照上一种情况的做法也是可以的(注意这里要区分 2 和大于 2 的偶数)。

有了这个合并的想法,于是我们可以做 k<\frac{n}{2} 了,这是简单的。

如果奇质数的数量和 2 的数量都是奇数该怎么办呢。如果存在一个奇质数 a_x 满足 a_x+2 不是质数,那么将它与一个 2 配对起来,把剩下的数按照上一种情况去合并就可以了。

接下来可以解决 k=\frac{n}{2} 了,因为找不到 a_x+2 是奇合数,所以怎么划分答案都不会小于 1,直接将相邻两个匹配就可以了。

如果找不到这个数,那么考虑 p,p+2,p+4 模 3 的余数互不相同,所以当 p\not=3p+4 一定是合数。尝试将一个奇质数与两个 2 配对,配对完成后剩下偶数个奇质数,奇数个 2,将多出来那个 2 随便合并一下就可以了。

但是如果所有的奇质数都是 3,上面的做法就不管用了。有两种办法,一种是取 3 个 3 拼成一个 9,另一种是取 3 个 2 和一个 3 拼一个 9,两种方法必定会有一个可行。

如果你找不到一个奇质数 +2 是合数,此时如果 2 的个数是 1,那拼出 p+4 的做法又废掉了,不过此时我们有:奇质数的数量大于 3(小于等于 3 的前面都被解决了)。

考虑如下结论:任意 5 个奇质数中,一定能选出 3 个奇质数的和是合数。分讨他们模 3 的余数即可证明这个结论(这里奇质数等于 3 是允许的)。

在前 5 个奇质数中选出 3 个合成一个合数,然后将剩下两个与 1 个 2 合并,问题就解决了,将偶数随便合并即可。

n 是奇数

同理可以做掉 k<\frac{n}{2}

如果全是奇质数,那么在前 5 个奇质数中找 3 个合并成合数,这部分就解决了。

如果 2 的数量是奇数,那么将偶数个奇质数两两合并,然后再将多出来的 2 加到某一对上面。

如果 2 的数量是偶数,则仿照上面,如果能找到一个奇质数满足其 +2 后是合数,那么两者匹配,剩余奇数个 2 和偶数个奇质数,将奇质数合并后把多出来那一个 2 加过去。否则考虑找一个不是 3 的奇质数,将其与两个 2 合并。

如果奇质数全是 3,当 3 的数量大于 1 的时候,选择将 3 个 3 合成一个 9,否则选择 3 个 2 和一个 3 合成一个 9。

于是问题就解决了,复杂度 O(n\log n)。真是一场酣畅淋漓的食雪啊。

代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
int testnum,n,k,a[100003],k1,k2,k3,k4,k5,k6,k7,k8,k9,bcj[100003],cnt2,rk;
int finf(int xx){if(bcj[xx]!=xx)bcj[xx]=finf(bcj[xx]);return bcj[xx];}
vector<int>lst[100003];
bool isp(int xx){
    for(int i=2;i*i<=xx;i++)if(xx%i==0)return false;
    return true;
}
void printans(){
    for(int i=1;i<=n;i++){lst[i].clear();lst[i].shrink_to_fit();}
    for(int i=1;i<=n;i++)lst[finf(i)].emplace_back(i);
    k1=0;k3=0;
    for(int i=1;i<=n;i++){
        if(lst[i].size()!=0){
            k2=0;
            for(auto j:lst[i])k2+=a[j];
            if(isp(k2))k1++;
        }
    }
    cout<<k1<<'\n';
    for(int i=1;i<=n;i++){
        if(lst[i].size()!=0){
            cout<<lst[i].size()<<' ';
            for(auto j:lst[i])cout<<a[j]<<' ';
            cout<<'\n';
        }
    }
    return;
}
void merge(int xx,int yy){
    if(xx<1||xx>n||yy<1||yy>n||finf(xx)==finf(yy)||k<=0)return;
    k--;bcj[finf(xx)]=finf(yy);
    return;
}
void smpmerge(){
    for(int i=1;i<n;i+=2)merge(i,i+1);
    for(int i=1;i<n;i++)merge(i,i+1);
    printans();
    return;
}
void sol(){
    cin>>n>>k;
    for(int i=1;i<=n;i++)cin>>a[i];
    sort(a+1,a+n+1);
    cnt2=0;
    for(int i=1;i<=n;i++)cnt2+=(a[i]==2);
    for(int i=1;i<=n;i++)bcj[i]=i;
    if(k==1){
        for(int i=1;i<=n;i++)bcj[i]=1;
        printans();
        return;
    }
    k=n-k;
    rk=k;
    if(n%2==1&&cnt2%2==1){
        for(int i=cnt2+1;i<=n;i+=2)merge(i,i+1);
        for(int i=1;i+1<=cnt2;i+=2)merge(i,i+1);
        merge(cnt2,cnt2+1);
        if(cnt2==n)merge(cnt2-1,cnt2);
        for(int i=1;i<n;i++)merge(i,i+1);
        printans();
        return;
    }
    if(n%2==1&&cnt2%2==0){
        if(k<=(n/2)){
            for(int i=1;i+1<=cnt2;i+=2)merge(i,i+1);
            for(int i=cnt2+1;i+1<=n;i+=2)merge(i,i+1);
            printans();
            return;
        }
        if(cnt2>0){
            for(int i=cnt2+1;i<=n;i++)if(a[i]!=3)swap(a[i],a[cnt2+1]);
            if(a[cnt2+1]==3){
                if(cnt2+1==n){
                    for(int i=1;i+1<n-4;i+=2)merge(i,i+1);
                    merge(n,n-1);merge(n-1,n-2);merge(n-2,n-3);
                    merge(n-4,n-5);
                    for(int i=1;i<n;i++)merge(i,i+1);
                    printans();
                    return;
                }
                merge(n-2,n-1);merge(n-1,n);
                for(int i=1;i+1<n-2;i+=2)merge(i,i+1);
                for(int i=1;i<n;i++)merge(i,i+1);
                printans();
                return;
            }
            for(int i=1;i+1<=cnt2-2;i+=2)merge(i,i+1);
            for(int i=cnt2+2;i+1<=n;i+=2)merge(i,i+1);
            if(!isp(a[cnt2+1]+2)){
                merge(cnt2,cnt2+1);
                if(cnt2>2)merge(cnt2-1,cnt2-2);
                else merge(cnt2-1,cnt2+2);
                for(int i=1;i<n;i++)if(finf(i)!=finf(cnt2)&&finf(i+1)!=finf(cnt2))merge(i,i+1);
                merge(1,n);
                printans();
                return;
            }
            merge(cnt2,cnt2+1);merge(cnt2-1,cnt2);
            for(int i=1;i<n;i++)if(finf(i)!=finf(cnt2)&&finf(i+1)!=finf(cnt2))merge(i,i+1);
            merge(1,n);
            printans();
        }
        else{
            int flg=0;
            for(int i=1;i<=5&&flg==0;i++){
                for(int j=i+1;j<=5&&flg==0;j++){
                    for(int u=j+1;u<=5&&flg==0;u++){
                        if((a[i]+a[j]+a[u])%3==0){
                            swap(a[n],a[u]);
                            swap(a[n-1],a[j]);
                            swap(a[n-2],a[i]);
                            merge(n,n-1);merge(n-1,n-2);
                            flg=1;
                        }
                    }
                }
            }
            for(int i=1;i+1<=n-3;i+=2)merge(i,i+1);
            for(int i=1;i<n;i++)merge(i,i+1);
            printans();
            return;
        }
    }
    if(n%2==0&&cnt2%2==0){smpmerge();return;}
    if(n%2==0&&cnt2%2==1){
        if(k<(n/2)){
            for(int i=1;i<n;i+=2)if(a[i]%2==a[i+1]%2)merge(i,i+1);
            printans();
            return;
        }
        for(int i=cnt2+1;i<=n;i++){
            if(!isp(a[i]+2)){
                swap(a[i],a[n]);
                swap(a[cnt2],a[n-1]);
                smpmerge();
                return;
            }
        }
        if(k==(n/2)){smpmerge();return;}
        if(cnt2>1){
            for(int i=cnt2+1;i<=n;i++)if(a[i]!=3)swap(a[i],a[cnt2+1]);
            if(a[cnt2+1]==3){
                if(cnt2+1==n){
                    for(int i=1;i+1<=n;i+=2)merge(i,i+1);
                    merge(cnt2-1,cnt2);
                    smpmerge();
                    return;
                }
                if(cnt2>1){
                    swap(a[n-1],a[cnt2]);
                    swap(a[n-2],a[cnt2-1]);
                    swap(a[n-3],a[cnt2-2]);
                    for(int i=n-3;i<n;i++)merge(i,i+1);
                    for(int i=1;i+1<n-3;i+=2)merge(i,i+1);
                    for(int i=1;i<n;i++)merge(i,i+1);
                    printans();
                    return;
                }
                for(int i=2;i+1<n-2;i+=2)merge(i,i+1);
                merge(n-2,n-1);merge(n-1,n);
                for(int i=1;i<n;i++)merge(i,i+1);
                printans();
                return;
            }
            merge(cnt2,cnt2+1);merge(cnt2-1,cnt2);
            if(cnt2>3)merge(cnt2-2,cnt2-3);
            else merge(cnt2-2,cnt2+2);
            for(int i=1;i<cnt2-3;i+=2)merge(i,i+1);
            for(int i=cnt2+2;i<=n;i+=2)merge(i,i+1);
            for(int i=1;i<n;i++)if(finf(i)!=finf(cnt2)&&finf(i+1)!=finf(cnt2))merge(i,i+1);
            merge(1,n);
            printans();
        }
        else{
            int flg=0;
            for(int i=2;i<=6&&flg==0;i++){
                for(int j=i+1;j<=6&&flg==0;j++){
                    for(int u=j+1;u<=6&&flg==0;u++){
                        if((a[i]+a[j]+a[u])%3==0){
                            swap(a[n],a[u]);
                            swap(a[n-1],a[j]);
                            swap(a[n-2],a[i]);
                            flg=1;
                        }
                    }
                }
            }
            merge(n-2,n-1);merge(n-1,n);
            for(int i=2;i+1<n-2;i+=2)merge(i,i+1);
            for(int i=1;i<n;i++)merge(i,i+1);
            printans();
        }
        return;
    }
    return;
}
signed main(){
    ios::sync_with_stdio(false);
    cin>>testnum;
    while(testnum--)sol();
    return 0;
}