题解 【第五届传智杯 初赛】

· · 个人记录

A

题解

签到题。首先用 \text{if} 语句判断 b 的符号,然后加在 a 的绝对值上即可。

时间复杂度为 \mathcal O(1)

参考代码

版本 1

#include<bits/stdc++.h>
#define up(l, r, i) for(int i = l, END##i = r;i <= END##i;++ i)
#define dn(r, l, i) for(int i = r, END##i = l;i >= END##i;-- i)
using namespace std;
typedef long long i64;
const int INF = 2147483647;
int main(){
    int a, b;
    cin >> a >> b;
    cout << fixed << setprecision(0) << copysignl(a, b) + 1e-9 << endl;
    return 0;
}

版本 2

#include<bits/stdc++.h>
#define up(l, r, i) for(int i = l, END##i = r;i <= END##i;++ i)
#define dn(r, l, i) for(int i = r, END##i = l;i >= END##i;-- i)
using namespace std;
typedef long long i64;
const int INF = 2147483647;
int main(){
    i64 a, b; cin >> a >> b;
    cout << (b < 0 ? -llabs(a) : llabs(a));
    return 0;
}

版本 3

#include <bits/stdc++.h>
using namespace std;
int main()
{
    int a,b;
    cin >> a >> b;
    if (b>0 && a==INT_MIN)
        cout << 2147483648ll << endl;
    else
    {
        a=abs(a);
        if (b<0)
            a*=-1;
        cout << a << endl;
    }
    return 0;
}

版本 4

import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        long a = scanner.nextLong(), b = scanner.nextLong();
        System.out.println((Math.abs(a) * (b > 0 ? 1 : -1)));
    }
}

B

题解

模拟题。按照题目要求输入整数 a,b,模拟这个奇怪的进位规则即可。

时间复杂度为 \mathcal O(n+m)

参考代码

版本 1

#include<bits/stdc++.h>
#define up(l, r, i) for(int i = l, END##i = r;i <= END##i;++ i)
#define dn(r, l, i) for(int i = r, END##i = l;i >= END##i;-- i)
using namespace std;
typedef long long i64;
const int INF = 2147483647;
int qread(){
    int w=1,c,ret;
    while((c = getchar()) >  '9' || c <  '0') w = (c == '-' ? -1 : 1); ret = c - '0';
    while((c = getchar()) >= '0' && c <= '9') ret = ret * 10 + c - '0';
    return ret * w;
}
const int MAXN = 2e5 + 3;
int A[MAXN], B[MAXN];
int main(){
    int n = qread(), m = qread(), l = max(n, m);
    dn(n, 1, i) A[i] = qread();
    dn(m, 1, i) B[i] = qread();
    up(1, l, i) A[i] += B[i], A[i + 1] += A[i] / (i + 1), A[i] %= i + 1;
    if(A[l + 1]) ++ l;
    dn(l, 1, i) printf("%d%c", A[i], " \n"[i == 1]);
    return 0;
}

版本 2

#include <bits/stdc++.h>
using namespace std;
int a[200050],b[200050];
int main()
{
    auto read=([&]{
        int x;cin >> x;
        return x;
    });
    int n=read(),m=read();
    int len=max(n,m)+1;
    generate_n(a+1,n,read);
    generate_n(b+1,m,read);
    reverse(a+1,a+n+1);
    reverse(b+1,b+m+1);
    for (int i=1;i<=len;i++)
    {
        a[i]+=b[i];
        a[i+1]+=(a[i]/(i+1));
        a[i]%=(i+1);
    }
    while (a[len]==0 && len>1)
        len--;
    reverse(a+1,a+len+1);
    for (int i=1;i<=len;i++)
        cout << a[i] << " ";
    return 0;
}

版本 3

import java.util.Scanner;

public class Main {

    public static int[] a = new int[200005];
    public static int[] b = new int[200005];
    public static int[] c = new int[200005];

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt(), m = scanner.nextInt();
        int maxLength = Math.max(n, m);
        for (int i = (maxLength - n) + 1; i <= maxLength; ++i)
            a[i] = scanner.nextInt();
        for (int i = (maxLength - m) + 1; i <= maxLength; ++i)
            b[i] = scanner.nextInt();
        for (int i = maxLength, cnt = 2; i > 0; --i, ++cnt) {
            c[i] += a[i] + b[i];
            if (c[i] >= cnt) {
                c[i] -= cnt;
                c[i - 1] += 1;
            }
        }
        if (c[0] > 0) {
            System.out.printf("%d ", c[0]);
        }
        for (int i = 1; i <= maxLength; ++i) {
            System.out.printf("%d ", c[i]);
        }
        System.out.println();
    }
}

C

题解

读入题。暴风吸入输入数据里给定的所有字符,存到数组里,统计有多少个换行符,确定输入文件的总行数 m。由此计算出最后一个行号的长度 s=\lfloor\lg m+1\rfloor(数学库里可以直接调用 \lg,当然你也可以随便用什么途径算出每个数的长度)。

然后就是模拟了。对于第 i 行,

时间复杂度为 \mathcal O(|S|),其中 |S| 是输入的所有字符的个数。

参考代码

版本 1

#include<bits/stdc++.h>
#define up(l, r, i) for(int i = l, END##i = r;i <= END##i;++ i)
#define dn(r, l, i) for(int i = r, END##i = l;i >= END##i;-- i)
using namespace std;
typedef long long i64;
const int INF = 2147483647;
const int MAXN= 2e4 + 3;
char S[MAXN], c; int l, m;
int main(){
    m = count(S + 1 , S + 1 + fread(S + 1, 1, MAXN, stdin), '\n');
    int s = log10(m) + 1 + 1e-9, p = 0;
    up(1, m, i){
        int t = log10(i) + 1 + 1e-9;
        for(int j = 1;j <= s - t;++ j) putchar( ' '); printf("%d ", i);
        for(p = p + 1;S[p] != 10;++ p) putchar(S[p]); putchar('\n');
    }
    return 0;
}

版本 2

