[YsOI2020]归零

· · 题解

这波是标程被踩,已经没脸见人了。万人血书 @skydogli 写一份题解。

标签:dp

题意简述

给定数字 S,每次选择一位四舍五入,问数字变成 0 的操作方案数。答案取模。

{\rm{subtask\ 0}}\ (n\le6)

按题意模拟爆搜 / dp。

{\rm{subtask\ 1}}\ (n \le 15)

容易发现 1\sim3 本质相同,简并为 15\sim8 本质相同,简并为 5。所以只有可能出现 5 种数字:0,1,4,5,9。(奇妙的是它们可以组成 114514191910,这一定不是巧合!)

同时,一个数位不可能 5 种情况都存在,最多会有其中的 4 种。意思是说,对于每一个位置,至少有一种数字不可能出现。

预处理出每个数位可能被哪些数字占据。然后状压 dp 即可。

复杂度 O(4^nn),感觉大概跑不满(?)。如果不幸被卡那我谢罪。

{\rm{subtask\ 2}}\ (S_i\in[0,4])

不会产生进位,只需要找到有几个位置不是 0,求个阶乘。

{\rm{subtask\ 3}}\ (S_i=9)

发现【四舍五入】一个 ⑨ 之后,会使前面连续的 9 变成 0,再在最前面加一个 1。这个 1 是与世隔绝的,它不会使别人进位,别人也不能让他进位。

大概很容易想到 dp。设 dp_{i,j} 表示从高到低考虑到第 i 位,且前 i-1 位已经没有 9 了,且剩余 j 个与世隔绝的 1 的方案数。

考虑用记忆化搜索实现。转移每次找一个后面的 9 【四舍五入】或者消除一个与世隔绝的 1 即可。

复杂度 O(n^3),前缀和优化可以做到 O(n^2)

参考代码:

// O(n^2)
#include <bits/stdc++.h>

const int MX = 56;
const int MOD = 998244353;

int dp[MX][MX][2];
char str[233];
int dapai(int x ,int y ,int t){
    if(~dp[x][y][t]) return dp[x][y][t];
    int ret = (x == 0 && y == 0);
    if(y && t) ret = (ret + 1LL * y * dapai(x ,y - 1 ,1)) % MOD;
    if(x){
        if(x != 1) ret = (ret + dapai(x - 1 ,y ,0)) % MOD;
        ret = (ret + dapai(x - 1 ,y + 1 ,1)) % MOD;
    }
    return dp[x][y][t] = ret;
}

int main(){
    memset(dp ,-1 ,sizeof dp);
    init();
    scanf("%s" ,str);
    int n = strlen(str);
    printf("%d\n" ,dapai(n ,0 ,0));
    return 0;
} 

{\rm{subtask\ 4}}\ (S_i\in[5,8])

出题人并不会,他出这个部分分只是觉得好像有什么做法。

验题人给出了一种做法:

有一种复杂度与划分数方案相关的暴力,设 dp[st][i] 为当前划分区间状态为 st,有 i 个与世隔绝的 1 的方案数。

每次枚举【四舍五入】哪个或者选择消去一个 1 即可转移,总复杂度 O(\text{划分数}(n)\times n^2)

本地跑出结果后打表即可。

{\rm{subtask\ 5}}\ (n \le 40)

出题人出这个部分分只为放过复杂度正确但常数大的。

{\rm{subtask\ 6}}\ (n \le 64)

三句话题解:

注意到一个事实:【四舍五入】一个位置之后可以把左右两部分分成子问题。

dp_{l,r,c,1/0} 表示除了 [l,r] 位其它数位都变成了 0,还需【四舍五入】 c 次这个区间也全都变成 0,【四舍五入】第 l-1 位的时候是否对第 l 位造成进位,的方案数。进位后的数字可以直接预处理。

转移只需要枚举当前【四舍五入】哪一位 mid,以及区间 [l,mid-1],[mid+1,r] 如何分配剩余的 c-1 次【四舍五入】即可。 注意有的时候【四舍五入】进位会导致右端点 +1,直接特判掉即可。

