-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fixed description * Updated weighted matching * Added weighted matching stress tests * updated desc
- Loading branch information
Showing
2 changed files
with
86 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,89 +1,47 @@ | ||
/** | ||
* Author: Stanford | ||
* Date: Unknown | ||
* Source: Stanford Notebook | ||
* Description: Min cost bipartite matching. Negate costs for max cost. | ||
* Time: O(N^3) | ||
* Status: tested during ICPC 2015 | ||
* Author: Benjamin Qi, Chilli | ||
* Date: 2020-04-04 | ||
* License: CC0 | ||
* Source: https://github.com/bqi343/USACO/blob/master/Implementations/content/graphs%20(12)/Matching/Hungarian.h | ||
* Description: Given array of (possibly negative) costs to complete $N$ jobs | ||
* w/ $M$ workers $(N \le M)$, finds min cost to complete all jobs s.t. each | ||
* worker is assigned to at most one job. Takes cost[N][M], where cost[i][j] = | ||
* cost for i-th job to be completed by j-th worker and returns (min cost, | ||
* match), where match[i] = worker assigned to i-th job. Negate costs for max | ||
* cost. | ||
* Time: O(N^2M) | ||
* Status: Tested on kattis:cordonbleu, stress-tested | ||
*/ | ||
#pragma once | ||
|
||
typedef vector<double> vd; | ||
bool zero(double x) { return fabs(x) < 1e-10; } | ||
double minCostMatching(const vector<vd>& cost, vi& L, vi& R) { | ||
int n = sz(cost), mated = 0; | ||
vd dist(n), u(n), v(n); | ||
vi dad(n), seen(n); | ||
|
||
/// construct dual feasible solution | ||
rep(i,0,n) { | ||
u[i] = cost[i][0]; | ||
rep(j,1,n) u[i] = min(u[i], cost[i][j]); | ||
} | ||
rep(j,0,n) { | ||
v[j] = cost[0][j] - u[0]; | ||
rep(i,1,n) v[j] = min(v[j], cost[i][j] - u[i]); | ||
} | ||
|
||
/// find primal solution satisfying complementary slackness | ||
L = R = vi(n, -1); | ||
rep(i,0,n) rep(j,0,n) { | ||
if (R[j] != -1) continue; | ||
if (zero(cost[i][j] - u[i] - v[j])) { | ||
L[i] = j; | ||
R[j] = i; | ||
mated++; | ||
break; | ||
} | ||
} | ||
|
||
for (; mated < n; mated++) { // until solution is feasible | ||
int s = 0; | ||
while (L[s] != -1) s++; | ||
fill(all(dad), -1); | ||
fill(all(seen), 0); | ||
rep(k,0,n) | ||
dist[k] = cost[s][k] - u[s] - v[k]; | ||
|
||
int j = 0; | ||
for (;;) { /// find closest | ||
j = -1; | ||
rep(k,0,n){ | ||
if (seen[k]) continue; | ||
if (j == -1 || dist[k] < dist[j]) j = k; | ||
pair<int, vi> hungarian(const vector<vi> &a) { | ||
if (a.empty()) return {0, {}}; | ||
int n = sz(a) + 1, m = sz(a[0]) + 1; | ||
vi u(n), v(m), p(m), ans(n - 1); | ||
rep(i,1,n) { | ||
p[0] = i; | ||
int j0 = 0; // add "dummy" worker 0 | ||
vi dist(m, INT_MAX), pre(m, -1); | ||
vector<bool> done(m + 1); | ||
do { // dijkstra | ||
done[j0] = true; | ||
int i0 = p[j0], j1, delta = INT_MAX; | ||
rep(j,1,m) if (!done[j]) { | ||
auto cur = a[i0 - 1][j - 1] - u[i0] - v[j]; | ||
if (cur < dist[j]) dist[j] = cur, pre[j] = j0; | ||
if (dist[j] < delta) delta = dist[j], j1 = j; | ||
} | ||
seen[j] = 1; | ||
int i = R[j]; | ||
if (i == -1) break; | ||
rep(k,0,n) { /// relax neighbors | ||
if (seen[k]) continue; | ||
auto new_dist = dist[j] + cost[i][k] - u[i] - v[k]; | ||
if (dist[k] > new_dist) { | ||
dist[k] = new_dist; | ||
dad[k] = j; | ||
} | ||
rep(j,0,m) { | ||
if (done[j]) u[p[j]] += delta, v[j] -= delta; | ||
else dist[j] -= delta; | ||
} | ||
j0 = j1; | ||
} while (p[j0]); | ||
while (j0) { // update alternating path | ||
int j1 = pre[j0]; | ||
p[j0] = p[j1], j0 = j1; | ||
} | ||
|
||
/// update dual variables | ||
rep(k,0,n) { | ||
if (k == j || !seen[k]) continue; | ||
auto w = dist[k] - dist[j]; | ||
v[k] += w, u[R[k]] -= w; | ||
} | ||
u[s] += dist[j]; | ||
|
||
/// augment along path | ||
while (dad[j] >= 0) { | ||
int d = dad[j]; | ||
R[j] = R[d]; | ||
L[R[j]] = j; | ||
j = d; | ||
} | ||
R[j] = s; | ||
L[s] = j; | ||
} | ||
auto value = vd(1)[0]; | ||
rep(i,0,n) value += cost[i][L[i]]; | ||
return value; | ||
rep(j,1,m) if (p[j]) ans[p[j] - 1] = j - 1; | ||
return {-v[0], ans}; // min cost | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#include "../utilities/template.h" | ||
#include "../utilities/utils.h" | ||
#include "../utilities/random.h" | ||
|
||
#include "../../content/graph/WeightedMatching.h" | ||
#include <bits/extc++.h> /// include-line, keep-include | ||
#include "../../content/graph/MinCostMaxFlow.h" | ||
|
||
void test(int N, int mxCost, int iters) { | ||
for (int it = 0; it < iters; it++) { | ||
int n = randRange(1, N), m = randRange(1, N); | ||
if (n > m) | ||
swap(n, m); | ||
|
||
MCMF mcmf(n + m + 2); | ||
int s = 0; | ||
int t = 1; | ||
for (int i = 0; i < n; i++) | ||
mcmf.addEdge(s, i + 2, 1, 0); | ||
for (int i = 0; i < m; i++) | ||
mcmf.addEdge(2 + n + i, t, 1, 0); | ||
|
||
vector<vi> cost(n, vi(m)); | ||
for (int i = 0; i < n; i++) { | ||
for (int j = 0; j < m; j++) { | ||
cost[i][j] = randRange(-mxCost, mxCost); | ||
mcmf.addEdge(i + 2, 2 + n + j, 1, cost[i][j]); | ||
} | ||
} | ||
mcmf.setpi(s); | ||
auto maxflow = mcmf.maxflow(s, t); | ||
auto matching = hungarian(cost); | ||
assert(maxflow.first == n); | ||
assert(maxflow.second == matching.first); | ||
int matchSum = 0; | ||
for (int i = 0; i < n; i++) | ||
matchSum += cost[i][matching.second[i]]; | ||
assert(matchSum == matching.first); | ||
return; | ||
} | ||
} | ||
signed main() { | ||
test(25, 5, 1000); | ||
test(100, 1000, 100); | ||
test(100, 1, 50); | ||
test(5, 5, 10000); | ||
cout << "Tests passed!" << endl; | ||
} |