题解:P12246 电 van

· · 题解

题意:

给你一个长度为 n 的只包括 \texttt{v}\texttt{a}\texttt{n} 三个字符的字符串,有 m 次操作,每次操作交换 s_xs_{x+1},求每次操作后字符串中 \texttt{van} 作为子序列的出现次数。

思路:

暴力解法直接排除,思考怎么快速计算。

我们可以寻找每一个字符 \texttt{a},统计其前面的 \texttt{v} 以及后面 \texttt{n} 的数量,相乘便是含有这个 \texttt{a} 的子序列的数量,累加便是最后的答案。

再考虑修改,每次修改只会交换前后的字符,重要的是,每次修改只会影响一个 \texttt{a} 前后 \texttt{v}\texttt{n} 的数量。 具体需要分类讨论:

s_x 为字符 \texttt{a} 时,讨论 s_{x+1}

s_{x+1} 为字符 \texttt{a} 时,讨论 s_{x}

都不是则没有影响,因为只有当 \texttt{a} 是被修改的两个字符中其中一个才会受到影响,具体证明很简单,不再赘述。

每次修改减去当前 \texttt{a} 的贡献,再更新前后 \texttt{v}\texttt{n} 的数量,最后加上更新后 \texttt{a} 的贡献就是最后的答案。

代码:

#include<bits/stdc++.h>
#define int long long
using namespace std;
#define IOS ios::sync_with_stdio(false);cin.tie(0);cout.tie(0)
const int N=1e6+5;
int n,m;
string s;
int va[N],an[N];
vector<int>a;
int numa;
signed main(){
    IOS;
    cin>>n>>m;
    cin>>s;
    s=" "+s;
    int numv=0,numn=0;
    for(int i=1;i<=n;i++){
        if(s[i]=='v')numv++;
        if(s[i]=='a'){
            va[i]=numv;
            a.push_back(i);
            numa++;
        }
    }
    for(int i=n;i>=1;i--){
        if(s[i]=='n')numn++;
        if(s[i]=='a'){
            an[i]=numn;
        }
    }
    int ans=0;
    for(int i=0;i<numa;i++){
        ans=ans+va[a[i]]*an[a[i]];
    }
    while(m--){
        int x;
        cin>>x;
        int l=0,r=numa-1;
        int mid;
        while(l<r){
            mid=l+r>>1;
            a[mid]>=x?r=mid:l=mid+1;
        }
        int aid=a[r];
        if(aid==x){
            if(s[x+1]=='a'){
                cout<<ans<<endl;
            }
            if(s[x+1]=='v'){
                ans=ans-va[x]*an[x];
                swap(va[x],va[x+1]);
                swap(an[x],an[x+1]);
                va[x+1]++;
                ans=ans+va[x+1]*an[x+1];
                a[r]++;
                cout<<ans<<endl;
            }
            if(s[x+1]=='n'){
                ans=ans-va[x]*an[x];
                swap(va[x],va[x+1]);
                swap(an[x],an[x+1]);
                an[x+1]--;
                ans=ans+va[x+1]*an[x+1];
                a[r]++;
                cout<<ans<<endl;
            }
        }
        else if(aid==x+1){
            if(s[x]=='v'){
                ans=ans-va[x+1]*an[x+1];
                swap(va[x],va[x+1]);
                swap(an[x],an[x+1]);
                va[x]--;
                ans=ans+va[x]*an[x];
                a[r]--;
                cout<<ans<<endl;
            }
            if(s[x]=='n'){
                ans=ans-va[x+1]*an[x+1];
                swap(va[x],va[x+1]);
                swap(an[x],an[x+1]);
                an[x]++;
                a[r]--;
                ans=ans+va[x]*an[x];
                cout<<ans<<endl;
            }
        }
        else{
            cout<<ans<<endl;
        }
        swap(s[x],s[x+1]);
        /*cout<<s<<endl;
        for(int i=1;i<=numa;i++){
            cout<<a[i-1]<<" ";
            cout<<va[a[i-1]]<<" "<<an[a[i-1]]<<endl;
        }*/
    }
    return 0;
}

很明显可以优化,但作者懒得优化(疑似最劣解)。