复杂度 O(n^5) = 1,073,741,824,剪剪枝/卡卡常即可通过。

剪枝方法:预处理每个区间最多和最少【四舍五入】次数。不满足的区间则不递归。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>

const int MX = 64 + 3;
const int MOD = 998244353;

int dp[MX][MX][MX * 2][2] ,n;
int C[MX * 2][MX * 2];
char str[MX][MX];
int maxmv[MX][MX][MX] ,minmv[MX][MX][MX];

void init(){
    for(int i = 0 ; i < MX * 2 ; ++i) C[i][0] = 1;
    for(int i = 1 ; i < MX * 2 ; ++i)
        for(int j = 1 ; j < MX * 2 ; ++j)
            C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % MOD;
}

int calc(int n ,int m){
    return C[n + m][m];
}

int dapai(int l ,int r ,int c ,int jw){
    if(c < 0) return 0;
    if(l > r) return !c;
    int ind = (!jw) * n + jw * l;
    if(~dp[l][r][c][jw]) return dp[l][r][c][jw];

    if(c == 0) return !maxmv[ind][l][r];

    int cur = 0;
    int book[MX] = {0};
    int ok = 1;
    for(int k = r ; k >= l && ok ; --k){
        book[k] = ok && str[ind][k] >= 5;
        ok = ok && str[ind][k] == 9;
    }
    for(int k = l ; k <= r ; ++k){
        if(str[ind][k] == 0) continue;
        if(book[k] || (k == r && str[ind][k] >= 5)){
            cur = (cur + 1LL * dapai(l ,k - 1 ,c - 2 ,jw) * calc(c - 2 ,1)) % MOD;
            continue;
        }
        int tind = ((str[ind][k] < 5) * n + (str[ind][k] >= 5) * (k + 1));
        for(int s = minmv[ind][l][k - 1] ; s <= c - 1 ; ++s){
            if(s > maxmv[ind][l][k - 1]
            || (c - 1 - s) > maxmv[tind][k + 1][r]
            || (c - 1 - s) < minmv[tind][k + 1][r]) continue;
            cur = (cur + 1LL
                * dapai(l ,k - 1 ,s ,jw) 
                * dapai(k + 1 ,r ,c - s - 1 ,str[ind][k] >= 5) % MOD
                * calc(s ,c - s - 1)) % MOD;
        }
    }
    return dp[l][r][c][jw] = cur;
}

int main(){
    memset(dp ,-1 ,sizeof dp);
    init();
    std::cin >> str[0];
    n = strlen(str[0]);
    std::reverse(str[0] ,str[0] + n);
    str[0][n] = '0' ,n = n + 1;
    for(int i = 0 ; i < n ; ++i) str[0][i] -= '0';
    memcpy(str[n] ,str[0] ,sizeof str[0]);

    for(int i = 0 ; i < n ; ++i){ // Ä£Ä⽫´Ëλ½øÎ»
        memcpy(str[i] ,str[n] ,sizeof str[i]);
        str[i][i]++;
        for(int j = i ; j < n ; ++j){
            str[i][j + 1] += str[i][j] / 10;
            str[i][j] %= 10;
        }
    }

    for(int i = 0 ; i <= n ; ++i){
        for(int l = 0 ; l < n ; ++l){
            for(int r = l ; r < n ; ++r){
                for(int s = r ; s >= l ; --s){
                    maxmv[i][l][r] += (str[i][s] != 0);
                    maxmv[i][l][r] += (str[i][s] >= 5);
                }
                int ret = 0;
                for(int s = l ; s <= r ; ++s){
                    ret += str[i][s];
                    if(ret % 10) minmv[i][l][r]++;
                    if(ret >= 5) ret = 1;
                    else ret = 0;
                }
                if(ret % 10) minmv[i][l][r]++;
            }
        }
    }

    int Ans = 0;
    for(int i = minmv[n][0][n - 1] ; i <= maxmv[n][0][n - 1] ; ++i){
        int ret = dapai(0 ,n - 1 ,i ,0);
        Ans = (Ans + ret) % MOD;
    }
    printf("%d\n" ,Ans);
    return 0;
}