题解:P9848 [ICPC2021 Nanjing R] Cloud Retainer's Game

· · 题解

简要题意

平面直角坐标系上有两个无限长的直线镜子,一个是 y=0,一个是 y=H

m 个硬币,第 i 个硬币位于 (x_i^{'},y_i^{'})。有 n 个可调节盒子,第 i 个盒子位于 (x_i,y_i)。你可以将一个可调节盒子调节为镜子模式或默认模式。

有一束光从 (0,0) 处以 \frac{1}{4}\pi 弧度射出,光会穿过硬币和默认模式的盒子,碰到镜子或镜子模式的盒子后,会发生反射(遵循反射定律)。

你需要最大化光线碰到的硬币数量。

T$ 组数据。$1\leq n,m\leq 10^5,1\leq \sum n,\sum m \leq 5 \times 10^5,2\leq H\leq 10^9,1\leq x_i,x^{'}_i\leq 10^9,1\leq y_i,y^{'}_i\leq H

代码

理清题意后其实是一个简单题。

首先有一个平凡的 dp,即 f(i,j,k) 表示点 (i,j) 开始(包含这个点),光线方向为 k\frac{1}{4}\pi\frac{7}{4}\pi)。转移就是如果这个点是硬币,那么找到后继节点然后加上 1,如果这个点是盒子,枚举这个盒子是哪一种盒子分别找后继即可。

这个转移没有问题,关键在于如何找后继。

我们可以让 y 轴上的某一个点 (0,i)\frac{1}{4}\pi\frac{7}{4}\pi 弧度发射的光线,为标准光线(考虑镜子,不考虑镜子模式的盒子),则发现无论如何反射,得到的光线都是由若干段标准光线得到的。

同样我们可以求出经过一个点以一定弧度的光线的标准光线是什么。所以我们可以对标准光线开一个 map,里面存 vector,按照 x 升序存储这条光线可以经过的点的坐标,找后继的时候在上面二分即可。

时间复杂度单次 O(n\log n) 可以通过本题。

代码

由于一开始代码没有构思好,写得比较丑,而且常数不小,反正能过就行了。

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N = 1e5 + 5;
int H, n, m;

struct position {
    int x, y;
    bool operator<(const position& rhs) const { return x == rhs.x ? y < rhs.y : x < rhs.x; }
} p[N][2];

struct line {
    int k, b;
    bool operator<(const line& rhs) const { return k == rhs.k ? b < rhs.b : k < rhs.k; }
    bool operator==(const line& rhs) const { return k == rhs.k && b == rhs.b; }
};

map<line,vector<pair<int,int> > > mp;
int f[N][2][2];

auto cmp = [](const pair<int,int>& lhs, const pair<int,int>& rhs){ return p[lhs.first][lhs.second] < p[rhs.first][rhs.second]; };

line line1(position pos){
    if(pos.x < pos.y) return line{1, pos.y - pos.x};
    int x = (pos.x - pos.y) % (H << 1);
    if(x == 0) return line{1, 0};
    if(x <= H) return line{0, x};
    else return line{1, ((H << 1) - x)};
}

line line2(position pos){
    int x = (pos.x + pos.y) % (H << 1);
    if(x == 0) return line{1, 0};
    if(x <= H) return line{0, x};
    else return line{1, ((H << 1) - x)};
}

int dp(int i, int j, int k){// p[i][j], direction: k: line1(1) / line2(0)
    if(f[i][j][k] != -1) return f[i][j][k];
    if(j){// coin
        line l;
        if(k) l = line1(p[i][j]);
        else l = line2(p[i][j]);
        auto ite = upper_bound(mp[l].begin(), mp[l].end(), make_pair(i, j), cmp);
        if(ite == mp[l].end()) return f[i][j][k] = 1;
        if(line1(p[ite->first][ite->second]) == l) return f[i][j][k] = dp(ite->first, ite->second, 1) + 1;
        else return f[i][j][k] = dp(ite->first, ite->second, 0) + 1;
    }
    else{// board
        line l1, l2;
        if(k) l1 = line1(p[i][j]), l2 = line2(p[i][j]);
        else l1 = line2(p[i][j]), l2 = line1(p[i][j]);
        auto ite1 = upper_bound(mp[l1].begin(), mp[l1].end(), make_pair(i, j), cmp);
        auto ite2 = upper_bound(mp[l2].begin(), mp[l2].end(), make_pair(i, j), cmp);
        f[i][j][k] = 0;
        if(ite1 != mp[l1].end()){
            if(line1(p[ite1->first][ite1->second]) == l1) f[i][j][k] = max(f[i][j][k], dp(ite1->first, ite1->second, 1));
            else f[i][j][k] = max(f[i][j][k], dp(ite1->first, ite1->second, 0));
        }
        if(ite2 != mp[l2].end()){
            if(line1(p[ite2->first][ite2->second]) == l2) f[i][j][k] = max(f[i][j][k], dp(ite2->first, ite2->second, 1));
            else f[i][j][k] = max(f[i][j][k], dp(ite2->first, ite2->second, 0));
        }
        return f[i][j][k];
    }
}

void solve(){
    cin >> H >> n;
    p[0][1] = {0, 0}; mp[{1, 0}].push_back(make_pair(0, 1));
    for(int i=1;i<=n;i++){
        cin >> p[i][0].x >> p[i][0].y;
        mp[line1(p[i][0])].emplace_back(make_pair(i, 0));
        mp[line2(p[i][0])].emplace_back(make_pair(i, 0));
    }
    cin >> m;
    for(int i=1;i<=m;i++){
        cin >> p[i][1].x >> p[i][1].y;
        mp[line1(p[i][1])].emplace_back(make_pair(i, 1));
        mp[line2(p[i][1])].emplace_back(make_pair(i, 1));
    }
    for(auto& i : mp) sort(i.second.begin(), i.second.end(), cmp);
    for(int i=0;i<=max(n, m);i++){
        for(int j=0;j<=1;j++) f[i][j][0] = f[i][j][1] = -1;
    }
    cout << (dp(0, 1, 1) - 1) << '\n';
    mp.clear();
}

signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    int t; cin >> t;
    while(t--) solve();
    return 0;
}

// Written by xiezheyuan