题解:Soso 的排列 2
yukimianyan · · 算法·理论
- Soso 的排列:https://www.luogu.com.cn/problem/P15394(
k\leq 10^{12} ) - Soso 的排列 2:https://www.luogu.com.cn/problem/U667696(
k\leq 10^{10^5} )
其他部分参见官方题解(V 题),这里快进到将
- 事实上最后转换出的变进制数长度为
\ell=O(m/\log m) 而不是O(m) 。证明考虑用阶乘的斯特林公式,具体不写了。例如k=10^{100000} 对应的长度\ell = 25205 ,k=10^{200000} 对应的长度\ell = 47175 ,。这样就给原来O(m^2) 的算法除了一个\log m 。 - 压位高精度。由于题目输入是十进制,因此没办法我们只能压九位
10^9 进制高精度,也算是给复杂度除了一个位长w=9 。
应用这两个优化就能把该部分复杂度降到
#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
using LL = long long;
template <unsigned umod>
struct modint { /*{{{*/
static constexpr int mod = umod;
unsigned v;
modint() = default;
template <class T, enable_if_t<is_integral<T>::value, int> = 0>
modint(const T& y)
: v((unsigned)(y % mod + (is_signed<T>() && y < 0 ? mod : 0))) {}
modint& operator+=(const modint& rhs) {
v += rhs.v;
if (v >= umod) v -= umod;
return *this;
}
modint& operator-=(const modint& rhs) {
v -= rhs.v;
if (v >= umod) v += umod;
return *this;
}
modint& operator*=(const modint& rhs) {
v = (unsigned)(1ull * v * rhs.v % umod);
return *this;
}
modint& operator/=(const modint& rhs) {
assert(rhs.v);
return *this *= qpow(rhs, mod - 2);
}
friend modint operator+(modint lhs, const modint& rhs) { return lhs += rhs; }
friend modint operator-(modint lhs, const modint& rhs) { return lhs -= rhs; }
friend modint operator*(modint lhs, const modint& rhs) { return lhs *= rhs; }
friend modint operator/(modint lhs, const modint& rhs) { return lhs /= rhs; }
template <class T>
friend modint qpow(modint a, T b) {
modint r = 1;
for (assert(b >= 0); b; b >>= 1, a *= a)
if (b & 1) r *= a;
return r;
}
friend int raw(const modint& self) { return self.v; }
friend ostream& operator<<(ostream& os, const modint& self) {
return os << raw(self);
}
explicit operator bool() const { return v != 0; }
modint operator-() const { return modint(0) - *this; }
bool operator==(const modint& rhs) const { return v == rhs.v; }
bool operator!=(const modint& rhs) const { return v != rhs.v; }
}; /*}}}*/
using mint = modint<998244353>;
vector<int> solve(string s) {
constexpr uint32_t BASE = 1000000000; // 压 9 位十进制
vector<uint32_t> num; // 高位在前
int len = s.size();
// 字符串转压位数组(高位在前)
for (int i = len; i > 0; i -= 9) {
int start = i - 9;
if (start < 0) start = 0;
int seg_len = i - start;
int val = stoi(s.substr(start, seg_len));
num.push_back(val);
}
reverse(num.begin(), num.end()); // 现在高位在前
vector<int> ans;
int divisor = 1;
size_t cur_len = num.size(); // 当前有效长度
uint32_t* data = num.data(); // 直接操作内存
uint64_t count = 0;
while (true) {
uint64_t remainder = 0;
size_t new_len = 0; // 商的有效长度
count += cur_len;
for (size_t i = 0; i < cur_len; ++i) {
remainder = remainder * BASE + data[i];
uint32_t q = remainder / divisor;
remainder %= divisor;
if (new_len != 0 || q != 0) {
data[new_len++] = q; // 覆盖写入商
}
}
ans.push_back(static_cast<int>(remainder));
if (new_len == 0) break; // 商为 0,结束
cur_len = new_len; // 更新有效长度
++divisor;
}
return ans;
}
constexpr int N = 2e5 + 10;
template <class T, int N>
struct segtree {
T ans[N << 2];
template <class Func>
void build(Func&& f, int p, int l, int r) {
if (l == r) return ans[p] = f(l), void();
int mid = (l + r) >> 1;
build(std::forward<Func>(f), p << 1, l, mid);
build(std::forward<Func>(f), p << 1 | 1, mid + 1, r);
maintain(p);
}
void maintain(int p) { ans[p] = ans[p << 1] + ans[p << 1 | 1]; }
void setValue(int x, T k, int p, int l, int r) {
if (l == r) return ans[p] = k, void();
int mid = (l + r) >> 1;
if (x <= mid)
setValue(x, k, p << 1, l, mid);
else
setValue(x, k, p << 1 | 1, mid + 1, r);
maintain(p);
}
int getKth(T k, int p, int l, int r) {
if (l == r) return l;
int mid = (l + r) >> 1;
if (k <= ans[p << 1])
return getKth(k, p << 1, l, mid);
else
return getKth(k - ans[p << 1], p << 1 | 1, mid + 1, r);
}
T getRank(int x, int p, int l, int r) {
if (l == r) return 0;
int mid = (l + r) >> 1;
if (x <= mid)
return getRank(x, p << 1, l, mid);
else
return getRank(x, p << 1 | 1, mid + 1, r) + ans[p << 1];
}
};
int n, p[N];
mint ans[N], fac[N];
int s[N];
string k_str;
segtree<int, N> t;
segtree<mint, N> t2;
void calc(int k) {
mint sum = 1;
for (int i = n; i >= 1; i--) {
// q[i] == p[i]
ans[i] += sum * p[i] * k;
debug("ans[%d] += %d * %d * %d\n", i, raw(sum), p[i], k);
sum += fac[n - i] * s[i];
}
t.build([](int) -> int { return 1; }, 1, 1, n);
t2.build([](int x) -> int { return x; }, 1, 1, n);
sum = 0;
for (int i = 1; i <= n; i++) {
// q[i] < p[i]
t.setValue(p[i], 0, 1, 1, n);
t2.setValue(p[i], 0, 1, 1, n);
ans[i] += sum * k;
mint pre = t2.getRank(p[i], 1, 1, n);
int pnum = t.getRank(p[i], 1, 1, n);
ans[i] += pre * fac[n - i] * k;
debug("i = %d, sum = %d, pre = %d, pnum = %d\n", i, raw(sum), raw(pre), pnum);
debug("ans[%d] += (%d + %d * %d) * %d\n", i, raw(sum), raw(pre), raw(fac[n - i]), k);
if (i < n) {
sum += (t2.ans[1] - pre + p[i]) * fac[n - i - 1] * pnum;
sum += pre * fac[n - i - 1] * (pnum - 1);
}
}
}
int main() {
#ifndef LOCAL
cin.tie(nullptr)->sync_with_stdio(false);
#endif
cin >> n >> k_str;
fac[0] = 1;
for (int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i;
for (int i = 1; i <= n; i++) cin >> p[i];
t.build([](int) -> int { return 0; }, 1, 1, n);
for (int i = n; i >= 1; i--) {
s[i] = t.getRank(p[i], 1, 1, n);
t.setValue(p[i], 1, 1, 1, n);
}
calc(-1);
auto res = solve(k_str);
debug("res: "); for (int i = 0; i < (int)res.size(); i++) debug("%d, ", res[i]); debug("\n");
for (int i = n; i >= 1; i--) {
if (n - i < (int)res.size()) s[i] += res[n - i];
int num = n - i + 1;
s[i - 1] += s[i] / num;
s[i] %= num;
}
mint ext = fac[n] * s[0], tmp = 1;
for (int i = 0; i < (int)res.size(); i++) {
tmp *= i;
if (i >= n) ext += tmp * res[i];
}
static constexpr int inv2 = (mint::mod + 1) / 2;
for (int i = 1; i <= n; i++) ans[i] += mint(n + 1) * inv2 * ext;
t.build([](int) -> int { return 1; }, 1, 1, n);
for (int i = 1; i <= n; i++) {
int x = t.getKth(s[i] + 1, 1, 1, n);
p[i] = x;
debug("%d, ", x);
t.setValue(x, 0, 1, 1, n);
}
debug("\n");
calc(1);
for (int i = 1; i <= n; i++) cout << ans[i] << endl;
return 0;
}