P11362 [NOIP2024] 遗失的赋值 题解

· · 题解

Sol

提供一种题解区没有的做法。

本年所有比赛中 DP 最光彩的一集。

下文中所有的下标 i,j 均表示第 i,j 个赋值限制。

考虑 DP,设 f_i 表示满足从 1i 的合法方案数。

显然答案就是 f_m\times k^{2\times (n - c_m)} 次。

考虑如何转移,考虑容斥,所有的方案数为 k^{2\times c_i} 种。

思考得到不合法的状态当且仅当有相邻的赋值限制,使得两个赋值限制不能同时满足。

接下来考虑怎样不重不漏的删去不合法的方案数,对于任意一种方案,我们考虑用最后一个 j 来表示这种方案,使得从 j 开始,使用二元限制后 j + 1 的赋值限制一定不成立。

那么我们可以暴力枚举 j,然后构造一种方案使其强制不能满足第 j+1 个赋值限制,并且 j+1\sim i 之间的二元限制应当合法。

不难推出方案数为 f_j\times k^{c_{j+1} - c_j - 1}\times (k-1)\times k^{2\times (c_i-c_{j+1})}

那么转移方程就是 f_i = \displaystyle \sum_{j=1}^{i-1}f_j\times k^{c_{j+1} - c_j - 1}\times (k-1)\times k^{2\times (c_i-c_{j+1})}=\sum_{j=1}^{i-1}(f_j\times k^{c_{j+1}-c_j-1-2\times c_{j+1}})\times (k-1)\times k^{2\times c_i}

朴素转移是 O(n^2) 的。

如果你仔细观察这个式子,那么你可以发现,对于一个固定的 i,第一个括号内的所有内容都是与 j 相关的,其余的都是常量,那么就可以前缀和优化 DP 来实现。

Code

赛时代码。

#include <bits/stdc++.h>
#define x first
#define y second
#define pb push_back
#define pf push_front
#define IOS ios :: sync_with_stdio (false),cin.tie (0),cout.tie (0)
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair <int,int> PII;
template <typename T1,typename T2> void tomax (T1 &x,T2 y) {
    if (y > x) x = y;
}
template <typename T1,typename T2> void tomin (T1 &x,T2 y) {
    if (y < x) x = y;
}
int fastio = (IOS,0);
#define endl '\n'
#define puts(s) cout << s << endl
const int N = 100010,MOD = 1e9 + 7;
LL n,k;
int m;
PII a[N];
LL f[N];
LL power (LL a,LL b,LL p) {
    LL ans = 1;
    while (b) {
        if (b & 1) ans = ans * a % p;
        a = a * a % p;
        b >>= 1;
    }
    return ans;
}
int main () {
    freopen ("assign.in","r",stdin);
    freopen ("assign.out","w",stdout);
    int T;
    cin >> T;
    while (T--) {
        cin >> n >> m >> k;
        for (int i = 1;i <= m;i++) cin >> a[i].x >> a[i].y;
        sort (a + 1,a + m + 1);
        bool flag = 0;
        for (int i = 1;i <= m - 1;i++) {
            if (a[i].x == a[i + 1].x && a[i].y != a[i + 1].y) {
                flag = 1;
                break;
            }
        }
        if (flag) {
            puts ("0");
            continue;
        }
        m = unique (a + 1,a + m + 1) - a - 1;
        // for (int i = 1;i <= m;i++) {
        //     f[i] = power (k * k % MOD,a[i].x - 1,MOD);
        //     for (int j = 1;j < i;j++) {
        //         LL w = power (k,a[j + 1].x - a[j].x - 1,MOD) * (k - 1) % MOD;
        //         f[i] = (f[i] - f[j] * w % MOD * power (k * k,a[i].x - a[j + 1].x,MOD) % MOD + MOD) % MOD;
        //     }
        // }
        LL sum = 0;
        for (int i = 1;i <= m;i++) {
            f[i] = power (k * k % MOD,a[i].x - 1,MOD);
            f[i] = (f[i] - sum * power (k * k % MOD,a[i].x,MOD) % MOD + MOD) % MOD;
            if (i + 1 <= m) {
                sum = (sum + f[i] * (power (k,a[i + 1].x - a[i].x - 1,MOD) * (k - 1) % MOD) % MOD * power (power (k * k % MOD,a[i + 1].x,MOD),MOD - 2,MOD) % MOD) % MOD;
            }
        }
        cout << f[m] * power (k * k % MOD,n - a[m].x,MOD) % MOD << endl;
    }
    return 0;
}