题解:Soso 的排列 2

· · 算法·理论

其他部分参见官方题解(V 题),这里快进到将 k 转化为变进制数这一步。记 m=O(\log k)

  1. 事实上最后转换出的变进制数长度为 \ell=O(m/\log m) 而不是 O(m)。证明考虑用阶乘的斯特林公式,具体不写了。例如 k=10^{100000} 对应的长度 \ell = 25205, k=10^{200000} 对应的长度 \ell = 47175,。这样就给原来 O(m^2) 的算法除了一个 \log m
  2. 压位高精度。由于题目输入是十进制,因此没办法我们只能压九位 10^9 进制高精度,也算是给复杂度除了一个位长 w=9

应用这两个优化就能把该部分复杂度降到 O(\frac{m^2}{w\log m}),虽然跑的有点慢,不过应该是可以过的。

#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
using LL = long long;
template <unsigned umod>
struct modint { /*{{{*/
  static constexpr int mod = umod;
  unsigned v;
  modint() = default;
  template <class T, enable_if_t<is_integral<T>::value, int> = 0>
    modint(const T& y)
    : v((unsigned)(y % mod + (is_signed<T>() && y < 0 ? mod : 0))) {}
  modint& operator+=(const modint& rhs) {
    v += rhs.v;
    if (v >= umod) v -= umod;
    return *this;
  }
  modint& operator-=(const modint& rhs) {
    v -= rhs.v;
    if (v >= umod) v += umod;
    return *this;
  }
  modint& operator*=(const modint& rhs) {
    v = (unsigned)(1ull * v * rhs.v % umod);
    return *this;
  }
  modint& operator/=(const modint& rhs) {
    assert(rhs.v);
    return *this *= qpow(rhs, mod - 2);
  }
  friend modint operator+(modint lhs, const modint& rhs) { return lhs += rhs; }
  friend modint operator-(modint lhs, const modint& rhs) { return lhs -= rhs; }
  friend modint operator*(modint lhs, const modint& rhs) { return lhs *= rhs; }
  friend modint operator/(modint lhs, const modint& rhs) { return lhs /= rhs; }
  template <class T>
    friend modint qpow(modint a, T b) {
      modint r = 1;
      for (assert(b >= 0); b; b >>= 1, a *= a)
        if (b & 1) r *= a;
      return r;
    }
  friend int raw(const modint& self) { return self.v; }
  friend ostream& operator<<(ostream& os, const modint& self) {
    return os << raw(self);
  }
  explicit operator bool() const { return v != 0; }
  modint operator-() const { return modint(0) - *this; }
  bool operator==(const modint& rhs) const { return v == rhs.v; }
  bool operator!=(const modint& rhs) const { return v != rhs.v; }
}; /*}}}*/
using mint = modint<998244353>;

