题解

又写了一遍KM算法,这题刚好是把最大最小KM拼在一起写的,感觉比较有记录价值

感觉KM始终不熟啊QAQ

算法流程大抵如下,原理就是每次我们通过减少最少的匹配量达成最大匹配,所以获得的一定是最大价值

1.我们先给左部点求一个期望大小,如果是最大KM,期望大小就是最大的那条边的权值,如果是最小KM,期望大小就是最小的那条边的权值
2.然后跑二分图匹配,两个点能匹配的条件是左点\(u\)的期望值加右点\(v\)的期望值刚好是边权
3.给无法访问的点更新断层大小,如果是最小匹配,那么断层就是\(c[u][v] - (ex_l[u] + ex_r[v])\),如果是最大匹配,就是\((ex_l[u] + ex_r[v]) - c[u][v]\)
4.在没有被访问的右点里寻找最小的减少量\(d\)
5.给访问过的左点和右点,如果是最小匹配,左点加上\(d\),因为要包括进一些更大的边,右点减去\(d\),如果是最大匹配,左点减去\(d\),因为要包括进一些更小的边,右点加上\(d\)

代码

#include <bits/stdc++.h>
#define fi first
#define se second
#define pii pair<int, int>
#define pdi pair<db, int>
#define mp make_pair
#define pb push_back
#define enter putchar('\n')
#define space putchar(' ')
#define eps 1e-8
#define mo 974711
#define MAXN 205
//#define ivorysi
using namespace std;
typedef long long int64;
typedef double db;
template <class T>
void read(T &res) {
    res = 0;
    char c = getchar();
    T f = 1;
    while (c < '0' || c > '9') {
        if (c == '-')
            f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        res = res * 10 + c - '0';
        c = getchar();
    }
    res *= f;
}
template <class T>
void out(T x) {
    if (x < 0) {
        x = -x;
        putchar('-');
    }
    if (x >= 10) {
        out(x / 10);
    }
    putchar('0' + x % 10);
}
int N;
int c[105][105];
int ex_l[105], ex_r[105], slack[105], matc[105];
bool vis_l[105], vis_r[105];
bool match(int u, int on) {
    vis_l[u] = 1;
    for (int v = 1; v <= N; ++v) {
        if (!vis_r[v]) {
            if (vis_r[v])
                continue;
            int gap;
            if (on == 0)
                gap = c[u][v] - ex_l[u] - ex_r[v];
            else
                gap = ex_l[u] + ex_r[v] - c[u][v];
            if (gap == 0) {
                vis_r[v] = 1;
                if (!matc[v] || match(matc[v], on)) {
                    matc[v] = u;
                    return true;
                }
            } else
                slack[v] = min(slack[v], gap);
        }
    }
    return false;
}
int KM(int on) {
    for (int i = 1; i <= N; ++i) {
        ex_r[i] = 0;
        if (on == 0)
            ex_l[i] = 0x7fffffff;
        else
            ex_l[i] = 0;
        for (int j = 1; j <= N; ++j) {
            if (on == 0)
                ex_l[i] = min(ex_l[i], c[i][j]);
            else
                ex_l[i] = max(ex_l[i], c[i][j]);
        }
    }
    memset(matc, 0, sizeof(matc));
    for (int i = 1; i <= N; ++i) {
        for (int j = 1; j <= N; ++j) slack[j] = 0x7fffffff;
        while (1) {
            memset(vis_l, 0, sizeof(vis_l));
            memset(vis_r, 0, sizeof(vis_r));
            if (match(i, on))
                break;
            int d = 0x7fffffff;
            for (int j = 1; j <= N; ++j) {
                if (!vis_r[j])
                    d = min(d, slack[j]);
            }
            for (int j = 1; j <= N; ++j) {
                if (on == 0) {
                    if (vis_l[j])
                        ex_l[j] += d;
                    if (vis_r[j])
                        ex_r[j] -= d;
                    else
                        slack[j] -= d;
                } else {
                    if (vis_l[j])
                        ex_l[j] -= d;
                    if (vis_r[j])
                        ex_r[j] += d;
                    else
                        slack[j] -= d;
                }
            }
        }
    }
    int res = 0;
    for (int v = 1; v <= N; ++v) {
        res += c[matc[v]][v];
    }
    return res;
}
void Solve() {
    read(N);
    for (int i = 1; i <= N; ++i) {
        for (int j = 1; j <= N; ++j) {
            read(c[i][j]);
        }
    }
    out(KM(0));
    enter;
    out(KM(1));
    enter;
}
int main() {
#ifdef ivorysi
    freopen("f1.in", "r", stdin);
#endif
    Solve();
    return 0;
}