#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
char buf[200050];
vector <char> s[200050];
int cnt;
int get_digit(int x)
{
    int digit=1,ret=1;
    while (ret<=x)
    {
        digit++;
        ret*=10;
    }
    return digit;
}
int main()
{
    while(fgets(buf,200000,stdin)!=NULL)
    {
        cnt++;
        for (int i=0;buf[i]!='\n';i++)
            s[cnt].push_back(buf[i]);
    }
    int cnt_digit=get_digit(cnt);
    for (int i=1;i<=cnt;i++)
    {
        for (int j=1;j<=cnt_digit-get_digit(i);j++)
            putchar(' ');
        cout << i << ' ';
        for (int j=0;j<s[i].size();j++)
            putchar(s[i][j]);
        putchar('\n');
    }
    return 0;
}

版本 3

import java.util.Scanner;
import java.util.List;
import java.util.ArrayList;

public class Main {

    public static List<String> list = new ArrayList<>();

    public static int getBit(int x) {
        int cnt = 0;
        while (x > 0) {
            x /= 10;
            ++cnt;
        }
        return cnt;
    }

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        while (scanner.hasNextLine()) {
            list.add(scanner.nextLine());
        }
        int size = list.size();
        int len = getBit(size);
        for (int i = 0; i < size; ++i) {
            System.out.printf("%" + len + "d %s\n", i + 1, list.get(i));
        }
    }
}

D

题解

贪心题。

假设 m 次操作后,剩下来的数字的值域为 [l,r]。那么原来 a 数列里,所有严格小于 l 的数字肯定都被操作了,同时所有严格大于 r 的数字肯定也被操作了。

a 里面一共有 u 个数严格小于 l;有 v 个数严格大于 r

断言:所需要的操作次数至少为 u+v+\min(u,v),并且可以取到。

证明:如果一个位置在操作后,它的值还在 [l,r] 之外,那么后面某个时候肯定还要把它进行操作。容易发现,至少前 \min(u,v) 次操作肯定都无法让结果变到 [l,r] 内。因此执行完这至少 \min(u,v) 次操作后肯定还要再把这 u+v 个数都操作一遍,这是容易做到的(通过 \min(u,v) 次操作,你总能把此时值域的下界提升到 l 或者把上界降低到 r)。所以最优决策肯定不会小于 u+v+\min(u,v) 次。

那么这题怎么做呢?

直接将 a 数组从小到大排序。考虑枚举 l=a_i,计算出最小的 r=a_j,一定有 (i-1)+(n-j)+\min(i-1,n-j)\le m。容易发现随着 i 的增大,j 肯定是单调不减的。因此首先预处理 j=1,接着随着 i 的增大找到可以满足条件的最小的 j。显然当 i> \min(n,m+1) 时就不存在这样的 j 了。

时间复杂度为 \mathcal O(n\log n),瓶颈在于排序。

参考代码

版本 1

#include<bits/stdc++.h>
#define up(l, r, i) for(int i = l, END##i = r;i <= END##i;++ i)
#define dn(r, l, i) for(int i = r, END##i = l;i >= END##i;-- i)
using namespace std;
typedef long long i64;
const int INF = 2147483647;
int qread(){
    int w=1,c,ret;
    while((c = getchar()) >  '9' || c <  '0') w = (c == '-' ? -1 : 1); ret = c - '0';
    while((c = getchar()) >= '0' && c <= '9') ret = ret * 10 + c - '0';
    return ret * w;
}
const int MAXN = 1e5 + 3;
int A[MAXN], ans = INF;
int main(){
    int n = qread(), m = qread();
    up(1, n, i) A[i] = qread();
    sort(A + 1, A + 1 + n);
    int j = 1;
    up(1, min(n, m + 1), i){
        j = max(i, j);
        while((i - 1) + (n - j) + min(i - 1, n - j) > m) ++ j;
        ans = min(ans, A[j] - A[i]);
    }
    printf("%d\n", ans);
    return 0;
}

版本 2

import java.util.Scanner;
import java.util.Arrays;

public class Main {

    public static int[] a = new int[100005];

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt(), m = scanner.nextInt();
        for (int i = 1; i <= n; ++i)
            a[i] = scanner.nextInt();
        Arrays.sort(a, 1, n + 1);
        int j = 1, ans = Integer.MAX_VALUE;
        for (int i = 1; i <= Math.min(n, m + 1); ++i) {
            j = Math.max(j, i);
            while((i - 1) + (n - j) + Math.min(i - 1, n - j) > m) 
                ++j;
            ans = Math.min(ans, a[j] - a[i]);
        }
        System.out.println(ans);
    }
}

E

题解

数学题。

题目已经贴心地给我们把数列分好了段。容易发现,第 i 段长度为 4i-3,前 n 段长度为 2n^2-n。可以二分出询问所在的段 k

计算出第 k 段两端端点 [l,r] 后,再计算它左右两侧的四等分点 l'=l+(k-1),r'=r-(k-1)。那么第 k 段又分成了三段,[l,l')[l',r'],(r',r]

分类讨论下输出啥就行了。

时间复杂度为 \mathcal O(q\log v),其中 v 是值域。

参考代码

版本 1

#include<bits/stdc++.h>
#define up(l, r, i) for(int i = l, END##i = r;i <= END##i;++ i)
#define dn(r, l, i) for(int i = r, END##i = l;i >= END##i;-- i)
using namespace std;
typedef long long i64;
const int INF = 2147483647;
i64 qread(){
    i64 w=1,c,ret;
    while((c = getchar()) >  '9' || c <  '0') w = (c == '-' ? -1 : 1); ret = c - '0';
    while((c = getchar()) >= '0' && c <= '9') ret = ret * 10 + c - '0';
    return ret * w;
}
i64 val(i64 p){return 2 * p * p - p;}
int main(){
    up(1, qread(), TTT){
        i64 k = qread(), p = 0;
        dn(30, 0, i){
            if(val(p | 1 << i) < k) p |= 1 << i;
        }
        int i = p + 1, w = i - 1;
        i64 l = val(p) + 1, ll = l + w;
        i64 r = val(i),     rr = r - w;
        if(l  <= k && k <  ll) printf("%lld\n", k - l     ); else 
        if(ll <= k && k <= rr) printf("%lld\n", w - k + ll); else 
            printf("%lld\n", k - r);
    }
    return 0;
}

