P1933 [NOI2010]旅行路线

· · 题解

先吐槽一下

我做这一题的时候,还只做过一道插头dp的模板题。然而我仍然看出了这是一道插头dp,但之后就不会做了。于是我上网搜题解。结果

一篇题解都搜不到

没有办法,只能自己动手,丰衣足食

于是我用一个下午复习插头dp,做了两道典型例题,然后第二天上午用半天时间A掉了这道目前通过数6的题目(还有两个是输出答案的)。然后激动之余专门开通了洛谷博客写本题题解

事实上,找到性质之后这题也是比较套路的那种插头dp了

下面开始说正话

状态记录

题意可以转化为在方格图上填 1...n\times m 这些数,使其连成一条路径,并且1必须填在边缘的方案数,考虑如何记录状态才能转移

显然,需要知道轮廓线上的n+1个插头的状态,还要知道轮廓线上填的数是什么(才知道现在该填什么),然后这样仍然无法转移,还需要知道到目前为止哪些数已经填过了。总共需要记录这三个东西。

然而由于多了后面那个150位的二进制位来表示每一个数填没填,这些状态无法压到一个int或long long里,而且显然这个状态有点多,跟暴力没什么区别了。

所以需要一个重要的性质

如果两个状态轮廓线上的插头完全相同,轮廓线上填的数也完全相同,那么他们已经填过的数的集合也是相同的

也就是说只要前两个东西就可以区分状态了

这个性质大概可以理解为,轮廓线已经完全确定了,又要保证是一条路径,那么两边该填哪些数其实也确定了。(反正我没有举出来反例,就认为他是对的了)

那么就可以记录状态了。

定义3种插头,0表示无插头(不向后延伸),1表示向递减的数延伸,2表示向递增的数延伸

我们需要n个150以内的数,每个数可以用8个二进制位记录

我们需要n+1个0/1/2的数,每个数可以用2个二进制位记录

正好可以压在一个32位整形里

然后就可以每次解码编码,存在哈希表里了,用比较套路的方法

每个状态还要开一个150位的bitset,表示每个数是否填过,为了后面可以转移,已经说过,每个状态只会有一个bitset

转移

一般插头dp转移都要大力讨论,此题也不例外 但本题的讨论过于繁琐,很容易重复或遗留因此可以换一种方法

先大概考虑一下当前可以填哪些数,记录在一个数组里,去重

然后一一考虑,判断是否合法。判断合法这个过程相对简单,大力continue就可以了,想到一个条件就continue一下

考虑完填那个数,之后考虑新的插头长什么样。同样还是枚举所有可能的插头,然后大力continue

这部分我为了避免没有考虑的情况,写的比较冗杂,应该有很多条件都是可以精简的。但如果不考虑代码的精致程度的话,只要无脑判断所有非法情况就可以了

如果当前插头也合法,那么转移即可

因此实际上还是很套路的。。。

下面是代码,可读性应该还是很高的

