题解 P5068 【[Ynoi2015] 我回来了】
eee_hoho
2021-01-12 16:02:11
分享一个被吊踩的做法/kk
首先考虑加入一个数 $x$ 对伤害为 $i$ 的亵渎的贡献,假设当前伤害为 $i$ 的亵渎操作次数为 $ans_i$ 次,那么如果 $0<x-i\times ans_i\le i$ ,就可以使 $ans_i++$ ,而如果在插入这个数之前插入了满足更大的 $ans_i$ 的要求,那么就可以继续增加。
而我们发现一个伤害为 $i$ 的亵渎操作次数为 $\lfloor\frac{n}{i}\rfloor+1$ ,所以 $\sum_{i=1}^n ans_i$ 的最大值是 $nlogn$ 级别的,那么这启发我们加入一个数之后暴力往后跳有关的 $ans$ 数组,然后可以用一个树状数组进行单点修改和区间和。
那么现在就变成了我有 $(i,i*j),(j*i\le n)$ ,这么多个点,每次给定一个 $x$ ,会给所有的 $(i,\lfloor\frac{x-1}{i}\rfloor)$ 上的点打上标记。
这就可以看作一个整除分块,一次操作只会有 $2\sqrt{x-1}$ 种 $y$ 坐标,那么我们可以每次直接提取出来相同的一段 $y$ 进行覆盖,然后用并查集维护出来这个点的下一个应该覆盖的点是哪个点,复杂度就是正确的了。
复杂度 $O(n\sqrt{n}+mlogn+nlogn\alpha(n)+nlog^2n)$ 。
**Code**
``` cpp
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
const int N = 1e5;
const int M = 1e6;
const int MAXN = 2e7;
using namespace std;
struct qry
{
int opt,l,r,id;
}q[M + 5];
int n,m,vis[N + 5],ans[N + 5],pool[MAXN + 5],pool1[MAXN + 5];
int *pp = pool,*d[N + 5],*p1 = pool1;
long long s;
struct Bit
{
long long c[N + 5];
int lowbit(int x)
{
return x & (-x);
}
void add(int x,int v)
{
for (;x <= n;x += lowbit(x))
c[x] += v;
}
long long query(int x)
{
long long ans = 0;
for (;x;x -= lowbit(x))
ans += c[x];
return ans;
}
}c;
struct Dsu
{
int *fa,ts;
void init(int n)
{
fa = p1;
p1 += n;
ts = n - 1;
for (int i = 1;i < n;i++)
fa[i] = i;
}
int find(int x)
{
if (fa[x] == x)
return x;
return fa[x] = find(fa[x]);
}
}fd[N + 5];
inline long long read()
{
long long X(0);int w(0);char ch(0);
while (!isdigit(ch)) w |= ch == '-',ch = getchar();
while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48),ch = getchar();
return w ? -X : X;
}
int main()
{
n = read();m = read();
int opt,l,r;
for (int i = 1;i <= n;i++)
{
d[i] = pp;
pp += n / i + 1;
fd[i].init(n / i + 2);
}
fd[0].init(n + 2);
while (m--)
{
opt = read();l = read();
if (opt == 1)
{
if (vis[l])
continue;
vis[l] = 1;
int h = l - 1;
for (int l = 1,r,v;l <= n;l = r + 1)
{
v = h / l;
if (v)
r = min(h / v,n);
else
r = n;
//cout<<l<<" "<<r<<" "<<v<<endl;
for (int i = l;i <= r;i++)
{
//cout<<v<<" "<<i<<endl;
i = fd[v].find(i);
if (i > r)
break;
d[i][v] = 1;
int las = ans[i];
while (d[i][ans[i]] && ans[i] <= n / i)
ans[i]++;
if (ans[i] - las != 0)
c.add(i,ans[i] - las);
if (i + 1 <= fd[v].ts)
fd[v].fa[i] = fd[v].find(fd[v].fa[i + 1]);
}
}
}
else
{
r = read();
printf("%lld\n",c.query(r) - c.query(l - 1) + r - l + 1);
}
}
return 0;
}
```