P4217 产品销售

· · 题解

首先考虑一个费用流建模:

  1. 源点向第 i 个点连边,容量为 u_i,费用为 p_i,这代表每季度的生产;
  2. i 个点向汇点连边,容量为 d_i,费用为 0,这代表每个季度客户的要求;

为了方便,我们把所有边反向,这样从源点出发的所有边都要强制满流。

但是由于 n 开到了 10^5,所以直接费用流会 T。所以使用模拟费用流。

具体地,按顺序枚举每一个点 x,尝试对这个点进行一个增广。因此我们要找到从 x 到汇点的费用最小的路径。可以发现增广路要么是往左走的,要么是往右走的。我们对两种增广路分别维护。

我们先来讨论如何增广。假设我们已经知道了 S \rightarrow x \rightarrow y \rightarrow TxT 的费用最小的路径,此时有两种情况。

在增广之后,可能会出现某些点与汇点的边满流的情况。这种情况只需要把这些点从备选的增广路终点中删去即可。

然后考虑从一个点 x 转向下一个点 x + 1 时代价的变化。首先 x + 1 及以后的点代价减去 c_x。其次 x 及以前的点代价根据 xx + 1 之间边的情况来定。如果这条边被用来增广过,那 x 之前点的代价要减去 c_x,否则加上 m_x

注意到实际上形如一个区间加,区间 \min 的问题,使用线段树维护。首先开两棵线段树维护当前点两边点的代价,然后第三棵线段树用来维护每条反向边的剩余流量。当增广时新增了反向边,就在第三棵线段树上区间加。如果要往左增广了,就在第三棵线段树上求一个区间 \min。当增广时用到了反向边,就在第三棵线段树上减。如果某条反向边满流了,就在第二棵线段树上修改点的代价。如果有一个终点满流了,那就在第二棵线段树上把它的位置加个 +\infty

在实现上,前两棵线段树是好实现的。对于最后一棵,我们先把所有元素赋成 0。接下来如果要检查是否有反向边满流,我们就暴力 dfs,找值为 0 的点,把它的流量限制赋为 +\infty。如果这是一条满流的反向边,我们就改代价;如果这条反向边根本就没出现过,那我们就不管了。这两种情况可以通过判断每条边有没有被往右增广过来分辨。

代码

