题解:P11362 [NOIP2024] 遗失的赋值(民间数据)

· · 题解

柯爱柿子题

题目分析

看到 n \le 10^9m \le 10 ^5,很容易想到分段考虑。

先将数组 c 排序,把 1c_1 分为一段,c_1c_2 分为一段,以此类推到 c_mn

可以发现,若一段区间 [l,r] 满足 l = r 那么这个区间贡献为 1

接下来就是计算剩余每一段的贡献。

假设这段区间为 [l,r]。若 k 使得存在 c_k=l,那么会发现 l 上的数有 v 种选择,而其中 v-1 种选择剩下的 [l,r) 全部不确定,而又因为 r-1 不确定,所以 b_{r-1} 也有 v 种选择,所以总共 v^{2\times(r-l+1)-1} \times (v-1)。 否则 l+1 确定,且又有 v 种选择,于是将区间 l+1,r 继续做一遍上面的计数,将答案多乘一个 v

重复上面的步骤直到区间 [r-1,r]。我们发现,若存在 k 使得 c_k=r,因为 a_{r-1} 等于 r-1 上的数,所以贡献为 1,否则贡献为 v

那么我们就发现每一段区间 [l,r] 的答案就是这样:

v^{2\times(r-l)-1} \times (v-1) + v^{2\times(r-l)-2} \times (v-1) + \cdots + v^{2\times(r-l)-r+l} \times (v-1) + f(r)

其中 f(r) 为若存在 k 使得 c_k=r,答案为 v^{2\times(r-l)-r + l - 1},否则贡献为 v^{2\times(r-l)-r+l}

然后赛时就死在了这里。

接着我们发现可以将 f(r) 前面的那坨东西整理下,就变成了这样:

(v-1)\times(v^{2\times(r-l)-1} + v^{2\times(r-l)-2} + \cdots + v^{2\times(r-l)-r+l} )+ f(r)

我们发现一个公式:

$x\times S(x,n) = x^1+x^2+x^3+\cdots+x^{n+1}$。 $x\times S(x,n)-S(x,n) = x^{n+1}-1 S(x,n) = \frac{x^{n+1}-1}{x-1}

于是答案便成为了:

(S(v,2\times(r-l)-1)-S(v,2\times(r-l)-r+l-1))\times(v-1) + f(r)

最后将所有区间的贡献乘起来即可,时间复杂度 O(Tmlog_2n),注意要把 [1,c_1] 这个区间要特殊判断。

代码

#include <bits/stdc++.h>
namespace FAST_IO {
const int LEN = 1 << 18;
char BUF[LEN], PUF[LEN], space = ' ', line = '\n';
int Pin = LEN, Pout;
inline void flushin() {
    memcpy(BUF, BUF + Pin, LEN - Pin), fread(BUF + LEN - Pin, 1, Pin, stdin),
        Pin = 0;
    return;
}
inline void flushout() {
    fwrite(PUF, 1, Pout, stdout), Pout = 0;
    return;
}
inline char Getc() {
    return (Pin == LEN ? (fread(BUF, 1, LEN, stdin), Pin = 0) : 0), BUF[Pin++];
}
inline char Get() { return BUF[Pin++]; }
inline void Putc(char x) {
    if (Pout == LEN)
        flushout(), Pout = 0;
    PUF[Pout++] = x;
}
inline void Put(char x) { PUF[Pout++] = x; }
template <typename tp = int> inline void read(int &X) {
    (Pin + 32 >= LEN) ? flushin() : void();
    tp res = 0;
    char f = 1, ch = ' ';
    for (; ch < '0' || ch > '9'; ch = Get())
        if (ch == '-')
            f = -1;
    for (; ch >= '0' && ch <= '9'; ch = Get())
        res = (res << 3) + (res << 1) + ch - 48;
    X = res * f;
    return;
}
template <typename tp = char> inline void read(char &X) {
    X = Getc();
    return;
}
template <typename tp> inline void wt(tp a) {
    if (a > 9)
        wt(a / 10);
    Put(a % 10 + '0');
    return;
}
template <typename tp = int> inline void write(int a) {
    static int stk[20], top;
    (Pout + 32 >= LEN) ? flushout() : void();
    if (a < 0)
        Put('-'), a = -a;
    else if (a == 0)
        Put('0');
    for (top = 0; a; a /= 10)
        stk[++top] = a % 10;
    for (; top; --top)
        Put(stk[top] ^ 48);
    return;
}
template <typename tp = char> inline void write(char a) {
    Put(a);
    return;
}
template <typename T, typename... Args> void read(T &tmp, Args &...tmps) {
    read(tmp), read(tmps...);
}
template <typename T, typename... Args> void write(T &tmp, Args &...tmps) {
    write(tmp), write(tmps...);
}
} // namespace FAST_IO
using namespace FAST_IO;
const int mod = 1e9 + 7;
int t, n, m, v, ans, lst, dth;
struct st {
    int c, d;
} a[100002];
int qpow(int x, int y, int mod) {
    if (y < 1)
        return 1;
    int tmp = qpow(x, y / 2, mod);
    if (y & 1)
        return 1ll * x * tmp % mod * tmp % mod;
    return 1ll * tmp * tmp % mod;
}
inline int inv(int x) { return qpow(x, mod - 2, mod); }
inline int S(int x, int n) {
    return 1ll * (qpow(x, n + 1, mod) - 1) % mod * dth % mod;
}
void solve(int l, int r) {
    int val = S(v, 2 * (r - l) - 1) - S(v, 2 * (r - l) - r + l - 1);
    val = (val + mod) % mod;
    val = 1ll * val * (v - 1) % mod;
    val = (val + qpow(v, 2 * (r - l) - r + l, mod)) % mod;
    ans = 1ll * ans * val % mod;
}
void sol(int l, int r) {
    int val = S(v, 2 * (r - l) - 1) - S(v, 2 * (r - l) - r + l - 1);
    val = (val + mod) % mod;
    val = 1ll * val * (v - 1) % mod;
    val = (val + qpow(v, 2 * (r - l) - r + l - 1, mod)) % mod;
    ans = 1ll * ans * val % mod;
}
signed main() {
    read(t);
    for (; t--;) {
        read(n, m, v), ans = lst = 1, dth = inv(v - 1);
        for (int i = 1; i <= m; i++)
            read(a[i].c, a[i].d);
        std::sort(a + 1, a + 1 + m, [](st x, st y) { return x.c < y.c; });
        for (int i = 1; i <= m; i++)
            if (a[i].c == a[i - 1].c && a[i].d != a[i - 1].d)
                ans = 0;
        if (!ans) {
            puts("0");
            continue;
        }
        for (int i = 1; i <= m; i++) {
            if (a[i].c == a[i - 1].c)
                continue;
            if (lst == 1 && i == 1 && a[i].c != 1)
                solve(lst, a[i].c);
            else if (a[i].c != 1)
                sol(lst, a[i].c);
            lst = a[i].c;
        }
        if (a[m].c != n)
            solve(a[m].c, n);
        printf("%d\n", ans);
    }
    return 0;
}
/*
    1 0 0 0 1
(v-1)*v*v*v*v*v*v*v
(v-1)*v*v*v*v*v*v*v
(v-1)*v*v*v*v*v*v*v
(v-1)*v*v*v*v*v*v*v
*/

注意需要略微卡常。