P1933 [NOI2010]旅行路线
ren482933891 · · 题解
先吐槽一下
我做这一题的时候,还只做过一道插头dp的模板题。然而我仍然看出了这是一道插头dp,但之后就不会做了。于是我上网搜题解。结果
一篇题解都搜不到
没有办法,只能自己动手,丰衣足食
于是我用一个下午复习插头dp,做了两道典型例题,然后第二天上午用半天时间A掉了这道目前通过数6的题目(还有两个是输出答案的)。然后激动之余专门开通了洛谷博客写本题题解
事实上,找到性质之后这题也是比较套路的那种插头dp了
下面开始说正话
状态记录
题意可以转化为在方格图上填
显然,需要知道轮廓线上的n+1个插头的状态,还要知道轮廓线上填的数是什么(才知道现在该填什么),然后这样仍然无法转移,还需要知道到目前为止哪些数已经填过了。总共需要记录这三个东西。
然而由于多了后面那个150位的二进制位来表示每一个数填没填,这些状态无法压到一个int或long long里,而且显然这个状态有点多,跟暴力没什么区别了。
所以需要一个重要的性质
如果两个状态轮廓线上的插头完全相同,轮廓线上填的数也完全相同,那么他们已经填过的数的集合也是相同的
也就是说只要前两个东西就可以区分状态了
这个性质大概可以理解为,轮廓线已经完全确定了,又要保证是一条路径,那么两边该填哪些数其实也确定了。(反正我没有举出来反例,就认为他是对的了)
那么就可以记录状态了。
定义3种插头,0表示无插头(不向后延伸),1表示向递减的数延伸,2表示向递增的数延伸
我们需要n个150以内的数,每个数可以用8个二进制位记录
我们需要n+1个0/1/2的数,每个数可以用2个二进制位记录
正好可以压在一个32位整形里
然后就可以每次解码编码,存在哈希表里了,用比较套路的方法
每个状态还要开一个150位的bitset,表示每个数是否填过,为了后面可以转移,已经说过,每个状态只会有一个bitset
转移
一般插头dp转移都要大力讨论,此题也不例外 但本题的讨论过于繁琐,很容易重复或遗留因此可以换一种方法
先大概考虑一下当前可以填哪些数,记录在一个数组里,去重
然后一一考虑,判断是否合法。判断合法这个过程相对简单,大力continue就可以了,想到一个条件就continue一下
考虑完填那个数,之后考虑新的插头长什么样。同样还是枚举所有可能的插头,然后大力continue
这部分我为了避免没有考虑的情况,写的比较冗杂,应该有很多条件都是可以精简的。但如果不考虑代码的精致程度的话,只要无脑判断所有非法情况就可以了
如果当前插头也合法,那么转移即可
因此实际上还是很套路的。。。
下面是代码,可读性应该还是很高的
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5, maxm = 55, maxl = 155, mod = 11192869, mo = 500000;
typedef long long ll;
typedef unsigned int uint;
int a[maxn][maxm], L[maxl];
int n, m, pos[maxn], plug[maxn], head[2][500010], tot[2], cur, pre, chc[maxl], ans;
struct State{
bitset<maxl> used;
uint stt;
int val, nxt;
State() { val = nxt = stt = 0; used.reset(); }
}ptr[2][1001000];
void hah(uint stt, int val, bitset<maxl> &used) {
int x = stt % mo;
for(int i = head[cur][x]; i; i = ptr[cur][i].nxt) if(ptr[cur][i].stt == stt) {
ptr[cur][i].val = (ptr[cur][i].val + val) % mod; return;
}
ptr[cur][++tot[cur]].stt = stt;
ptr[cur][tot[cur]].val = val;
ptr[cur][tot[cur]].used = used;
ptr[cur][tot[cur]].nxt = head[cur][x];
head[cur][x] = tot[cur];
}
uint encode() {
uint stt = 0;
for(int i = 1; i <= n; i++) stt = (stt << 8) + pos[i];
for(int i = 0; i <= n; i++) stt = (stt << 2) + plug[i];
return stt;
}
void decode(uint stt) {
for(int i = n; i >= 0; i--) plug[i] = stt & 3, stt >>= 2;
for(int i = n; i; i--) pos[i] = stt & 255, stt >>= 8;
}
void solve() {
bitset<maxl> used;
used.reset();
cur = 0; pre = 1; hah(0, 1, used);
for(int j = 1; j <= m; j++) {
// 新的一行要把plug整体右移
for(int t = 1; t <= tot[cur]; t++) {
decode(ptr[cur][t].stt);
for(int i = n - 1; i >= 0; i--) plug[i + 1] = plug[i];
plug[0] = 0;
ptr[cur][t].stt = encode();
}
for(int i = 1; i <= n; i++) {
swap(cur, pre); tot[cur] = 0;
memset(head[cur], 0, sizeof(head[cur]));
for(int t = 1; t <= tot[pre]; t++) {
uint stt = ptr[pre][t].stt;
int val = ptr[pre][t].val;
used = ptr[pre][t].used;
decode(stt);
int r = plug[i - 1], d = plug[i];
int cnt = 0;
if(!r && !d) for(int i = 1; i <= n * m; i++) chc[++cnt] = i;
else {
if(r == 1) chc[++cnt] = pos[i-1] - 1;
else if(r == 2) chc[++cnt] = pos[i-1] + 1;
if(d == 1) chc[++cnt] = pos[i] - 1;
else if(d == 2) chc[++cnt] = pos[i] + 1;
}
// 当前位置可能会填哪些数
sort(chc + 1, chc + 1 + cnt);
cnt = unique(chc + 1, chc + 1 + cnt) - chc - 1;
for(int hh = 1; hh <= cnt; hh++) {
int x = chc[hh]; // 枚举当前位置填的数,判断是否合法
if(a[i][j] != L[x]) continue; if(used[x]) continue;
if(r == 1 && x != pos[i - 1] - 1) continue;
if(r == 2 && x != pos[i - 1] + 1) continue;
if(d == 1 && x != pos[i] - 1) continue;
if(d == 2 && x != pos[i] + 1) continue;
if(x == 1 && i > 1 && i < n && j > 1 && j < m) continue;
if(i == n && j == m) ans = (ans + val) % mod;
used[x] = 1; int od = pos[i]; pos[i] = x;
// cout << x << endl;
for(int npr = 0; npr <= 2; npr++)
for(int npd = 0; npd <= 2; npd++) {
// 枚举新的插头,判断是否合法,这部分我写的比较冗杂,或许可以精简一下
int pnum = (r > 0) + (d > 0) + (npr > 0) + (npd > 0);
if(x != 1 && x != n * m && pnum != 2) continue;
if((x == 1 || x == n * m) && pnum != 1) continue;
if(npr == npd && npr) continue;
if(j == m && npr) continue; if(i == n && npd) continue;
if((npr == 1 || npd == 1) && used[x - 1]) continue;
if((npr == 2 || npd == 2) && used[x + 1]) continue;
if(npr == 1 && a[i][j+1] != L[x - 1]) continue;
if(npr == 2 && a[i][j+1] != L[x + 1]) continue;
if(npd == 1 && a[i+1][j] != L[x - 1]) continue;
if(npd == 2 && a[i+1][j] != L[x + 1]) continue;
// 当前转移合法,更新下一位置的状态和dp值
plug[i - 1] = npr; plug[i] = npd;
hah(encode(), val, used);
plug[i - 1] = r; plug[i] = d;
}
used[x] = 0; pos[i] = od;
}
}
}
}
}
int main() {
// printf("%lf\n", (double)(&b2-&b1)/1024/1024);
// freopen("trip.in", "r", stdin);
// freopen("trip.out", "w", stdout);
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) for(int j = 1; j <= m; j++) scanf("%d", &a[i][j]);
for(int i = 1; i <= n * m; i++) scanf("%d", &L[i]);
L[0] = L[n * m + 1] = 521;
for(int i = 0; i <= m + 1; i++) a[0][i] = a[n + 1][i] = 233;
for(int i = 1; i <= n; i++) a[i][0] = a[i][m + 1] = 233;
solve();
printf("%d\n", ans);
return 0;
}