题解:AT_abc352_g [ABC352G] Socks 3
首先不难发现:
其中:
这里,
分母很好理解,就是总共的方案数,然后
所以这个式子我们要想计算,就需要计算出原数组这
其中
但这样时间复杂度就炸了。观察发现我们 DP 的过程,其实我们就是在每次对原本的这个式子乘上
这个式子的各项系数,直接使用分治 NTT 即可解决,时间复杂度
代码:
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int mod = 998244353;
const int G = 3;
int fastpow(int a, int b, int mod)
{
int res = 1;
while (b) {
if (b & 1) {
res = (res * a) % mod;
}
a = (a * a) % mod;
b >>= 1;
}
return res;
}
int rev[1 << 22];
void change(vector<int> &a, int len)
{
for (int i = 0; i < len; i++) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? (len >> 1) : 0);
}
for (int i = 0; i < len; i++) {
if (i < rev[i]) {
swap(a[i], a[rev[i]]);
}
}
}
void ntt(vector<int> &a, int len, int x)
{
change(a, len);
for (int h = 2; h <= len; h <<= 1) {
int omega = fastpow(G, (mod - 1) / h, mod);
if (x == -1) {
omega = fastpow(omega, mod - 2, mod);
}
for (int i = 0; i < len; i += h) {
int w = 1;
for (int j = i; j < i + h / 2; j++) {
int u = a[j];
int v = a[j + h / 2] * w % mod;
a[j] = (u + v) % mod;
a[j + h / 2] = (u - v + mod) % mod;
w = (w * omega) % mod;
}
}
}
if (x == -1) {
int inv = fastpow(len, mod - 2, mod);
for (int i = 0; i < len; i++) {
a[i] = (a[i] * inv) % mod;
}
}
}
vector<int> convo(vector<int> a, vector<int> b)
{
if (a.empty() || b.empty()) {
return {0};
}
int m = 1;
while (m < a.size() + b.size() - 1) {
m <<= 1;
}
a.resize(m);
b.resize(m);
ntt(a, m, 1);
ntt(b, m, 1);
for (int i = 0; i < m; i++) {
a[i] = (a[i] * b[i]) % mod;
}
ntt(a, m, -1);
return a;
}
vector<int> solve(vector<int>& a, int l, int r)
{
if (l > r) {
return {1};
}
if (l == r) {
return {1, a[l]};
}
int mid = (l + r) / 2;
return convo(solve(a, l, mid), solve(a, mid + 1, r));
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n;
cin >> n;
vector<int> a(n);
int cnt = 0;
for (int i = 0; i < n; i++) {
cin >> a[i];
cnt += a[i];
}
vector<int> res = solve(a, 0, n - 1);
int ans = 1, s1 = 1, s2 = 1;
for (int i = 1; i <= n && i < res.size(); i++) {
s1 = (s1 * i) % mod;
s2 = s2 * (cnt - i + 1) % mod;
ans = (ans + (res[i] * s1) % mod * fastpow(s2, mod - 2, mod)) % mod;
}
cout << (ans + mod) % mod << '\n';
return 0;
}
AC记录