题解:CF963E Circles of Waiting

· · 题解

分析

第一篇黑题题解寄。

提个醒,输入的是 R,a_1,a_2,a_3,a_4,而 p_1,p_2,p_3,p_4 要分别求出。

首先考虑列出方程。设 f_{i,j} 表示在 (i,j) 这个点移动至离原点的距离大于 R 的点的期望步数,则可以得到 f_{i,j}=f_{i-1,j}\times p_1+f_{i,j-1}\times p_2+f_{i+1,j}\times p_3+f_{i,j+1}\times p_4+1(i^2+j^2\le R^2),也即 f_{i,j}-p_1\times f_{i-1,j}-p_2\times f_{i,j-1}-p_3\times f_{i+1,j}-p_4\times f_{i,j+1},由于这个转移有后效性,所以不能用动规,考虑使用高斯消元。

在平面直角坐标系中,若某个点对 (x,y)|x|>R 或者 |y|>R,则这个点对与原点的距离一定大于 R。所以,只需在满足 |x|,|y|\le R 的数对 (x,y) 中找出满足题意的数对即可,设 tot 为满足题意数对的个数。输出极端情况可以发现,当 R=50tot7845

将每个满足题意的点对按照从左往右,从上往下的顺序依次标号,那么得到的方程的个数也是 tot,使用普通的高斯消元法,复杂度是立方级别的,显然会超时,所以考虑优化。

由于每个 f_{i,j} 只与 f_{i-1,j},f_{i,j-1},f_{i+1,j},f_{i,j+1} 有关,而 (i-1,j),(i,j-1),(i+1,j),(i,j+1) 这四个点所对应的编号与 (i,j) 的编号之差分别不会超过 2\times R,1,2\times R,1,也就是在方程组的系数矩阵中,每一行最多会有连续的 2\times R+1 个数有值。这便成了带状高斯消元。

那么如何快速的将系数矩阵化为上三角呢?可以按照从左到右的顺序将每一列的系数消去,假设目前要消去第 k 列的系数,则用第 k 行消元。因为第 k 行第 k 列所对应的数位于主对角线上,所以它一定是有值的;而第 k 行前 k 列的数都在此前被消去了,所以这样做,既能将第 k 列的系数消去,又不会影响前 k 列。

下面是代码(有些卡常):

#include <bits/stdc++.h>
using namespace std;
int mod = 1e9 + 7;
int o = 1;

int qpow(int a, int b) {
    int ans = 1;
    while (b) {
        if (b & 1)
            ans = (int)(1LL * ans * a % mod);
        a = (int)(1LL * a * a % mod);
        b >>= 1;
    }
    return ans;
}
int bi[105][105];
int ii[8005], jj[8005];
int a[7846][7846];
int f[7846];

signed main() {
    int r, a1, a2, a3, a4;
    scanf("%d", &r);
    scanf("%d", &a1);
    scanf("%d", &a2);
    scanf("%d", &a3);
    scanf("%d", &a4);
    int x = a1 + a2 + a3 + a4;
    x = qpow(x, mod - 2);
    a1 = (int)(1LL * a1 * x % mod);
    a2 = (int)(1LL * a2 * x % mod);
    a3 = (int)(1LL * a3 * x % mod);
    a4 = (int)(1LL * a4 * x % mod);
    int tot = 0;
    for (int i = -r; i <= r; i++) {
        for (int j = -r; j <= r; j++) {
            if (i * i + j * j <= r * r) {
                bi[i + r][j + r] = ++tot;
                ii[tot] = i, jj[tot] = j;
            }
        }
    }
    //tot最大7845
    for (int i = 1; i <= tot; i++) {
        a[i][i] = 1;
        a[i][0] = 1;
        int ni = ii[i], nj = jj[i];
        if (bi[ni - 1 + r][nj + r])
            a[i][bi[ni - 1 + r][nj + r]] = ((-a1) % mod + mod) % mod;
        if (bi[ni + r][nj - 1 + r])
            a[i][bi[ni + r][nj - 1 + r]] = ((-a2) % mod + mod) % mod;
        if (bi[ni + 1 + r][nj + r])
            a[i][bi[ni + 1 + r][nj + r]] = ((-a3) % mod + mod) % mod;
        if (bi[ni + r][nj + 1 + r])
            a[i][bi[ni + r][nj + 1 + r]] = ((-a4) % mod + mod) % mod;
    }
    for (int i = 1; i <= tot; i++) {
        int y = min(tot, i + 2 * r);
        for (int j = i + 1; j <= y; j++) {//要满足 j-2*r<=i,也即 j<=i+2*r
            int x = (int)(1LL * a[j][i] * qpow(a[i][i], mod - 2) % mod);
            int u = max(o, j - 2 * r), v = min(tot, j + 2 * r);
            for (int k = u; k <= v; k++) {
                a[j][k] = (a[j][k] - 1LL * x * a[i][k]) % mod;
                if (a[j][k] < 0)
                    a[j][k] += mod;
            }
            a[j][0] = (a[j][0] - 1LL * x * a[i][0]) % mod;
            if (a[j][0] < mod)
                a[j][0] += mod;
        }
    }
    for (int i = tot; i >= 1; i--) {
        for (int j = i + 1; j <= tot; j++)
            a[i][0] = (int)((a[i][0] - 1LL * a[i][j] * f[j] % mod)) % mod;
        f[i] = (int)(1LL * a[i][0] * qpow(a[i][i], mod - 2) % mod);
        f[i] = (f[i] % mod + mod) % mod;
    }
    cout << f[bi[r][r]];
}