版本 2

#include <iostream>
using namespace std;
int main()
{
    int q;
    cin >> q;
    while (q--)
    {
        long long k,l=1,r=2e9,ans=0;
        cin >> k;
        while (l<=r)
        {
            long long mid=(l+r)/2;
            if ((long long)mid*mid*2-mid>=k)
            {
                r=mid-1;
                ans=mid;
            }
            else
                l=mid+1;
        }
        ans--;
        long long len=ans*ans*2-ans;
        k-=len+1;
        if (k<=ans)
            cout << k << endl;
        else if (k<=3*ans)
            cout << 2*ans-k << endl;
        else
            cout << -4*ans+k << endl;

    }
    return 0;
}

版本 3

import java.io.*;
import java.util.StringTokenizer;

public class Main {

    public static long val(long p) {
        return p * 2 * p - p;
    }

    public static class Scanner {
        public BufferedReader in;
        public StringTokenizer tok;
        public String next() { hasNext(); return tok.nextToken(); }
        public String nextLine() { try { return in.readLine(); } catch (Exception e) { return null; } }
        public long nextLong() { return Long.parseLong(next()); }
        public int nextInt() { return Integer.parseInt(next()); }
        public PrintWriter out = new PrintWriter(new OutputStreamWriter(System.out));
        public boolean hasNext() {
            while (tok == null || !tok.hasMoreTokens()) try { tok = new StringTokenizer(in.readLine()); } catch (Exception e) { return false; }
            return true;
        }
        public Scanner(InputStream inputStream) { in = new BufferedReader(new InputStreamReader(inputStream)); }
    }

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int q = scanner.nextInt();
        while (q-- > 0) {
            long k = scanner.nextLong(), p = 0;
            for (int i = 30; i >= 0; --i) {
                if (val(p | 1 << i) < k)
                    p |= 1 << i;
            }
            long i = p + 1, w = i - 1;
            long l = val(p) + 1, ll = l + w;
            long r = val(i), rr = r - w;
            if (l <= k && k < ll)
                System.out.println(k - l);
            else if (ll <= k && k <= rr)
                System.out.println(w - k + ll);
            else
                System.out.println(k - r);
        }
    }
}

F

题解

模拟题。没什么好说的。这里就讲几个细节:

时间复杂度为 \mathcal O(nq+nL+kq)

参考代码

版本 1

#include<bits/stdc++.h>
#define up(l, r, i) for(int i = l, END##i = r;i <= END##i;++ i)
#define dn(r, l, i) for(int i = r, END##i = l;i >= END##i;-- i)
using namespace std;
typedef long long i64;
const int INF = 2147483647;
int qread(){
    int w=1,c,ret;
    while((c = getchar()) >  '9' || c <  '0') w = (c == '-' ? -1 : 1); ret = c - '0';
    while((c = getchar()) >= '0' && c <= '9') ret = ret * 10 + c - '0';
    return ret * w;
}
int n, m, q, maxl; bool o = 1;
const int MAXN = 100 + 3;
int C[MAXN][MAXN], A[MAXN], D[MAXN], L[MAXN], O[3], T[MAXN];
i64 M[2];
int main(){
    n = qread(), m = qread(), q = qread(), maxl = qread();
    M[0] = M[1] = m, O[0] = O[1] = 1;
    up(1, n, i) L[i] = 0, T[i] = -1;
    up(1, n, i)
        up(0, maxl - 1, j) C[i][j] = qread();
    up(1, n, i) D[i] = qread();
    for(int op, k;~scanf("%d%d", &op, &k);){
        if(op == 1){
            if(o == 1){
                up(1, n, i) if(T[i] != -1) M[T[i]] += D[i];
            }
            o ^= 1;
            up(1, k, i){
                O[o] = (O[o]) % n + 1;
                int p = O[o];
                if(T[p] ==  o) M[o] += A[p]; else 
                if(T[p] == !o){
                    M[!o] += A[p], M[o] -= A[p];
                    if(M[o] < 0)
                        puts(o ? "Merry" : "Renko"), exit(0);
                }
            }
        } else {
            int p = O[o];
            while(k && L[p] < maxl && M[o] >= C[p][L[p]] && T[p] != !o){
                A[p] += C[p][L[p]], M[o] -= C[p][L[p]];
                ++ L[p], T[p] = o, -- k;
            }
        }
    }
    up(1, n, i) if(T[i] != -1) M[T[i]] += D[i];
    printf("%lld %lld\n", M[0], M[1]);
    return 0;
}

版本 2

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cctype>
#include <queue>

using namespace std;

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
    while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
    return x*f;
}

int n,m,q,L,C[105][105],d[105],a[105],lv[105],pos[2],belong[105];

long long money[2];

const string name[2]={"Renko","Merry"};

inline void putfail(int id)
{
    cout << name[id] << endl;
    exit(0);
}

inline void forward(int id,int k)
{
    for (int i=1;i<=k;i++)
    {
        int &p=pos[id];
        p++;
        if (p>n)
            p-=n;
        if (belong[p]==id)
            money[id]+=a[p];
        else if (belong[p]==(id^1))
        {
            money[id]-=a[p];
            money[id^1]+=a[p];
        }
        if (money[id]<0)
            putfail(id);
    }
}

inline void build(int id,int k)
{
    int p=pos[id],cost=C[p][lv[p]];
    for (int i=1;(money[id]>=cost && lv[p]<L && (belong[p]==-1 || belong[p]==id) && i<=k);i++)
    {
        belong[p]=id;
        money[id]-=cost;
        a[p]+=cost;
        cost=C[p][++lv[p]];
    }
}

