题解 P5068 【[Ynoi2015] 我回来了】

eee_hoho

2021-01-12 16:02:11

Solution

分享一个被吊踩的做法/kk 首先考虑加入一个数 $x$ 对伤害为 $i$ 的亵渎的贡献,假设当前伤害为 $i$ 的亵渎操作次数为 $ans_i$ 次,那么如果 $0<x-i\times ans_i\le i$ ,就可以使 $ans_i++$ ,而如果在插入这个数之前插入了满足更大的 $ans_i$ 的要求,那么就可以继续增加。 而我们发现一个伤害为 $i$ 的亵渎操作次数为 $\lfloor\frac{n}{i}\rfloor+1$ ,所以 $\sum_{i=1}^n ans_i$ 的最大值是 $nlogn$ 级别的,那么这启发我们加入一个数之后暴力往后跳有关的 $ans$ 数组,然后可以用一个树状数组进行单点修改和区间和。 那么现在就变成了我有 $(i,i*j),(j*i\le n)$ ,这么多个点,每次给定一个 $x$ ,会给所有的 $(i,\lfloor\frac{x-1}{i}\rfloor)$ 上的点打上标记。 这就可以看作一个整除分块,一次操作只会有 $2\sqrt{x-1}$ 种 $y$ 坐标,那么我们可以每次直接提取出来相同的一段 $y$ 进行覆盖,然后用并查集维护出来这个点的下一个应该覆盖的点是哪个点,复杂度就是正确的了。 复杂度 $O(n\sqrt{n}+mlogn+nlogn\alpha(n)+nlog^2n)$ 。 **Code** ``` cpp #include <iostream> #include <cstdio> #include <algorithm> #include <cstring> const int N = 1e5; const int M = 1e6; const int MAXN = 2e7; using namespace std; struct qry { int opt,l,r,id; }q[M + 5]; int n,m,vis[N + 5],ans[N + 5],pool[MAXN + 5],pool1[MAXN + 5]; int *pp = pool,*d[N + 5],*p1 = pool1; long long s; struct Bit { long long c[N + 5]; int lowbit(int x) { return x & (-x); } void add(int x,int v) { for (;x <= n;x += lowbit(x)) c[x] += v; } long long query(int x) { long long ans = 0; for (;x;x -= lowbit(x)) ans += c[x]; return ans; } }c; struct Dsu { int *fa,ts; void init(int n) { fa = p1; p1 += n; ts = n - 1; for (int i = 1;i < n;i++) fa[i] = i; } int find(int x) { if (fa[x] == x) return x; return fa[x] = find(fa[x]); } }fd[N + 5]; inline long long read() { long long X(0);int w(0);char ch(0); while (!isdigit(ch)) w |= ch == '-',ch = getchar(); while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48),ch = getchar(); return w ? -X : X; } int main() { n = read();m = read(); int opt,l,r; for (int i = 1;i <= n;i++) { d[i] = pp; pp += n / i + 1; fd[i].init(n / i + 2); } fd[0].init(n + 2); while (m--) { opt = read();l = read(); if (opt == 1) { if (vis[l]) continue; vis[l] = 1; int h = l - 1; for (int l = 1,r,v;l <= n;l = r + 1) { v = h / l; if (v) r = min(h / v,n); else r = n; //cout<<l<<" "<<r<<" "<<v<<endl; for (int i = l;i <= r;i++) { //cout<<v<<" "<<i<<endl; i = fd[v].find(i); if (i > r) break; d[i][v] = 1; int las = ans[i]; while (d[i][ans[i]] && ans[i] <= n / i) ans[i]++; if (ans[i] - las != 0) c.add(i,ans[i] - las); if (i + 1 <= fd[v].ts) fd[v].fa[i] = fd[v].find(fd[v].fa[i + 1]); } } } else { r = read(); printf("%lld\n",c.query(r) - c.query(l - 1) + r - l + 1); } } return 0; } ```