【题解】P3706 [SDOI2017]硬币游戏

· · 题解

讲一种不用脑子的做法。

暴力怎么做?建出 ACAM,对每个节点设一个变量,为包含该节点对应字符串的概率,每个点的概率会平均对其后继产生贡献,再套个高斯消元即可。复杂度 O(n^3m^3)

如果你做过 Boring Problem 的话,你会发现相同的套路可以直接搬过来用。我们优化的核心在于减少变量的个数,即仅对每个终止状态设变量,并用这 n 个变量表示其余所有节点。

ACAM 有很好的性质:对所有节点标记上在 Trie 树(不是 fail 树!)上的深度,则所有的转移要么直接指向儿子,要么指向深度小于等于自己的节点。对应到该题,即为所有的前驱要么是 Trie 树上的父亲,要么是深度大于等于自己的节点。按照深度从大到小遍历的话,对每个节点都可以保证,除了父亲以外,所有前驱都已经被遍历。

直接转移的式子为 p=1/2\left(\sum_{x} p_x\right),单独拎出父亲写成 p_{fa}=2p-\sum_{x\neq fa}p_x,我们按照深度从大到小遍历,对于一个节点,用自己的值和其他前驱的值去表示父亲,这样就可以用 n 个变量表示 ACAM 上所有节点的值。

那方程在哪里?来源有二:

  1. 不同的点可能有共同的父亲,此时两组值必然是相同的。由于一共有 n 个叶子,所以共会有 n-1 次这样的合并;

完事以后直接消元就好了,注意消元时的精度问题,时空复杂度均为 \mathcal{O}(n^3)。LOJ 上可以过,洛谷上被卡空间了。

upd(感谢 @阿丑):对每一个 x_i 都分别跑 bfs,存系数的空间就可以降到 \mathcal{O}(nm),就可以过了。

using db = double;

const int max_n = 300, max_l = 300, max_nd = max_n * max_l + 1, cs = 2;

int tr[max_nd][cs], fail[max_nd], fa[max_nd], ind = 1;
int hd[max_nd], des[max_nd + 1], nxt[max_nd + 1], st[max_n], e_cnt = 0;
db ce[max_nd], m[max_n][max_n + 1];
bool ok[max_nd];

void add(int s, int t)
{
    des[e_cnt] = t;
    nxt[e_cnt] = hd[s];
    hd[s] = e_cnt++;
}

char s[max_l + 1];
signed main()
{
    memset(tr, -1, sizeof tr);

    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, l;

    cin >> n >> l;
    fa[0] = -1;

    for (int i = 0; i < n; i++)
    {
        cin >> s;

        int ptr = 0;
        for (int j = 0; j < l; j++)
        {
            int d = (s[j] == 'H');
            if (tr[ptr][d] == -1)
            {
                fa[ind] = ptr;
                tr[ptr][d] = ind++;
            }
            ptr = tr[ptr][d];
        }
        ok[ptr] = true;
        st[i] = ptr;
    }
    fill(hd, hd + ind, -1);

    queue<int> q;
    for (int i = 0; i < cs; i++)
        if (tr[0][i] == -1)
            tr[0][i] = 0;
        else
        {
            fail[tr[0][i]] = 0;
            q.push(tr[0][i]);
        }
    while (!q.empty())
    {
        int cur = q.front(); q.pop();
        for (int i = 0; i < cs; i++)
            if (tr[cur][i] == -1)
            {
                tr[cur][i] = tr[fail[cur]][i];
                if (!ok[cur])
                    add(tr[cur][i], cur);
            }
            else
            {
                fail[tr[cur][i]] = tr[fail[cur]][i];
                q.push(tr[cur][i]);
            }
    }

    for (int cp = 0; cp < n; cp++)
    {
        fill(ok, ok + ind, false);
        for (int i = 0; i < n; i++)
        {
            q.push(st[i]);
            ok[st[i]] = true;
            ce[st[i]] = (cp == i);
        }
        int fc = 0;
        while (!q.empty())
        {
            int cur = q.front(); q.pop();
            if (cur == 0)
                break;
            db tmp = ce[cur] * 2;
            for (int p = hd[cur], dst; p != -1; p = nxt[p])
            {
                dst = des[p];
                tmp -= ce[dst];
            }
            if (!ok[fa[cur]])
            {
                ok[fa[cur]] = true;
                ce[fa[cur]] = tmp;
                q.push(fa[cur]);
            }
            else
                m[fc++][cp] = ce[fa[cur]] - tmp;
        }
    }
    fill(m[n - 1], m[n - 1] + n + 1, 1);

    for (int i = 0; i < n; i++)
    {
        int mxp = -1;
        double mxc = fabs(m[i][i]);
        for (int j = i + 1; j < n; j++)
            if (fabs(m[j][i]) > mxc)
                mxc = fabs(m[j][i]), mxp = j;
        if (mxp != -1)
            swap(m[mxp], m[i]);
        for (int j = i + 1; j < n; j++)
        {
            db rto = m[j][i] / m[i][i];
            for (int k = i; k <= n; k++)
                m[j][k] -= rto * m[i][k];
        }
    }
    for (int i = n - 1; i >= 0; i--)
        for (int j = 0; j < i; j++)
        {
            db rto = m[j][i] / m[i][i];
            for (int k = i; k <= n; k++)
                m[j][k] -= rto * m[i][k];
        }

    cout << fixed;
    for (int i = 0; i < n; i++)
        cout << setprecision(10) << m[i][n] / m[i][i] << "\n";

    return 0;
}