int main()
{
    n=read(),m=read(),q=read(),L=read();
    for (int i=1;i<=n;i++)
    {
        for (int j=0;j<L;j++)
            C[i][j]=read();
    }
    for (int i=1;i<=n;i++)
        d[i]=read();
    pos[0]=pos[1]=1;
    fill(belong+1,belong+n+1,-1);
    money[0]=money[1]=m;
    for (int i=0;i<2*q;i++)
    {
        int op=read();
        begin:
        for (int j=1;j<=n && !(i&1);j++)
        {
            if (belong[j]==0)
                money[0]+=d[j];
            else if (belong[j]==1)
                money[1]+=d[j];
        }
        int k=read();
        forward(i&1,k);
        if (i==2*q-1)
            break;
        op=read();
        if (op==1)
        {
            i++;
            goto begin;
        }
        else
        {
            k=read();
            build(i&1,k);
        }
    }

    for (int j=1;j<=n;j++)
    {
        if (belong[j]==0)
            money[0]+=d[j];
        else if (belong[j]==1)
            money[1]+=d[j];
    }

    cout << money[0] << " " << money[1] << endl;
    return 0;
}

G

题解

前缀和题。

注意到 M 矩阵可以看作 B 矩阵在行上长度为 r 的循环,在列上长度为 c 的循环,容易想到将原来的 A 矩阵也按照这两个方向上的循环进行染色。使用 r\times c 种颜色染色。

这样子有什么好处呢?我们进行一个特殊的二维前缀和。

S_{i,j}=A_{i,j}+S_{i-r,j}+S_{i,j-c}-S_{i-r,j-c}

那比如说 (4,7) 位置。S_{4,7} 的值就是 a_{2,1}+a_{2,4}+a_{2,7}+a_{4,1}+a_{4,4}+a_{4,7}。换言之,我们对每种颜色都做了一次二维前缀和

比如,现在需要计算左上角、右下角分别为 (3,4),(5,7) 的子矩阵里,所有绿色元素的和。那么答案就是,

S_{5,7}-S_{1,7}-S_{5,1}+S_{1,1}

更一般地,如果我们希望计算左上角、右下角分别为 (x_1,y_1),(x_2,y_2) 的子矩阵((x_1,y_1)(x_2,y_2) 两个位置的颜色相同,设为 t)里,所有颜色为 t 的元素之和,答案就是:

S_{x_2,y_2}-S_{x_2,y_1-c}-S_{x_1-r,y_2}+S_{x_1-r,y_1-c}

现在考察一次询问。

容易发现,我们选取询问矩阵左上角这个 r\times c 的小矩阵,那么这个小矩阵里面应该每种颜色都恰好出现了一次。当然这不是重点,重点是矩阵里所有颜色都会在这个小矩阵出现一次。并且,我们可以根据 B 矩阵算出,哪些颜色对应的 A_{i,j} 值是需要被计算的。

容易计算出小矩阵里的每种颜色,在大矩阵(询问的那个矩阵)里对应的矩阵的左上角、右下角坐标。对于每种颜色,都做一次二维前缀和即可。

时间复杂度为 \mathcal O(nm+qrc)

参考代码

#include<bits/stdc++.h>
#define up(l, r, i) for(int i = l, END##i = r;i <= END##i;++ i)
#define dn(r, l, i) for(int i = r, END##i = l;i >= END##i;-- i)
using namespace std;
typedef long long i64;
const int INF = 2147483647;
int n, m, r, c, q;
int qread(){
    int w=1,c,ret;
    while((c = getchar()) >  '9' || c <  '0') w = (c == '-' ? -1 : 1); ret = c - '0';
    while((c = getchar()) >= '0' && c <= '9') ret = ret * 10 + c - '0';
    return ret * w;
}
const int MAXN = 2e3 + 3;
const int MAXM =  50 + 3;
const int MOD  = 998244353;
int A[MAXN][MAXN], S[MAXN][MAXN]; bool B[MAXN][MAXN];
int calc(int a1, int b1, int a2, int b2){
    int ret = S[a2][b2];
    if(a1 > r) ret = (ret - S[a1 - r][b2] + MOD) % MOD;
    if(b1 > c) ret = (ret - S[a2][b1 - c] + MOD) % MOD;
    if(a1 > r && b1 > c) ret = (ret + S[a1 - r][b1 - c]) % MOD;
    return ret;
}
int main(){
    n = qread(), m = qread();
    up(1, n, i) up(1, m, j) A[i][j] = qread();
    r = qread(), c = qread();
    up(1, r, i) up(1, c, j) B[i][j] = qread();
    up(1, n, i) up(1, m, j){
        S[i][j] = A[i][j];
        if(i > r) S[i][j] = (S[i][j] + S[i - r][j]) % MOD;
        if(j > c) S[i][j] = (S[i][j] + S[i][j - c]) % MOD;
        if(i > r && j > c)
            S[i][j] = (S[i][j] - S[i - r][j - c] + MOD) % MOD;
    }
    q = qread();
    up(1, q, i){
        int _x1 = qread(), _y1 = qread();
        int _x2 = qread(), _y2 = qread();
        int ans = 0;
        up(1, min(r, _x2 - _x1 + 1), a)
            up(1, min(c, _y2 - _y1 + 1), b) if(B[a][b] == 0){
                int a1 = _x1 + a - 1, a2 = a1 + (_x2 - a1) / r * r;
                int b1 = _y1 + b - 1, b2 = b1 + (_y2 - b1) / c * c;
                ans = (ans + calc(a1, b1, a2, b2)) % MOD;
        }
        printf("%d\n", ans);
    }

    return 0;
}

H

题解

搜索题。

注意一个重要性质:水流之间可视为互不干扰的。虽然确实有强度更大的水流可以覆盖强度更小的水流这样的设定,但容易发现强度更大的水流,可以流到的空间,包含了强度更小的水流。

(感性理解一下)

于是可以考虑,从高到低计算每个高度有哪些位置是有水流的。下面定义结构体 \text{Pos2} 用来存储二维坐标,结构体 \text{Pos3} 用来存储三维坐标。