#include <iostream>
#define int long long
using namespace std;
const int inf = 21474836470000;
int d[100005], u[100005], p[100005], m[100005], c[100005], pre[100005];
int n;
struct Segment_Tree1 {
    int mn[400005], mnp[400005];
    int tg[400005];
    inline void tag(int x, int y) { mn[x] += y, tg[x] += y; }
    inline void pushdown(int o) {
        if (!tg[o]) 
            return;
        tag(o << 1, tg[o]);
        tag(o << 1 | 1, tg[o]);
        tg[o] = 0;
    }
    inline void pushup(int o) {
        if (mn[o << 1] < mn[o << 1 | 1]) 
            mn[o] = mn[o << 1], mnp[o] = mnp[o << 1];
        else 
            mn[o] = mn[o << 1 | 1], mnp[o] = mnp[o << 1 | 1];
    }
    void Build(int o, int l, int r) {
        if (l == r) {
            mn[o] = p[l];
            mnp[o] = r;
            return;
        }
        int mid = (l + r) >> 1;
        Build(o << 1, l, mid);
        Build(o << 1 | 1, mid + 1, r);
        pushup(o);
    }
    void Add(int o, int l, int r, int L, int R, int v) {
        if (L > R) 
            return;
        if (L <= l && r <= R) {
            tag(o, v);
            return;
        }
        pushdown(o);
        int mid = (l + r) >> 1;
        if (L <= mid) 
            Add(o << 1, l, mid, L, R, v);
        if (R > mid) 
            Add(o << 1 | 1, mid + 1, r, L, R, v);
        pushup(o);
    }
    int Query(int o, int l, int r, int L, int R, int& p) {
        if (L > R) {
            p = 0;
            return inf;
        }
        if (L <= l && r <= R) {
            p = mnp[o];
            return mn[o];
        }
        pushdown(o);
        int mid = (l + r) >> 1;
        if (R <= mid) 
            return Query(o << 1, l, mid, L, R, p);
        if (L > mid) 
            return Query(o << 1 | 1, mid + 1, r, L, R, p);
        int lp, lv = Query(o << 1, l, mid, L, R, lp);
        int rp, rv = Query(o << 1 | 1, mid + 1, r, L, R, rp);
        if (lv > rv) 
            swap(lv, rv), swap(lp, rp);
        return p = lp, lv;
    }
} seg1, seg2, seg3;
void Deal(int o, int l, int r) {
    if (seg3.mn[o] != 0) 
        return;
    if (l == r) {
        seg3.mn[o] = inf;
        if (pre[l + 1]) 
            seg2.Add(1, 1, n, 1, l, c[l] + m[r]);
        return;
    }
    seg3.pushdown(o);
    int mid = (l + r) >> 1;
    if (seg3.mn[o << 1] == 0) 
        Deal(o << 1, l, mid);
    if (seg3.mn[o << 1 | 1] == 0) 
        Deal(o << 1 | 1, mid + 1, r);
    seg3.pushup(o);
}
void Deal(int o, int l, int r, int L, int R) {
    if (!L || !R || L > R) 
        return;
    if (L <= l && r <= R) {
        Deal(o, l, r);
        return;
    }
    seg3.pushdown(o);
    int mid = (l + r) >> 1;
    if (L <= mid) 
        Deal(o << 1, l, mid, L, R);
    if (R > mid) 
        Deal(o << 1 | 1, mid + 1, r, L, R);
    seg3.pushup(o);
}
signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    cin >> n;
    for (int i = 1; i <= n; i++) cin >> d[i];
    for (int i = 1; i <= n; i++) cin >> u[i];
    for (int i = 1; i <= n; i++) cin >> p[i];
    seg1.Build(1, 1, n), seg2.Build(1, 1, n);
    for (int i = 1; i < n; i++) cin >> m[i];
    for (int i = 1; i < n; i++) cin >> c[i], seg1.Add(1, 1, n, i + 1, n, c[i]);
    int ans = 0;
    for (int i = 1; i <= n; i++) {
        pre[i] += pre[i - 1];
        seg1.Add(1, 1, n, i, n, -c[i - 1]);
        if (pre[i]) 
            seg2.Add(1, 1, n, 1, i - 1, -c[i - 1]);
        else 
            seg2.Add(1, 1, n, 1, i - 1, m[i - 1]);
        Deal(1, 1, n - 1, i - 1, i - 1);
        while (d[i]) {
            int x;
            int pl, ml = seg2.Query(1, 1, n, 1, i - 1, pl);
            int pr, mr = seg1.Query(1, 1, n, i, n, pr);
            if (mr < ml) {
                int f = min(d[i], u[pr]);
                d[i] -= f, u[pr] -= f;
                ans += f * mr;
                pre[i]++, pre[pr + 1]--;
                seg3.Add(1, 1, n - 1, i, pr - 1, f);
                if (!u[pr]) {
                    seg1.Add(1, 1, n, pr, pr, inf);
                    seg2.Add(1, 1, n, pr, pr, inf);
                }
            } else {
                int x;
                int f = min(seg3.Query(1, 1, n - 1, pl, i - 1, x), min(u[pl], d[i]));
                d[i] -= f, u[pl] -= f;
                seg3.Add(1, 1, n - 1, pl, i - 1, -f);
                Deal(1, 1, n - 1, pl, i - 1);
                ans += f * ml;
                if (!u[pl]) {
                    seg1.Add(1, 1, n, pl, pl, inf);
                    seg2.Add(1, 1, n, pl, pl, inf);
                }
            }
        }
    }
    cout << ans << "\n";
    return 0;
}

关于复杂度

我代码里写的是 while(d[i]) { ... }。这样复杂度是正确的。每次增广我们要么流满这个点,要么流满增广路的终点,要么流满增广路上的一条反向边。注意到这些东西都至多会被流满一次,所以至多增广 O(n) 次。