vector<int> solve(string s) {
  constexpr uint32_t BASE = 1000000000;          // 压 9 位十进制
  vector<uint32_t> num;                      // 高位在前
  int len = s.size();

  // 字符串转压位数组(高位在前)

  for (int i = len; i > 0; i -= 9) {
    int start = i - 9;
    if (start < 0) start = 0;
    int seg_len = i - start;
    int val = stoi(s.substr(start, seg_len));
    num.push_back(val);
  }
  reverse(num.begin(), num.end());  // 现在高位在前
  vector<int> ans;
  int divisor = 1;
  size_t cur_len = num.size();                // 当前有效长度
  uint32_t* data = num.data();                // 直接操作内存
  uint64_t count = 0;
  while (true) {
    uint64_t remainder = 0;
    size_t new_len = 0;                      // 商的有效长度
    count += cur_len;

    for (size_t i = 0; i < cur_len; ++i) {
      remainder = remainder * BASE + data[i];
      uint32_t q = remainder / divisor;
      remainder %= divisor;
      if (new_len != 0 || q != 0) {
        data[new_len++] = q;              // 覆盖写入商
      }
    }

    ans.push_back(static_cast<int>(remainder));

    if (new_len == 0) break;                  // 商为 0,结束
    cur_len = new_len;                         // 更新有效长度
    ++divisor;
  }

  return ans;
}
constexpr int N = 2e5 + 10;
template <class T, int N>
struct segtree {
  T ans[N << 2];
  template <class Func>
    void build(Func&& f, int p, int l, int r) {
      if (l == r) return ans[p] = f(l), void();
      int mid = (l + r) >> 1;
      build(std::forward<Func>(f), p << 1, l, mid);
      build(std::forward<Func>(f), p << 1 | 1, mid + 1, r);
      maintain(p);
    }
  void maintain(int p) { ans[p] = ans[p << 1] + ans[p << 1 | 1]; }
  void setValue(int x, T k, int p, int l, int r) {
    if (l == r) return ans[p] = k, void();
    int mid = (l + r) >> 1;
    if (x <= mid)
      setValue(x, k, p << 1, l, mid);
    else
      setValue(x, k, p << 1 | 1, mid + 1, r);
    maintain(p);
  }
  int getKth(T k, int p, int l, int r) {
    if (l == r) return l;
    int mid = (l + r) >> 1;
    if (k <= ans[p << 1])
      return getKth(k, p << 1, l, mid);
    else
      return getKth(k - ans[p << 1], p << 1 | 1, mid + 1, r);
  }
  T getRank(int x, int p, int l, int r) {
    if (l == r) return 0;
    int mid = (l + r) >> 1;
    if (x <= mid)
      return getRank(x, p << 1, l, mid);
    else
      return getRank(x, p << 1 | 1, mid + 1, r) + ans[p << 1];
  }
};
int n, p[N];
mint ans[N], fac[N];
int s[N];
string k_str;
segtree<int, N> t;
segtree<mint, N> t2;
void calc(int k) {
  mint sum = 1;
  for (int i = n; i >= 1; i--) {
    // q[i] == p[i]
    ans[i] += sum * p[i] * k;
    debug("ans[%d] += %d * %d * %d\n", i, raw(sum), p[i], k);
    sum += fac[n - i] * s[i];
  }
  t.build([](int) -> int { return 1; }, 1, 1, n);
  t2.build([](int x) -> int { return x; }, 1, 1, n);
  sum = 0;
  for (int i = 1; i <= n; i++) {
    // q[i] < p[i]
    t.setValue(p[i], 0, 1, 1, n);
    t2.setValue(p[i], 0, 1, 1, n);
    ans[i] += sum * k;
    mint pre = t2.getRank(p[i], 1, 1, n);
    int pnum = t.getRank(p[i], 1, 1, n);
    ans[i] += pre * fac[n - i] * k;
    debug("i = %d, sum = %d, pre = %d, pnum = %d\n", i, raw(sum), raw(pre), pnum);
    debug("ans[%d] += (%d + %d * %d) * %d\n", i, raw(sum), raw(pre), raw(fac[n - i]), k);
    if (i < n) {
      sum += (t2.ans[1] - pre + p[i]) * fac[n - i - 1] * pnum;
      sum += pre * fac[n - i - 1] * (pnum - 1);
    }
  }
}
int main() {
#ifndef LOCAL
  cin.tie(nullptr)->sync_with_stdio(false);
#endif
  cin >> n >> k_str;
  fac[0] = 1;
  for (int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i;
  for (int i = 1; i <= n; i++) cin >> p[i];
  t.build([](int) -> int { return 0; }, 1, 1, n);
  for (int i = n; i >= 1; i--) {
    s[i] = t.getRank(p[i], 1, 1, n);
    t.setValue(p[i], 1, 1, 1, n);
  }
  calc(-1);
  auto res = solve(k_str);
  debug("res: "); for (int i = 0; i < (int)res.size(); i++) debug("%d, ", res[i]); debug("\n");
  for (int i = n; i >= 1; i--) {
    if (n - i < (int)res.size()) s[i] += res[n - i];
    int num = n - i + 1;
    s[i - 1] += s[i] / num;
    s[i] %= num;
  }
  mint ext = fac[n] * s[0], tmp = 1;
  for (int i = 0; i < (int)res.size(); i++) {
    tmp *= i;
    if (i >= n) ext += tmp * res[i];
  }
  static constexpr int inv2 = (mint::mod + 1) / 2;
  for (int i = 1; i <= n; i++) ans[i] += mint(n + 1) * inv2 * ext;
  t.build([](int) -> int { return 1; }, 1, 1, n);
  for (int i = 1; i <= n; i++) {
    int x = t.getKth(s[i] + 1, 1, 1, n);
    p[i] = x;
    debug("%d, ", x);
    t.setValue(x, 0, 1, 1, n);
  }
  debug("\n");
  calc(1);
  for (int i = 1; i <= n; i++) cout << ans[i] << endl;
  return 0;
}