对于输入进来的每个实体方块 (x,y,h),都塞到 B 里。B(x,y,h)\gets\text{true};对于起始水方块的位置 (x_0,y_0),塞到 W 里,W(x_0,y_0)\gets\text{true}

首先将所有实体方块按照 h 值的大小由大到小排序,枚举每个高度 h。记高度为 h 的方块组成的集合为 B_h,那么 W 中的水柱可能会有一些流到了 B_h 里的某些方块上,发生了扩散。\bm {B_h} 出发,算出这些会发生扩散的二维坐标位置,放到队列 P 里。当然,如果 (x,y) 位置会发生扩散,那就代表扩散完后 (x,y,h) 位置肯定没有水方块,于是 W 里要删除 (x,y) 位置。

为什么不枚举 W 内的坐标来确定有哪些位置会发生扩散?因为这么做复杂度是 \mathcal O(|W|\log |W|) 的,而枚举 B_h 内的坐标复杂度是 \mathcal O(|B_h|\log |B|) 的。前者容易构造出一个 |W| 较大,并且不同的高度够多的数据,将时间复杂度卡到 \mathcal O(n^2),是不可以的。后者则是正确的复杂度。

现在我们要对 P 里的水流进行扩散了。为了扩散,我们需要知道第 h 层每个点到达目标位置的最短距离。对于 P 里的每个位置都算一次这个距离,复杂度达到了 \mathcal O(|P|\log|P|\cdot |B_h|\log|B|),这是不可接受的。

但是可以发现,这一层实际有用的目标位置(紧挨在一个实体方块旁边)是不多的,个数是 \mathcal O(|B_h|) 级别。考虑找到这些有用的目标位置,放到队列 Q 里。怎么找目标位置呢?还是要枚举 B_h 内的坐标,检查一下它四周是不是没有实体方块。如果没有实体方块那就丢 Q 里。可以去重,不去重应该也没啥问题。如果想要去重,那还要开一个 \text{map}\lang \text{Pos2},\text{bool}\rang(记为 V)存一下有那些位置已经放进 Q 里了。

Q 初始值求好后,就可以宽度优先搜索,计算出 B_h 内每个点到达目标结点的最短长度。这样时间复杂度降为了 \mathcal O(|B_h|\log |B|)

接着从 P 里的位置开始进行宽度优先搜索。从 (x,y) 位置可以到达 (x',y') 位置,当且仅当 (x',y',h+1) 位置没有实体方块,并且 D(x',y')=D(x,y)-1,并且 K(x,y)>1

当然,如果 (x,y) 位置已经是目标位置,那就令 W(x,y)\gets \text{true}

最后是时间复杂度分析。上面出现的三个过程时间复杂度全部都是 \mathcal O(|B_h|\log |B|),直接求和,得到总时间复杂度为 \mathcal O(n\log n)

参考代码

#include<bits/stdc++.h>
#define up(l,r,i) for(int i=l,END##i=r;i<=END##i;++i)
#define dn(r,l,i) for(int i=r,END##i=l;i>=END##i;--i)
using namespace std;
typedef long long i64;
const int INF =2147483647;
struct Pos2{
    int x, y;
    Pos2(int _x = 0, int _y = 0):x(_x), y(_y){}
    const bool operator < (const Pos2 &t) const {
        if(x != t.x) return x < t.x;
        return y < t.y;
    }
    const bool operator > (const Pos2 &t) const {
        if(x != t.x) return x > t.x;
        return y > t.y;
    }
    const bool operator ==(const Pos2 &t) const {
        return x == t.x && y == t.y;
    }
};
struct Pos3{
    int x, y, z;
    Pos3(int _x = 0, int _y = 0, int _z = 0):
        x(_x), y(_y), z(_z){}
    const bool operator < (const Pos3 &t) const {
        if(x != t.x) return x < t.x;
        if(y != t.y) return y < t.y;
        return z < t.z;
    }
    const bool operator > (const Pos3 &t) const {
        if(x != t.x) return x > t.x;
        if(y != t.y) return y > t.y;
        return z > t.z;
    }
    const bool operator ==(const Pos3 &t) const {
        return x == t.x && y == t.y && z == t.z;
    }
};
const int BASE = 13331;
struct Hash{
    unsigned operator ()(const Pos2 t) const{
        return t.x * BASE + t.y;
    }
    unsigned operator ()(const Pos3 t) const{
        return (t.x * BASE + t.y) * BASE + t.z;
    }
};
unordered_map<Pos3, bool, Hash> B;   // 存 (x, y, z) 是否有方块
unordered_map<Pos2, bool, Hash> V;   // 存 (x, y, h + 1) 有没有使用过
unordered_map<Pos2, int , Hash> D;   // 存 (x, y) 的最短路程
unordered_map<Pos2, bool, Hash> W;   // 存 (x, y, h + 1) 位置有没有水方块
unordered_map<Pos2, int , Hash> K;   // 存 (x, y, h + 1) 位置水方块的强度
const int DIR[4][2] = {{1, 0}, {0, 1}, {-1, 0}, {0, -1}};
const int MAXN = 2e5 + 3;
int n, p, X[MAXN], Y[MAXN], Z[MAXN], I[MAXN];
int qread(){
    int w = 1, c, ret;
    while((c = getchar()) >  '9' || c <  '0') w = (c == '-' ? -1 : 1); ret = c - '0';
    while((c = getchar()) >= '0' && c <= '9') ret = ret * 10 + c - '0';
    return ret * w;
}
bool cmp(int a, int b){ return Z[a] > Z[b]; }
int _x0, _y0;
int main(){
    n = qread(), p = qread(), _x0 = qread(), _y0 = qread();
    W[Pos2(_x0, _y0)] = true;
    up(1, n, i){
        X[i] = qread(), Y[i] = qread(), Z[i] = qread(), I[i] = i;
        B[Pos3(X[i], Y[i], Z[i])] = true;
    }
    sort(I + 1, I + 1 + n, cmp);
    up(1, n, i){
        int h = Z[I[i]], j;
        queue <Pos2> P, Q;
        for(j = i;j <= n && Z[I[j]] == h;++ j){
            int o = I[j], x = X[o], y = Y[o];
            Pos2 u(x, y);
            if(W.count(u))
                P.push(u), K[u] = p, W.erase(u);
            up(0, 3, k){
                int nx = x + DIR[k][0];
                int ny = y + DIR[k][1];
                Pos2 v(nx, ny);
                if(!V.count(v) && !B.count(Pos3(nx, ny, h))
                    && !B.count(Pos3(nx, ny, h + 1)))
                    V[v] = true, D[v] = 0, Q.push(v);
            }
        }
        while(!Q.empty()){
            Pos2 u = Q.front(); Q.pop(); int x = u.x, y = u.y;
            up(0, 3, k){
                int nx = x + DIR[k][0];
                int ny = y + DIR[k][1];
                Pos2 v(nx, ny);
                if(!D.count(v) && B.count(Pos3(nx, ny, h))
                    && !B.count(Pos3(nx, ny, h + 1)))
                    D[v] = D[u] + 1, Q.push(v);
            }
        }
        while(!P.empty()){
            Pos2 u = P.front(); P.pop(); int x = u.x, y = u.y;
            int d = D[u], s = K[u];
            if(!B.count(Pos3{x, y, h})){
                W[u] = true; continue;
            }
            if(s == 1) continue;
            up(0, 3, k){
                int nx = x + DIR[k][0];
                int ny = y + DIR[k][1];
                Pos2 v(nx, ny);
                if( D[v] == d - 1)
                if(!K.count(v) && !B.count(Pos3(nx, ny, h + 1)))
                    K[v] = s - 1, P.push(v);
            }
        }
        i = j - 1, D.clear(), K.clear(), V.clear();
    }
    printf("%u\n", W.size());
    return 0;
}

