题解 P5771 【[JSOI2016]反质数序列】

· · 题解

考虑按奇偶性分类,一个质数一定是一个奇数加一个偶数(2除外),那么这道题就是一个二分图最大独立集的问题,考虑用网络流做。

如果 v_a+v_b \in p 那么我们就连一条 (a,b,\inf) 的边,表示 a, b 不能共存。s向每一个偶数连一条流量为 1 的边,奇数向 t 连一条流量为 1 的边,因为最大流等于最小割,所以答案就是 n-maxflow

然后我们要特殊考虑 1 这个数字怎么处理,我们发现,在一个合法的反素数序列中最多存在一个 1,那么我们在原序列中直接去掉多余的 1 即可。

#include <bits/stdc++.h>
using namespace std;

#define rep(i, x, y) for(int i = (x); i <= (y); ++ i)
#define rop(i, x, y) for(int i = (x); i <  (y); ++ i)
#define per(i, x, y) for(int i = (x); i >= (y); -- i)
#define por(i, x, y) for(int i = (x); i >  (y); -- i)

const int inf = 0x3f3f3f3f;

int prime[200050];
bool vis[200050];
int ptot = 0;
void get_prime(int n){
    rep(i, 2, n) vis[i] = 1;
    rep(i, 2, n) {
        if(vis[i] == 1) prime[++ ptot] = i;
        for(int j = 1 ; j <= ptot && i * prime[j] <= n ; j++){
            vis[i * prime[j]] = 0;
            if(i % prime[j] == 0) break;
        }
    }
}
int s, t;
const int N = 3005;
int head[N * 2], cur[N * 2], nxt[N * N * 4], to[N * N * 4], tot = 1, val[N * N * 4];
int a[N];
void add(int x, int y, int w) { 
    to[++ tot] = y;
    nxt[tot] = head[x];
    head[x] = tot;
    val[tot] = w;
    to[++ tot] = x;
    nxt[tot] = head[y];
    head[y] = tot;
    val[tot] = 0;
}
int dep[N * 2];
bool bfs() {
    memset(dep, 0, sizeof dep);
    queue<int> q;
    q.push(s);
    dep[s] = 1;
    while(q.size()) {
        int now = q.front();
        q.pop();
        for(int i = head[now]; i; i = nxt[i]) {
            if(dep[to[i]] || val[i] == 0) continue;
            q.push(to[i]);
            dep[to[i]] = dep[now] + 1;
            if(to[i] == t) return 1;
        }
    }
    return 0;
}
int dfs(int now, int flow) {
    int res = flow;
    if(now == t) return flow;
    for(int &i = cur[now]; i; i = nxt[i]) {
        if(val[i] == 0 || dep[to[i]] != dep[now] + 1) continue;
        int usd = dfs(to[i], min(res, val[i]));
        val[i] -= usd;
        val[i ^ 1] += usd;
        res -= usd;
        if(res == 0) return flow;
    }
    if(res == flow) dep[now] = -1;
    return flow - res;
}
int main() {
    get_prime(200000);
    int n; scanf("%d", &n);
    int kk = 0, cnt = 0;
    rep(i, 1, n) scanf("%d", &a[i]);
    rep(i, 1, n) if(a[i] == 1) kk = i, ++ cnt;
    s = 2 * n + 1, t = s + 1;
    rep(i, 1, n) add(s, i, 1), add(i + n, t, 1);
    rep(i, 1, n) rep(j, 1, n) if(a[i] % 2 == 0 && a[j] % 2 == 1 && vis[a[i] + a[j]] && a[j] != 1) add(i, j + n, inf);
    rep(i, 1, n) if(kk && a[i] % 2 == 0 && vis[a[i] + 1]) add(i, kk + n, inf);
    int ans = 0;
    while(bfs()) {
        memcpy(cur, head, sizeof cur);
        ans += dfs(s, inf);
    }
    printf("%d\n", n - ans - max(cnt - 1, 0));
}