题解 CF607E 【Cross Sum】

xht

2020-02-20 18:06:35

Solution

> [CF607E Cross Sum](https://codeforces.com/contest/607/problem/E) ## 题意 - 在平面直角坐标系中,给定 $n$ 条不同的直线,第 $i$ 条直线为 $y=a_i x+b_i$。 - 可重点集 $\mathcal I$ 为 $n$ 条直线两两的交点,可重数集 $\mathcal D$ 为 $\mathcal I$ 中所有点与 $(p,q)$ 的距离。 - 你需要求出 $\mathcal D$ 中前 $m$ 小的数之和。 - $n \le 5 \times 10^4$,$m \le 3 \times 10^7$,$|p|,|q|,|a_i|,|b_i| \le 10^3$ 且输入的整数值为实际值的 $10^3$ 倍,答案精度误差 $\le 10^6$。 ## 题解 为了方便,我们可以直接将 $(p,q)$ 作为坐标系原点,则每条直线可以转成 $y+q=a_i(x+p)+b_i$,也就是 $y=a_ix + a_ip+b_i-q$。 首先显然要找到一个半径最小的圆使其能够覆盖至少 $m$ 个交点。 考虑二分答案转化为判定问题。我们可以求出每条直线与圆的交点坐标并极角排序和离散化,问题就变成若干条线段有多少两两相交,是经典的扫描线问题,可以用树状数组求出。这部分时间复杂度 $\mathcal O(n \log n \log w)$。 找到这个半径之后,我们现在要计算答案。 与上面的过程类似,但由于此时交点数最多只有 $m$,因此我们可以用 set 代替树状数组,直接得到所有相交的线段,然后求出它们的交点与原点的距离。 但可能有多个点恰好在圆上,因此我们在统计距离的时候只统计严格在圆内的点,如果个数少于 $m$,则剩下的距离全都是半径。这部分时间复杂度 $\mathcal O(m \log n)$。 总时间复杂度 $\mathcal O(n \log n \log w + m \log n)$,需要注意一下精度问题。 ## 代码 ```cpp const int N = 5e4 + 7; int n, m, x, y, t, v[N], c[N*2]; ld p, q, a[N], b[N]; pair<ld, int> s[N*2]; inline void add(int x) { while (x <= t) ++c[x], x += x & -x; } inline int ask(int x) { int k = 0; while (x) k += c[x], x -= x & -x; return k; } inline void work(ld r) { t = 0; for (int i = 1; i <= n; i++) { ld a = ::a[i], b = ::b[i]; ld d = 4 * a * a * b * b - 4 * (a * a + 1) * (b * b - r * r); if (d < eps) continue; d = sqrt(d); ld x1 = (-2 * a * b - d) / (2 * (a * a + 1)); ld x2 = (-2 * a * b + d) / (2 * (a * a + 1)); ld y1 = a * x1 + b, y2 = a * x2 + b; s[++t] = mp(P(x1, y1).a(), i), s[++t] = mp(P(x2, y2).a(), i); v[i] = 0; } sort(s + 1, s + t + 1); for (int i = t; i; i--) if (!v[s[i].se]) v[s[i].se] = i; } inline bool pd(ld r) { work(r); for (int i = 1; i <= t; i++) c[i] = 0; ll ret = 0; for (int i = 1; i <= t; i++) if (v[s[i].se] != i) { ret += ask(v[s[i].se]) - ask(i - 1); add(v[s[i].se]); } return ret >= m; } inline ld calc(int i, int j) { ld x = (b[j] - b[i]) / (a[i] - a[j]), y = a[i] * x + b[i]; return sqrt(x * x + y * y); } inline ld calc(ld r) { work(r); set<pi> st; ld ret = 0; int cnt = 0; for (int i = 1; i <= t; i++) if (v[s[i].se] != i) { auto it = st.upper_bound(mp(i + 1, 0)); while (it != st.end() && (it -> fi) < v[s[i].se]) ret += calc(s[i].se, it -> se), ++it, ++cnt; st.insert(mp(v[s[i].se], s[i].se)); } return ret + (m - cnt) * r; } int main() { rd(n), rd(x), rd(y), rd(m), p = x / 1e3, q = y / 1e3; for (int i = 1; i <= n; i++) rd(x), rd(y), a[i] = x / 1e3, b[i] = y / 1e3 + a[i] * p - q; ld l = 0, r = 1e10; for (int i = 1; i <= 100; i++) { ld mid = (l + r) / 2; if (pd(mid)) r = mid; else l = mid; } printf("%.10Lf\n", max(0.0L, calc(r))); return 0; } ```