参考代码 2

import java.io.*;
import java.util.*;

public class Main {
    public static class Vec2d {
        public int x, y;

        public Vec2d(int x, int y) {
            this.x = x;
            this.y = y;
        }

        @Override
        public int hashCode() {
            return Arrays.hashCode(new int[] {x, y});
        }

        public boolean equals(Vec2d vec2d) {
            return this.x == vec2d.x && this.y == vec2d.y;
        }
        @Override
        public boolean equals(Object vec2d) {
            if (!(vec2d instanceof Vec2d))
                return false;
            return this.x == ((Vec2d) vec2d).x && this.y == ((Vec2d) vec2d).y;
        }
    }

    public static class Vec3d {
        public int x, y, z;

        public Vec3d(int x, int y, int z) {
            this.x = x;
            this.y = y;
            this.z = z;
        }
        @Override
        public int hashCode() {
            return Arrays.hashCode(new int[] {x, y, z});
        }
        public boolean equals(Vec3d vec2d) {
            return this.x == vec2d.x && this.y == vec2d.y && this.z == vec2d.z;
        }
        @Override
        public boolean equals(Object vec2d) {
            if (!(vec2d instanceof Vec3d))
                return false;
            return this.x == ((Vec3d) vec2d).x && this.y == ((Vec3d) vec2d).y && this.z == ((Vec3d) vec2d).z;
        }
    }

    public static class Scanner {
        public BufferedReader in;
        public StringTokenizer tok;

        public String next() {
            hasNext();
            return tok.nextToken();
        }

        public String nextLine() {
            try {
                return in.readLine();
            } catch (Exception e) {
                return null;
            }
        }

        public long nextLong() {
            return Long.parseLong(next());
        }

        public int nextInt() {
            return Integer.parseInt(next());
        }

        public PrintWriter out = new PrintWriter(new OutputStreamWriter(System.out));

        public boolean hasNext() {
            while (tok == null || !tok.hasMoreTokens()) try {
                tok = new StringTokenizer(in.readLine());
            } catch (Exception e) {
                return false;
            }
            return true;
        }

        public Scanner(InputStream inputStream) {
            in = new BufferedReader(new InputStreamReader(inputStream));
        }
    }

    public static Map<Vec3d, Boolean> isblock = new HashMap<>();
    public static Map<Vec2d, Boolean> isused = new HashMap<>();
    public static Map<Vec2d, Integer> dist = new HashMap<>();
    public static Map<Vec2d, Boolean> iswater = new HashMap<>();
    public static Map<Vec2d, Integer> strwater = new HashMap<>();

    public static final int[] dx = {1, -1, 0, 0}, dy = {0, 0, 1, -1};

