AT_diverta2019_e XOR Partitioning 题解

· · 题解

以下讨论,下标从 0 开始。

先求前缀异或和数组 s

例如 a=[1,3,2,1,3] 的前缀异或和数组 s=[0,1,2,0,1,2]

假设划分位置是 i_0,i_1,i_2,\cdots,i_m

划分出来的每段子数组的异或和相同,等价于

s[0]\oplus s[i_0] = s[i_0] \oplus s[i_1] = s[i_1] \oplus s[i_2] = \cdots = s[i_m]\oplus s[n]

这意味着

\begin{aligned} &s[0] = s[i_1] = s[i_3] = \cdots \\ &s[i_0] = s[i_2] = s[i_4] = \cdots \end{aligned}

相当于从 s 中选择一个交替子序列(注意第一个和最后一个数必须选)。例如 s=[0,1,2,0,1,2] 中可以选 [0,2,0,2],对应划分 [1,3],[2],[1,3]

第一种情况

如果 s[n]> 0,这意味着交替子序列只能是 0,s[n],0,s[n],\cdots

那么定义:

如果 s[i] = 0,则从 i 左边的值为 s[n] 的位置转移过来,即

f[i][0] = f[j_1][1] + f[j_2][1] + \cdots

其中 s[j_1] 满足 j_1 < is[j_1] = s[n],其余 j_2 等同理。

如果 s[i] = s[n],则从 i 左边的值为 0 的位置转移过来,即

f[i][1] = f[j_1][0] + f[j_2][0] + \cdots

其中 s[j_1] 满足 j_1 < is[j_1] = 0,其余 j_2 等同理。

初始值 f[0][0] = 1

答案为 f[n][1]

这样写每次转移是 \mathcal{O}(n) 的。我们可以把 f[j_1][0] + f[j_2][0] + \cdots 记作 s_0,把 f[j_1][1] + f[j_2][1] + \cdots 记作 s_1。通过维护这两个变量的值,就可以做到 \mathcal{O}(1) 转移了。

第二种情况

本题难就难在 s[n]=0 的情况。

首先,我们可以选全为 0 的子序列,这有 2^{\textit{cnt}-2} 种方案,其中 \textit{cnt}s0 的出现次数,-2 是因为 s 的第一个和最后一个数都是 0 且必须选。

然后来讨论交替子序列的个数。

例如 s=[0,1,2,0,1,2,0],此时我们不但可以选 [0,2,0,2,0],还可以选 [0,1,0,1,0]。如果像第一种情况那样 DP,对于 1 我们需要算一遍 DP,对于 2 也需要算一遍 DP。可以预见,在 0 比较多的情况下,总共需要 \mathcal{O}(n^2) 的时间。

那要怎么做?

遇到 0 的时候,「延迟」计算 DP:只在遇到非 0 的时候,才去计算 DP。

例如两个 2 之间有三个 0,那么当遍历到第二个 2 的时候,才去计算关于 s[i]=0 的状态转移。这三个 0 的转移来源是完全一样的,可以一起计算。

那么,怎么知道两个 2 之间有多少个 0 呢?

方法很多,比如在遍历的同时维护 0 的个数 \textit{cnt},在遍历到第一个 2 的时候记录一下 \textit{cnt},在遍历到第二个 2 的时候用当前的 \textit{cnt} 减去上一次记录的 \textit{cnt},就可以知道两个 2 之间有多少个 0 了。

请看代码:

package main
import("bufio";."fmt";"os")

const mod = 1_000_000_007

func main() {
    in := bufio.NewReader(os.Stdin)
    var n, v, xor int
    // f[xor] 相当于只看前缀异或和中的 0 和 xor,求 DP
    f := [1 << 20]struct{ s0, s1, pre0 int }{}
    for i := range f {
        f[i].s0 = 1
    }
    cnt0 := 1 // 前缀异或和的第一个数是 0
    for Fscan(in, &n); n > 0; n-- {
        Fscan(in, &v)
        xor ^= v
        if xor == 0 {
            cnt0++
        } else {
            t := &f[xor]
            // f[i][0] = 一堆 f[j][1] 的和 = s1,这里直接把 f[i][0] 加到 s0 中
            t.s0 = (t.s0 + t.s1*(cnt0-t.pre0)) % mod
            // f[i][1] = 一堆 f[j][0] 的和 = s0,这里直接把 f[i][1] 加到 s1 中
            t.s1 = (t.s1 + t.s0) % mod
            t.pre0 = cnt0
        }
    }
    if xor > 0 {
        // 答案 = f[n][1] = 一堆 f[j][0] 的和 = s0
        // 注意不能写 f[xor].s1,因为前缀异或和的末尾如果有多个 xor,我们只能选一个
        Print(f[xor].s0)
    } else {
        ans := pow(2, cnt0-2) // 只选 0 的方案数
        for _, t := range f {
            // 答案 = f[n][0] = 一堆 f[j][1] 的和 = s1
            // 注意不能写 t.s0,因为前缀异或和的末尾如果有多个 0,我们只能选一个
            ans += t.s1
        }
        Print(ans % mod)
    }
}

func pow(x, n int) (res int) {
    res = 1
    for ; n > 0; n /= 2 {
        if n%2 > 0 {
            res = res * x % mod
        }
        x = x * x % mod
    }
    return
}