[ARC058C]Iroha and Haiku

· · 题解

  去往原博客食用更佳。

题目

  点这里看题目。

分析

  首先,由于仅仅是 " 存在 " 这一条限制很容易导致计重,且总方案数就是 10^n 。我们就可以考虑求补集,也就是不存在的情况。
  然后有一个很显然的 DP :
  f(i,S):序列长度为 i ,此时序列的\le X + Y + Z 的极长后缀的状态为 S 的方案数。
  举一个关于极长后缀的例子:

    X=2,Y=5,Z=4
    ... 1 3 ( 2 5 3 )

  括号内的就是状态需要的极长后缀。
  那么怎么存下来这个后缀呢?

暴力

  写个搜索找一下方案数:

    void DFS( const int su )
    {
        if( su > X + Y + Z ) return ;
        ++ cnt;
        for( int i = 1 ; i <= 10 ; i ++ )
            DFS( su + i ); 
    }

  得到结果:

    input: X = 5, Y = 7, Z = 5
    output: cnt = 130624

  因而,实际上, S 的状态数不会太多。我们可以全部搜索出来之后,建立 S 和序号的映射关系。对于每个 S ,我们可以直接得到它加了一个字符之后的转移,再进行 DP 。
  顺带一提, S 也可以直接被压缩成 long long 范围内的数,方便 map 使用。
  时间复杂度最大是 O(130642 \times 10\times n) ,有点卡,不过还好。

优雅解法

  我们假想出一条长度为 X+Y+Z 的空格子列。那么一个数 x 就相当于涂掉了连续的 x 个格子。我们就只需要判断:第 X 个格子是不是某次填涂的末尾?第 X + Y 个格子是否是某次填涂的末尾?第 X+Y+Z 个格子是否是某次填涂的末尾?
  进而可以想到一个状态压缩的方法:对于一个数 x ,我们就把 [1,x) 的格子留空,把第 x 个格子涂上颜色,标记这里是一次填涂的尾巴。这样我们就可以利用位运算快速查出当前状态是否合法。
  煮几个栗子:

    input: X = 3, Y = 2, Z = 4 => 001010001
    原序列         转化为格子        转化后的序列
    1 1 1 3 2 1 => 1 1 1 001 01 1 => 111001011    
                                       ^ x   ^    mismatched
    1 2 2 3 1   => 1 01 01 001 1  => 101010011    
                                       ^ ^   ^        matched

  这样应该就可以理解了。 DP 的时间复杂度是 O(10\times n2^{X+Y+Z})
  这个方法理论复杂度要大一些,不过暴力被卡常了

代码

  因为一些你知道的原因,这里给出 Vjudge 上的 AC 记录。

暴力

   AC 记录。

#include <map>
#include <cstdio>
#include <vector>
using namespace std;

typedef long long LL;
const int mod = 1e9 + 7;
const int MAXN = 45, MAXS = 2e6 + 5; 

template<typename _T>
void read( _T &x )
{
    x = 0;char s = getchar();int f = 1;
    while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
    while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
    x *= f;
}

template<typename _T>
void write( _T x )
{
    if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
    if( 9 < x ){ write( x / 10 ); }
    putchar( x % 10 + '0' );
}

map<LL, int> state;

int f[MAXN][MAXS];
int trans[MAXS][11];
int seq[MAXN], siz;
int cnt = 0;
int N, X, Y, Z;
bool legal[MAXS];

void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }

void DFS( const int su, const LL H )
{
    if( su > X + Y + Z ) return ;
    state[H] = ++ cnt;
    for( int i = 1 ; i <= 10 ; i ++ )
        DFS( su + i, H * 10 + i ); 
}

bool chk()
{
    int s[MAXN] = {};
    for( int i = 1 ; i <= siz ; i ++ ) s[i] = s[i - 1] + seq[i];
    bool flag = false;
    for( int i = 1 ; i <= siz ; i ++ )
        if( s[i] == X )
        { flag = true; break; }
    if( ! flag ) return false;
    flag = false;
    for( int i = 1 ; i <= siz ; i ++ )
        if( s[i] == X + Y )
        { flag = true; break; }
    if( ! flag ) return false;
    return s[siz] == X + Y + Z;
}

int push( const int c )
{
    int su = 0, l;
    for( l = siz ; l ; su += seq[l --] )
        if( c + su + seq[l] > X + Y + Z )
            break;
    LL val = 0;
    for( int j = l + 1 ; j <= siz ; j ++ )
        val = val * 10 + seq[j];
    if( c <= X + Y + Z ) val = val * 10 + c;
    return state[val];
}

void init( const int su )
{
    if( su > X + Y + Z ) return ;
    legal[++ cnt] = chk();
    for( int i = 1 ; i <= 10 ; i ++ ) trans[cnt][i] = push( i );
    for( int c = 1 ; c <= 10 ; c ++ )
        seq[++ siz] = c, init( su + c ), siz --;
}

int main()
{
    read( N ), read( X ), read( Y ), read( Z );
    DFS( 0, 0 ); 
    cnt = 0; init( 0 );

    f[0][1] = 1;
    for( int i = 0 ; i < N ; i ++ )
        for( int k = 1 ; k <= cnt ; k ++ )
            if( f[i][k] && ! legal[k] )
                for( int c = 1 ; c <= 10 ; c ++ )
                    if( trans[k][c] && ! legal[trans[k][c]] )
                        add( f[i + 1][trans[k][c]], f[i][k] );
    int ans = 1;
    for( int i = 1 ; i <= N ; i ++ )
        ans = 10ll * ans % mod;
    for( int i = 1 ; i <= cnt ; i ++ )
        if( ! legal[i] )
            add( ans, mod - f[N][i] );  
    write( ans ), putchar( '\n' );
    return 0;
}

优雅的解法

   AC 记录。

#include <cstdio>

const int mod = 1e9 + 7;
const int MAXN = 45, MAXS = ( 1 << 17 ) + 5;

template<typename _T>
void read( _T &x )
{
    x = 0;char s = getchar();int f = 1;
    while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
    while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
    x *= f;
}

template<typename _T>
void write( _T x )
{
    if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
    if( 9 < x ){ write( x / 10 ); }
    putchar( x % 10 + '0' );
}

int f[MAXN][MAXS];
int N, X, Y, Z, st, all;

int bit( const int i ) { return 1 << i - 1; }
bool chk( const int S ) { return ( S & st ) == st; }
void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
int insert( const int S, const int v ) { return ( S << v | bit( v ) ) & all; }

int main()
{
    read( N ), read( X ), read( Y ), read( Z );
    int up = 1 << X + Y + Z;
    all = ( bit( X + Y + Z ) << 1 ) - 1;
    st = bit( Z ) | bit( Y + Z ) | bit( X + Y + Z );

    f[0][0] = 1;
    for( int i = 0 ; i < N ; i ++ )
        for( int S = 0 ; S < up ; S ++ )
            if( f[i][S] && ! chk( S ) )
                for( int k = 1, T ; k <= 10 ; k ++ )
                    if( ! chk( T = insert( S, k ) ) )
                        add( f[i + 1][T], f[i][S] );
    int ans = 1;
    for( int i = 1 ; i <= N ; i ++ )
        ans = 10ll * ans % mod;
    for( int S = 0 ; S < up ; S ++ )
        if( ! chk( S ) )
            add( ans, mod - f[N][S] );
    write( ans ), putchar( '\n' );
    return 0;
}