    public static int n, k, _x0, _y0;
    public static int[] x = new int[100050], y = new int[100050], z = new int[100050];
    public static List<Integer> var_id = new ArrayList<>();

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        n = scanner.nextInt();
        k = scanner.nextInt();
        _x0 = scanner.nextInt();
        _y0 = scanner.nextInt();
        iswater.put(new Vec2d(_x0, _y0), true);
        for (int i = 1; i <= n; i++) {
            x[i] = scanner.nextInt();
            y[i] = scanner.nextInt();
            z[i] = scanner.nextInt();
            isblock.put(new Vec3d(x[i], y[i], z[i]), true);
            var_id.add(i);
        }
        var_id.sort((x, y) -> z[y] - z[x]);
        List<Integer> id = new ArrayList<>();
        id.add(0);
        for (int i = 0; i < n; ++i)
            id.add(var_id.get(i));
        for (int i = 0; i < 5; ++i)
            id.add(0);
        for (int i = 1; i <= n; i++) {
            int height = z[id.get(i)];
            Queue<Vec2d> p = new LinkedList<>(), q = new LinkedList<>();
            // spread at the same height
            for (int nid = id.get(i); i <= n && z[nid] == height; ) {
                int nx = x[nid], ny = y[nid];
                Vec2d u = new Vec2d(nx, ny);
                if (iswater.getOrDefault(u, false)) {
                    iswater.put(u, false);
                    p.add(u);
                    strwater.put(u, k);
                }
                for (int j = 0; j < 4; j++) {
                    int nx1 = nx + dx[j], ny1 = ny + dy[j];
                    Vec2d v = new Vec2d(nx1, ny1);
                    Vec3d v1 = new Vec3d(nx1, ny1, height);
                    Vec3d v2 = new Vec3d(nx1, ny1, height + 1);
                    if (!isused.getOrDefault(v, false) && !isblock.getOrDefault(v1, false) && !isblock.getOrDefault(v2, false)) {
                        isused.put(v, true);
                        dist.put(v, 0);
                        q.add(v);
                    }
                }
                i++;
                nid = id.get(i);
            }
            i--;
            // spread water in Q
            while (!q.isEmpty()) {
                Vec2d var1 = q.element();
                q.remove();
                int x = var1.x, y = var1.y;
                Vec2d u = new Vec2d(x, y);
                for (int j = 0; j < 4; j++) {
                    int nx = x + dx[j], ny = y + dy[j];
                    Vec2d v = new Vec2d(nx, ny);
                    Vec3d v1 = new Vec3d(nx, ny, height);
                    Vec3d v2 = new Vec3d(nx, ny, height + 1);
                    if (dist.getOrDefault(v, 0) == 0 && isblock.getOrDefault(v1, false) && !isblock.getOrDefault(v2, false)) {
                        dist.put(v, dist.get(u) + 1);
                        q.add(v);
                    }
                }
            }
            //spread water in P
            while (!p.isEmpty()) {
                Vec2d var1 = p.element();
                p.remove();
                int x = var1.x, y = var1.y;
                Vec2d u = new Vec2d(x, y);
                Vec3d u1 = new Vec3d(x, y, height);
                int d = dist.getOrDefault(u, 0), s = strwater.getOrDefault(u, 0);
                if (!isblock.getOrDefault(u1, false)) {
                    iswater.put(u, true);
                    continue;
                }
                if (s == 1)
                    continue;
                for (int j = 0; j < 4; j++) {
                    int nx = x + dx[j], ny = y + dy[j];
                    Vec2d v = new Vec2d(nx, ny);
                    Vec3d v1 = new Vec3d(nx, ny, height + 1);
                    if (dist.getOrDefault(v, 0) == d - 1 && strwater.getOrDefault(v, 0) == 0 && !isblock.getOrDefault(v1, false)) {
                        strwater.put(v, s - 1);
                        p.add(v);
                    }
                }
            }
            isused.clear();
            dist.clear();
            strwater.clear();
        }
        int cnt = 0;
        for (boolean i : iswater.values()) {
            cnt += i ? 1 : 0;
        }
        System.out.println(cnt);
    }
}

I

题解

压轴题。

容易发现这是一棵树。具体可以用归纳法。设前 i 行组成一棵树,那么第 i+1 行内每个点都向第 i 行某个点连了边,那肯定前 i+1 行也是一棵树了。

容易发现这树很特殊。它的点数达到了 n^2 级别,但是第 i 列整个就是串在一起的形成链的样子。可以证明,整棵树可以被划分为 \mathcal O(n) 个点和 \mathcal O(n) 条链。

定义:我们选出树上一些点作为关键点。这些点包括第 1 层的所有点、第 n 层的所有点、其他所有度数3 的点。

图中标上红星的即为关键点。

断言:关键点的个数是 \mathcal O(n) 级别的。

证明:考虑前 i 行,i\ge 2。容易发现,第 i 行所有点的度数应该恰好为 1。当加入第 i+1 行的所有点以及边时,第 i+1 行前 i 个点连的边使得第 i 行所有点度数都变成了 2,而第 i+1 行第 i+1 个点连的边使得第 i+1 行恰好有一个点度数变成了 3,这个点就成了关键点。那么,第 2\sim (n-1) 行应该均恰好有一个特殊点。再加上第 1 行与第 n 行的点,特殊点的总个数就是 \mathcal O(n) 级别的了。

除了关键点,其他所有点的度数均为 2(读者自证不难)。于是别的点肯定会在某条链上。沿着链走,走到的两端肯定都是关键点。于是,我们可以把一条链抽象成连接两个关键点的「边」,这个「虚树」节点个数为 \mathcal O(n),那么它的边的个数肯定也是 \mathcal O(n)

把虚树建好后,可以用非常经典的树上最大点权独立集的动态规划做法在虚树上跑。我们关心的是两个关键点之间应该怎么转移。换言之,我们需要知道连接两个关键点的链它的贡献怎么计算。

f_u,g_u 分别表示,在选择/不选择虚树上的节点 u 的情况下,子树 u 能取得的最大点权独立集的值。

假设有两个关键点 u,v,他们之间通过链 s=\{p_1,p_2,\cdots p_t\} 连接。那么 vf_u 的贡献是 \max\{\operatorname{val}(p_2,p_3,\cdots,p_t)+g_v,\operatorname{val}(p_2,p_3,\cdots,p_{t-1})+f_v\}vg_u 的贡献是 \max\{\operatorname{val}(p_1,p_2,\cdots,p_t)+g_v,\operatorname{val}(p_1,p_2,\cdots,p_{t-1})+f_v\}。其中 \operatorname{val}(p_1,p_2,\cdots,p_t) 表示序列 p 在不能选择相邻元素的情况下可以取得的最大点权独立集。我们需要知道,s 序列在「首项/末项」「能选/不能选」共 4 种组合的情况下,分别计算出来的最大点权独立集是多少。

注意一个重要性质:

这同样容易证明。由于该树特殊的构造方法,对于 j<i,非特殊点 (i,j) 肯定与 (i+1,j)(i-1,j) 相连,那么这三个点肯定在同一列;对于 (i,i),它向上连接的点肯定是关键点,向下连接的点肯定与它在同一列。于是容易发现每条链上的非特殊点必然在同一列。

那么一条链可以被映射到 01 序列 r 上的一段区间 [a,b]。假设这个链在第 k 列。如果 c_k=0,那么这个区间内每个位置的权值就是 r_a,r_{a+1},\cdots,r_b;如果 c_k=1,那么这个区间内每个位置的权值就是 \neg r_a,\neg r_{a+1},\cdots,\neg r_b

