题解 P3747 【[六省联考2017]相逢是问候】

· · 题解

Solution

Code

#include <bits/stdc++.h>

using namespace std;

#define ll long long
#define p2 p << 1
#define p3 p << 1 | 1

template <class t>
inline void read(t & res)
{
    char ch;
    while (ch = getchar(), !isdigit(ch));
    res = ch ^ 48;
    while (ch = getchar(), isdigit(ch))
    res = res * 10 + (ch ^ 48);
}

const int e = 5e4 + 5, o = 5e4 + 5, bl = 2e4, z = 105;
int a[e], cnt[o * 4], c1[z][20005], c2[z][20005], n, m, c, mod, phi[e];
int num[e], d[e], c4, c3[e], b[e][20], sum[o * 4];

inline void upt(int &x, int y)
{
    x = y;
    if (x >= mod) x -= mod;
}

inline int solve(int x)
{
    int s = sqrt(x), res = x, i;
    for (i = 2; i <= s; i++)
    if (x % i == 0)
    {
        while (x % i == 0) x /= i;
        res /= i;
        res *= i - 1;
    }
    if (x > 1)
    {
        res /= x;
        res *= x - 1;
    }
    return res;
}

inline void init()
{   
    d[d[0] = 1] = mod;
    for (;;)
    {
        d[0]++;
        d[d[0]] = solve(d[d[0] - 1]);
        if (d[d[0]] == 1) break;
    }
    reverse(d + 1, d + d[0] + 1);
    int i, j;
    ll y = 1;
    c3[0] = 1;
    for (i = 1; i < bl; i++) 
    {
        if (y <= 1e8)
        {
            y = y * c; 
            c3[i] = y;
            if (y <= 1e8) c4 = i;
        }
        else break;
    }
    for (i = 1; i <= d[0]; i++)
    {
        c1[i][0] = c2[i][0] = 1;
        for (j = 1; j < bl; j++) c2[i][j] = (ll)c2[i][j - 1] * c % d[i];
        c1[i][1] = (ll)c2[i][bl - 1] * c % d[i];
        for (j = 2; j < bl; j++) c1[i][j] = (ll)c1[i][j - 1] * c1[i][1] % d[i];
    }
    for (i = 1; i <= n; i++)
    {
        b[i][0] = a[i];
        for (j = 1; j <= 6; j++)
        {
            int x = b[i][j - 1], p1 = x / bl, p0 = x % bl;
            if (c3[p0] > 1e8 || p1 || p0 > c4) break;
            b[i][j] = c3[p0];
            num[i] = j;
        }
    }
}

inline int ksm(int y, int id)
{
    return (ll)c1[id][y / bl] * c2[id][y % bl] % d[id]; 
}

inline int f(int i, int j, int mod)
{
    if (mod == 1) return 0;
    if (j == 0) return a[i] % d[mod];
    if (c == 1 || (c != 1 && j - 1 <= num[i] && b[i][j - 1] < d[mod - 1])) 
    return ksm(f(i, j - 1, mod - 1), mod);
    else return ksm(f(i, j - 1, mod - 1) + d[mod - 1], mod);
}

inline void collect(int p)
{
    cnt[p] = min(cnt[p2], cnt[p3]);
    upt(sum[p], sum[p2] + sum[p3]);
}

inline void build(int l, int r, int p)
{
    if (l == r)
    {
        sum[p] = a[l];
        return;
    }
    int mid = l + r >> 1;
    build(l, mid, p2);
    build(mid + 1, r, p3);
    collect(p);
}

inline void update(int l, int r, int s, int t, int p)
{
    if (cnt[p] > d[0]) return;
    if (l == r)
    {
        cnt[p]++;
        sum[p] = f(l, cnt[p], d[0]);
        return;
    }
    int mid = l + r >> 1;
    if (s <= mid) update(l, mid, s, t, p2);
    if (t > mid) update(mid + 1, r, s, t, p3);
    collect(p);
}

inline int query(int l, int r, int s, int t, int p)
{
    if (l == s && r == t) return sum[p];
    int mid = l + r >> 1, res = 0;
    if (t <= mid) res = query(l, mid, s, t, p2);
    else if (s > mid) res = query(mid + 1, r, s, t, p3);
    else upt(res, query(l, mid, s, mid, p2) + query(mid + 1, r, mid + 1, t, p3));
    return res;
}

int main()
{
    read(n); read(m); read(mod); read(c);
    int i, opt, l, r;
    for (i = 1; i <= n; ++i) read(a[i]);
    init();
    build(1, n, 1);
    while (m--)
    {
        read(opt);
        read(l);
        read(r);
        if (!opt) update(1, n, l, r, 1);
        else printf("%d\n", query(1, n, l, r, 1));
    }
    return 0;
}