#include <bits/stdc++.h>
using namespace std;
const int maxn = 5, maxm = 55, maxl = 155, mod = 11192869, mo = 500000;
typedef long long ll;
typedef unsigned int uint;
int a[maxn][maxm], L[maxl];
int n, m, pos[maxn], plug[maxn], head[2][500010], tot[2], cur, pre, chc[maxl], ans;
struct State{
    bitset<maxl> used;
    uint stt;
    int val, nxt;
    State() { val = nxt = stt = 0; used.reset(); }
}ptr[2][1001000];
void hah(uint stt, int val, bitset<maxl> &used) {
    int x = stt % mo;
    for(int i = head[cur][x]; i; i = ptr[cur][i].nxt) if(ptr[cur][i].stt == stt) {
        ptr[cur][i].val = (ptr[cur][i].val + val) % mod; return;
    }
    ptr[cur][++tot[cur]].stt = stt;
    ptr[cur][tot[cur]].val = val;
    ptr[cur][tot[cur]].used = used;
    ptr[cur][tot[cur]].nxt = head[cur][x];
    head[cur][x] = tot[cur];
}
uint encode() {
    uint stt = 0;
    for(int i = 1; i <= n; i++) stt = (stt << 8) + pos[i];
    for(int i = 0; i <= n; i++) stt = (stt << 2) + plug[i];
    return stt;
}
void decode(uint stt) {
    for(int i = n; i >= 0; i--) plug[i] = stt & 3, stt >>= 2;
    for(int i = n; i; i--) pos[i] = stt & 255, stt >>= 8; 
}
void solve() {
    bitset<maxl> used;
    used.reset();
    cur = 0; pre = 1; hah(0, 1, used);
    for(int j = 1; j <= m; j++) {
        // 新的一行要把plug整体右移
        for(int t = 1; t <= tot[cur]; t++) {
            decode(ptr[cur][t].stt);
            for(int i = n - 1; i >= 0; i--) plug[i + 1] = plug[i]; 
            plug[0] = 0;
            ptr[cur][t].stt = encode();
        }
        for(int i = 1; i <= n; i++) {
            swap(cur, pre); tot[cur] = 0;
            memset(head[cur], 0, sizeof(head[cur]));
            for(int t = 1; t <= tot[pre]; t++) {
                uint stt = ptr[pre][t].stt; 
                int val = ptr[pre][t].val;
                used = ptr[pre][t].used;
                decode(stt);
                int r = plug[i - 1], d = plug[i];
                int cnt = 0;
                if(!r && !d) for(int i = 1; i <= n * m; i++) chc[++cnt] = i;
                else {
                    if(r == 1) chc[++cnt] = pos[i-1] - 1;
                    else if(r == 2) chc[++cnt] = pos[i-1] + 1;
                    if(d == 1) chc[++cnt] = pos[i] - 1;
                    else if(d == 2) chc[++cnt] = pos[i] + 1;
                }
                // 当前位置可能会填哪些数
                sort(chc + 1, chc + 1 + cnt);
                cnt = unique(chc + 1, chc + 1 + cnt) - chc - 1;
                for(int hh = 1; hh <= cnt; hh++) {
                    int x = chc[hh]; // 枚举当前位置填的数,判断是否合法
                    if(a[i][j] != L[x]) continue; if(used[x]) continue;
                    if(r == 1 && x != pos[i - 1] - 1) continue;
                    if(r == 2 && x != pos[i - 1] + 1) continue;
                    if(d == 1 && x != pos[i] - 1) continue;
                    if(d == 2 && x != pos[i] + 1) continue;
                    if(x == 1 && i > 1 && i < n && j > 1 && j < m) continue;
                    if(i == n && j == m) ans = (ans + val) % mod;
                    used[x] = 1; int od = pos[i]; pos[i] = x;
                    // cout << x << endl;
                    for(int npr = 0; npr <= 2; npr++) 
                        for(int npd = 0; npd <= 2; npd++) {
                            // 枚举新的插头,判断是否合法,这部分我写的比较冗杂,或许可以精简一下
                            int pnum = (r > 0) + (d > 0) + (npr > 0) + (npd > 0);
                            if(x != 1 && x != n * m && pnum != 2) continue;
                            if((x == 1 || x == n * m) && pnum != 1) continue;
                            if(npr == npd && npr) continue;
                            if(j == m && npr) continue; if(i == n && npd) continue;
                            if((npr == 1 || npd == 1) && used[x - 1]) continue;
                            if((npr == 2 || npd == 2) && used[x + 1]) continue;
                            if(npr == 1 && a[i][j+1] != L[x - 1]) continue;
                            if(npr == 2 && a[i][j+1] != L[x + 1]) continue;
                            if(npd == 1 && a[i+1][j] != L[x - 1]) continue;
                            if(npd == 2 && a[i+1][j] != L[x + 1]) continue;
                            // 当前转移合法,更新下一位置的状态和dp值
                            plug[i - 1] = npr; plug[i] = npd;
                            hah(encode(), val, used);
                            plug[i - 1] = r; plug[i] = d;
                        }
                    used[x] = 0; pos[i] = od;
                }
            }
        }
    }
}
int main() {
    // printf("%lf\n", (double)(&b2-&b1)/1024/1024);
    // freopen("trip.in", "r", stdin);
    // freopen("trip.out", "w", stdout);
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i++) for(int j = 1; j <= m; j++) scanf("%d", &a[i][j]);
    for(int i = 1; i <= n * m; i++) scanf("%d", &L[i]);
    L[0] = L[n * m + 1] = 521;
    for(int i = 0; i <= m + 1; i++) a[0][i] = a[n + 1][i] = 233;
    for(int i = 1; i <= n; i++) a[i][0] = a[i][m + 1] = 233;
    solve();
    printf("%d\n", ans);
    return 0;
}