问题转化为了,查询对 r 序列或者 \neg r 序列(这两个都是 01 序列)做区间最大点权独立集的结果。

对于 01 序列,计算最大点权独立集是可以采用贪心思想的。比如,对于长度为 2k 的全 1 段,它的最大点权独立集显然是 k;对于长度为 2k+1 的全 1 段,它的最大点权独立集显然是 k+1。对于含有 0 的序列,总是可以以 0 作为分隔符划分出各种全 1 段,分别求和再相加就行。

那么怎么对 01 序列做区间最大点权独立集呢?

我们预处理两个东西:

容易预处理上面的两个数组。现在要计算 [l,r] 区间的最大点权独立集,那么结果就是:

h_r-h_{\min\{r,p_l\}}+\lfloor(\min\{r,p_l\}-l)/2+1\rfloor

解释:使用做差的方法计算出 [\min\{r,p_l\}+1,r] 这一段的最大点权独立集的值,再加上 [l,\min\{r,p_l\}] 这个全 1 段的贡献。

最后讲讲怎么建虚树。维护 F_i 表示第 i 列从最下面往上走走到的关键点的编号。当 (i,i) 位置往 (i-1,a_i) 连边时,(i-1,a_i) 变成了第 a_i 列最靠下的关键点,作为虚树的一员它首先要和 F_{a_i} 连上边,然后更新 F_{a_i} 的值。因为最后一行全部都会是关键点,所以最后再将它们变成关键点与 F_i 连边即可。

于是这题就做完了。

参考代码

#include<bits/stdc++.h>
#define up(l, r, i) for(int i = l, END##i = r;i <= END##i;++ i)
#define dn(r, l, i) for(int i = r, END##i = l;i >= END##i;-- i)
using namespace std;
typedef long long i64;
const i64 INF = 1e18;
int n;
int qread(){
    int w=1,c,ret;
    while((c = getchar()) >  '9' || c <  '0') w = (c == '-' ? -1 : 1); ret = c - '0';
    while((c = getchar()) >= '0' && c <= '9') ret = ret * 10 + c - '0';
    return ret * w;
}
const int MAXN = 1e6 + 3;
const int MAXM = 2e6 + 3;
int R[MAXN], C[MAXN], A[MAXN], F[MAXN], G[MAXM], W[MAXM], o = 0;
int X[MAXM]; i64 U[MAXM][4];
int Y[MAXM]; i64 V[MAXM][4];
int P0[MAXN], Q0[MAXN], P1[MAXN], Q1[MAXN];
int value(int w){return w % 2 == 1 ? w / 2 + 1 : w / 2;}
void calc(int l, int r, i64 &w, bool t){    // [l, r] 区间
    if(r - l + 1 ==  0){w = 0   ; return;}
    if(r - l + 1 == -1){w = 0   ; return;}
    if(r - l + 1 == -2){w = -INF; return;}
    if(t == false){
        int u = min(Q0[l], r); w = P0[r] - P0[u] + ( R[l] == 1 ? value(u - l + 1) : 0);
    } else {
        int u = min(Q1[l], r); w = P1[r] - P1[u] + (!R[l] == 1 ? value(u - l + 1) : 0);
    }
}
void calc(int l, int r, i64 O[4], bool t){  // [l + (0~1), r - (0~1)] 区间
    calc(l    , r    , O[0b11], t);
    calc(l    , r - 1, O[0b10], t);
    calc(l + 1, r    , O[0b01], t);
    calc(l + 1, r - 1, O[0b00], t);
}
i64 I[MAXM], J[MAXM];   // I 是必须选上,J 是必须不选
void dfs(int u){
    if(X[u] == 0) I[u] = W[u], J[u] = 0; else {
        int l = X[u], r = Y[u]; dfs(l), dfs(r);
        I[u] = W[u]
            + max(U[u][0b00] + I[l], U[u][0b01] + J[l])
            + max(V[u][0b00] + I[r], V[u][0b01] + J[r]);
        J[u] =
            + max(U[u][0b10] + I[l], U[u][0b11] + J[l])
            + max(V[u][0b10] + I[r], V[u][0b11] + J[r]);
    }
}
int main(){ 
    n = qread();
    up(1, n, i) R[i] = qread();
    up(1, n, i) C[i] = qread();
    up(2, n, i) A[i] = qread();
    P0[1] = R[1], P1[1] = !R[1];
    up(2, n, i){
        P0[i] = max(P0[i - 1], P0[i - 2] +  R[i]);
        P1[i] = max(P1[i - 1], P1[i - 2] + !R[i]);
    }
    Q0[n] = Q1[n] = n;
    dn(n - 1, 1, i){
        if( R[i] == 0) Q0[i] = i; else
            if( R[i + 1] == 0) Q0[i] = i; else Q0[i] = Q0[i + 1];
        if(!R[i] == 0) Q1[i] = i; else
            if(!R[i + 1] == 0) Q1[i] = i; else Q1[i] = Q1[i + 1];
    }
    up(1, n - 1, i){
        int t = A[i + 1], f = F[t]; // (i, t) 是特殊点
        F[t] = F[i + 1] = ++ o;     // 给特殊点分配编号
        if(f)
            if(X[f]) Y[f] = o, calc(G[f] + 1, i - 1, V[f], C[t]);
            else     X[f] = o, calc(G[f] + 1, i - 1, U[f], C[t]);
        W[o] = R[i] ^ C[t], G[o] = i;
    }
    up(1, n, i){    // 最后一层都是特殊点
        int t = i, f = F[t]; ++ o, W[o] = R[n] ^ C[i];
        if(f)
            if(X[f]) Y[f] = o, calc(G[f] + 1, n - 1, V[f], C[t]);
            else     X[f] = o, calc(G[f] + 1, n - 1, U[f], C[t]);
    }
    dfs(1); printf("%lld\n", max(I[1], J[1]));
    return 0;
}