题解 P6690 【一次函数】
_sys
·
·
题解
我们设 g 为模 p 意义下的原根,G(x) = x + g。
则 G^i (1 \leq i \leq p(p-1))遍历了所有的 ax + b(0\leq a < p, 1\leq b < p)。
考虑 G^i=(x+g)^i=ig^{i-1}x+g^i。
于是我们可以把所有的 $ax + b$ 取 “离散对数”,方法是对常数项 BSGS,一次项通过解方程求出。得到一个数组 $\textrm{arr}$,$Ax + B$ 转变为 $\textrm{goal}$。
现在的问题变为:给你一个数组 $\textrm{arr}$,问至少从中选出多少个使得他们的某个线性组合等于 $\textrm{goal}$。
我们设选出来的数组为 $s$。
根据裴蜀定理,我们得知若 $\gcd \{s\}|\textrm{goal}$,则一定有解。
我们将 $\textrm{arr}$ 中的每个数变为 $\gcd(\textrm{arr}_i, \textrm{goal})$。于是问题变为,选出尽可能少的数使得 $\gcd\{s\}=1$。
我们可以通过简单的状压 dp 解决这个问题,状态数为 $2^{\omega(p-1)+1}$。
得到了选出的数的集合,我们便可以逐位构造方案。
我们假设已知 $a$,要构造出 $\sum_{i=1}^k a_ix_i=b$ 的一组解,我们将式子写成 $a_1x_1+\gcd_{i=2}^k\{a_i\}X=b$,发现 $\sum_{i=2}^ka_ix'_i=b'$ 有解当且仅当 $\gcd_{i=2}^k\{a_i\}|b'$。所以我们通过 exgcd 求出 $x_1$ 和 $X$,之后解 $\sum_{i=2}^ka_ix_i=\gcd_{i=2}^k\{a_i\}X$ 即可。
时间复杂度为 $O(2^{\omega(p)+1}n+\sqrt{np}\log p)$。
```cpp
#include <bits/stdc++.h>
using namespace std;
const int Maxi = 1030, Maxn = 5005;
int n, m, ct, bloc, f[Maxn][Maxi], pos[Maxi], fac[Maxi];
pair <int, int> from[Maxn][Maxi];
long long d, p, goal, val[Maxn], arr[Maxn];
pair <int, long long> ans[Maxn];
unordered_map <int, int> Ma;
typedef pair <long long, long long> pll;
pll G;
pll operator * (const pll &x, const pll &y)
{
return make_pair((x.first * y.second + x.second * y.first) % p, x.second * y.second % p);
}
long long gcd(long long x, long long y)
{
return x == 0 ? y : gcd(y % x, x);
}
long long fast_pow(long long x, long long y)
{
long long ans = 1, now = x;
while (y)
{
if (y & 1) ans = ans * now % p;
now = now * now % p;
y >>= 1;
}
return ans;
}
pll fast_pow(pll x, long long y)
{
pll ans = make_pair(0, 1), now = x;
while (y)
{
if (y & 1) ans = ans * now;
now = now * now;
y >>= 1;
}
return ans;
}
void dp(void)
{
int maxi = (1 << ct) - 1;
memset(f, 0x3f, sizeof(f));
f[0][maxi] = 0;
for (int i = 1; i <= n; i++)
{
int sta = 0;
long long tmp_val = gcd(val[i], p * (p - 1));
tmp_val = tmp_val / gcd(tmp_val, goal);
for (int j = 1; j <= ct; j++)
if (tmp_val % fac[j] == 0) sta |= 1 << (j - 1);
for (int j = 1; j <= maxi; j++)
if (f[pos[j]][j] + 1 < f[pos[j & sta]][j & sta])
{
pos[j & sta] = i;
f[i][j & sta] = f[pos[j]][j] + 1;
from[i][j & sta] = make_pair(pos[j], j);
}
}
}
int get_unit(void)
{
int maxi = sqrt(p - 1);
for (int i = 2; ; i++)
{
for (int j = 2; j <= maxi; j++)
if ((p - 1) % j == 0)
{
if (fast_pow(i, j) == 1) goto END;
if (fast_pow(i, (p - 1) / j) == 1) goto END;
}
return i;
END:;
}
}
void get_factor(void)
{
int maxi = sqrt(p - 1), tmp = p - 1;
for (int i = 2; i <= maxi; i++)
if (tmp % i == 0)
{
fac[++ct] = i;
while (tmp % i == 0) tmp /= i;
}
if (tmp != 1) fac[++ct] = tmp;
fac[++ct] = p;
}
pll exgcd(long long x, long long y)
{
if (!x)
return make_pair(0, 1);
else
{
pll tmp = exgcd(y % x, x);
return make_pair(tmp.second - tmp.first * (y / x), tmp.first);
}
}
long long multi(long long x, long long y)
{
x = (x % (p * (p - 1)) + (p * (p - 1))) % (p * (p - 1));
y = (y % (p * (p - 1)) + (p * (p - 1))) % (p * (p - 1));
long long ans = 0, now = x;
while (y)
{
if (y & 1) ans = ((unsigned long long) ans + now) % (p * (p - 1));
now = (now * 2ull) % (p * (p - 1));
y >>= 1;
}
return ans;
}
void work(int lt, long long goal_now)
{
long long g = 0, div;
for (int i = lt + 1; i <= m; i++)
g = gcd(g, arr[i]);
div = gcd(arr[lt], gcd(goal_now, g));
for (int i = lt; i <= m; i++)
arr[i] /= div;
g /= div, goal_now /= div;
pll result = exgcd(arr[lt], g);
ans[lt].second = multi(result.first, goal_now);
if (lt == m - 1) return ;
work(lt + 1, multi(multi(result.second, goal_now), g));
}
void BSGS_init(void)
{
bloc = sqrt(n * p) / 2;
long long now = 1;
d = get_unit();
for (int i = 0; i < bloc; i++)
{
Ma[now] = i;
now = now * d % p;
}
}
long long BSGS(pll val_now)
{
long long inv = fast_pow(fast_pow(d, bloc), p - 2), now_inv = 1;
for (int j = 0; j < p / bloc + 1; j++)
{
if (Ma.find(now_inv * val_now.second % p) != Ma.end())
{
long long tmp = Ma[now_inv * val_now.second % p] + j * (long long) bloc;
pll tmp_val = val_now * fast_pow(make_pair(1, d), p * (p - 1) - tmp);
return ((p - 1) * (p - tmp_val.first * d % p) + tmp) % (p * (p - 1));
}
(now_inv *= inv) %= p;
}
}
int main()
{
scanf("%d%lld%lld%lld", &n, &p, &G.first, &G.second);
if (G == make_pair(0LL, 1LL))
{
puts("0");
return 0;
}
BSGS_init();
get_factor();
goal = BSGS(G);
long long g = gcd(p * (p - 1), goal);
for (int i = 1; i <= n; i++)
{
pll x;
scanf("%lld%lld", &x.first, &x.second);
val[i] = BSGS(x), g = gcd(g, val[i]);
}
for (int i = 1; i <= n; i++) val[i] /= g;
goal /= g;
dp();
if (f[pos[0]][0] > n)
{
puts("-1");
return 0;
}
pair <int, int> now = make_pair(pos[0], 0);
while (now.first)
{
arr[++m] = val[now.first];
ans[m].first = now.first;
now = from[now.first][now.second];
}
arr[++m] = p * (p - 1);
work(1, goal);
printf("%d\n", m - 1);
for (int i = 1; i < m; i++)
printf("%d_%lld\n", ans[i].first, ans[i].second);
return 0;
}
```