P5080题解

· · 题解

正在切数论的我乍一看还以为是数论,结果思考了一下后才发现原来是建图和搜索。

通过将第一个样例画图之后,可以发现输出最后的结果就是图中的最长链,所以肯定是要建图的。

但我们要不要搜索?这数据范围显然是会爆炸,但是搜索可以剪枝啊,而且这个题可以剪大枝,这样就可以优化很多。

剪枝

只有一处剪枝,就是把当前位置的最大深度存一下,如果有别的父亲节点走到自己并且自己这个节点的已知最大深度小于当前这个新的路径的自己的当前深度,那就可以直接将这个边跳过。

剪枝如下: 这里是手写了一个栈用来存走过的路径,v 数组存当前节点最大深度,栈顶 top 就是当前路径该节点的深度 最后统计一下即可

void dfs(int x, int fa)
{
    st[++top] = x;
    if(top > v[x])
    {
        v[x] = top;
        for (int i = head[x]; i; i = e[i].next)
        {
            int v = e[i].to;
            dfs(v);
        }
        if(top > maxn) 
        {
            maxn = top;
            for (int i = 1; i <= top; i++)
            {
                ans[i] = st[i];
            }
        }
        top--;
    }
    else top--;
}

二分建边

如果我们一个一个建边,那会是平方的复杂度,所以可以用二分来降低一个对数级别的复杂度。 建边如下:

struct edge{
    int next, to;
}e[MAXN];
void add(int x, int y)
{
    e[++cnt].to = y;
    e[cnt].next = head[x];
    head[x] = cnt;
}
for (int i = 1; i <= n; i++)
{
    if(b[i] % 3 == 0) 
    {
        int it = lower_bound(b + 1, b + n + 1, b[i] / 3) - b;
        if(b[it] == b[i] / 3) add(i, it), ru[it]++;
    }
    int it = lower_bound(b + 1, b + n + 1, b[i] * 2) - b;
    if(b[it] == b[i] * 2) add(i, it), ru[it]++;
}

此时我们的复杂度已经降得比较低了,直接就可以过啦!

AC code

#include <bits/stdc++.h>
#define int long long
const int MAXN = 2e6 + 10;
using namespace std;
int n, a[MAXN], b[MAXN], ru[MAXN], ans[MAXN], st[MAXN], maxn, top, cnt, head[MAXN];
struct edge{
    int next, to;
}e[MAXN];
void add(int x, int y)
{
    e[++cnt].to = y;
    e[cnt].next = head[x];
    head[x] = cnt;
}
void dfs(int x, int fa)
{
    st[++top] = x;
    if(top > v[x])
    {
        v[x] = top;
        for (int i = head[x]; i; i = e[i].next)
        {
            int v = e[i].to;
            dfs(v);
        }
        if(top > maxn) 
        {
            maxn = top;
            for (int i = 1; i <= top; i++)
            {
                ans[i] = st[i];
            }
        }
        top--;
    }
    else top--;
}
signed main()
{
    scanf("%lld", &n);
    for (int i = 1; i <= n; i++)
    {
        scanf("%lld", &b[i]);
    }
    sort(b + 1, b + n + 1);
    for (int i = 1; i <= n; i++)
    {
        if(b[i] % 3 == 0) {
            int it = lower_bound(b + 1, b + n + 1, b[i] / 3) - b;
            if(b[it] == b[i] / 3) add(i, it), ru[it]++;
        }
        int it = lower_bound(b + 1, b + n + 1, b[i] * 2) - b;
        if(b[it] == b[i] * 2) add(i, it), ru[it]++;
    }
    for (int i = 1; i <= n; i++)
    {
        if(ru[i] == 0) dfs(i, 0);
    }
    for (int i = 1; i <= n; i++) if(dep[i] > maxn) maxn = dep[i];
    printf("%lld\n", maxn);
    for (int i = 1; i <= maxn; i++) printf("%lld ", b[ans[i]]);
    return 0;
}