· 6 min read

A Primer on Optimal Transport #NIPS2017


I wrote these notes in December 2017 after attending my first NeurIPS and never published them. Eight years later, I’m hitting publish with no edits — just this note as context. Reading them back is strange. The five takeaways from that week — Bayesian deep learning, fairness and bias, the need for theory, deep RL, GANs — all became defining research arcs of the decade that followed. The Rahimi & Recht “alchemy” talk, which was controversial at the time, looks prophetic now. And the throwaway concern about “generating realistic fake videos with geopolitical consequences” landed harder than I think anyone in that room expected.

The corporate circus section is its own kind of time capsule: an Intel Flo Rida concert, Nvidia handing out $3k GPUs to the audience, and an invite-only Tesla party with Elon Musk and Andrej Karpathy. A different era.

What strikes me most, though, is how optimistic and open everything felt. The field was moving fast but still felt legible — you could go to one conference and get your arms around the major themes. That’s long gone. Anyway — notes from the before-times, published from the future.

NIPS

A Primer on Optimal Transport

Marco Cuturi | Justin M Solomon

What is optimal transport? The natural geometry for probability measures — it puts a distance on the space of distributions, enabling comparisons of “bags of features.” Given statistical models pp and pp', OT gives us a notion of divergence to measure how well we’re doing.

Outline: intro → algorithms → apps (WW as a loss) → apps (WW for estimation). Slides at optimaltransport.github.io.

1. The Monge Problem

Monge (1781): move a pile of dirt μ\mu into a hole ν\nu (with shovels).

  • μ(x)\mu(x): height of pile at xx
  • y=T(x)y = T(x): destination for point xx — a map from μ\mu to ν\nu
  • d(x,T(x))d(x, T(x)): distance traveled
  • μ(x)d(x,T(x))\mu(x) \cdot d(x, T(x)): work done at xx

TT must satisfy the pushforward constraint T#μ=νT_\#\mu = \nu, meaning μ(A1)+μ(A2)=ν(B)\mu(A_1) + \mu(A_2) = \nu(B) for preimages mapping into BB.

The Monge problem: find the map TT that minimizes total work:

minT:T#μ=νd(x,T(x))dμ(x)\min_{T \,:\, T_\#\mu = \nu} \int d(x,\, T(x))\, d\mu(x)

Caveat: an optimal Monge map TT^* does not always exist (e.g., when μ\mu is a Dirac mass and ν\nu is not).

1a. The Kantorovich Relaxation

Kantorovich’s insight: instead of a deterministic map, allow measure couplings (joint distributions) γΠ(μ,ν)\gamma \in \Pi(\mu, \nu) — i.e., γ(x,y)0\gamma(x, y) \geq 0 with marginals μ\mu and ν\nu. This is just a linear program.

Primal (Kantorovich):

minγΠ(μ,ν)d(x,y)dγ(x,y)\min_{\gamma \in \Pi(\mu,\nu)} \int d(x,y)\, d\gamma(x,y)

Dual (potential functions):

maxf,gfdμ+gdνs.t.f(x)+g(y)d(x,y)\max_{f,\, g} \int f\, d\mu + \int g\, d\nu \quad \text{s.t.} \quad f(x) + g(y) \leq d(x,y)

The dual is elegant: ff and gg are the potential functions, and the dual tells you which points are most “expensive” to transport.

Proposition: for well-behaved cost functions, if μ\mu has a density then an optimal Monge map TT^* between μ\mu and ν\nu exists.

1b. pp-Wasserstein Distance

The pp-Wasserstein distance between probability measures μ\mu and ν\nu:

Wp(μ,ν)=(infγΠ(μ,ν)d(x,y)pdγ(x,y))1/pW_p(\mu, \nu) = \left( \inf_{\gamma \in \Pi(\mu,\nu)} \int d(x,y)^p\, d\gamma(x,y) \right)^{1/p}

This is a true metric on the space of probability measures. The geometry it induces is very different from information-theoretic metrics like KL divergence. McCann (1995) showed it gives rise to displacement interpolation — geodesics in the space of measures. (Solomon ‘15 has nice applications.)

2. How to Compute OT

Four cases:

  1. discrete → discrete
  2. discrete → continuous
  3. continuous → continuous
  4. continuous → discrete (open)

Cases 2–3 are largely “up for grabs.” Easy special cases:

  • Univariate: compute CDFs and quantile functions. WpW_p has a closed form: Wp(μ,ν)p=01Fμ1(t)Fν1(t)pdtW_p(\mu,\nu)^p = \int_0^1 |F_\mu^{-1}(t) - F_\nu^{-1}(t)|^p\, dt
  • Gaussians: closed form, TT is linear.
  • Dirac masses: Wp(δx,δy)=d(x,y)W_p(\delta_x, \delta_y) = d(x,y) — Wasserstein distance between point masses equals the ground distance.
  • Equal number of points: reduces to the Monge problem (an assignment problem).

Complexity of the LP: O(n3logn)O(n^3 \log n) via min-cost flow. Ouch.

Entropic Regularization and Sinkhorn

Optimal solutions PP^* to the LP are vertices of a polytope — unstable, non-unique, and non-differentiable. We want something faster, scalable, and differentiable.

Entropic regularization (Shannon entropy H(γ)=ijγijlogγijH(\gamma) = -\sum_{ij} \gamma_{ij} \log \gamma_{ij}):

minγΠ(μ,ν)C,γεH(γ)\min_{\gamma \in \Pi(\mu,\nu)} \langle C, \gamma \rangle - \varepsilon\, H(\gamma)

As ε\varepsilon \to \infty, the solution approaches the independent coupling μν\mu \otimes \nu; as ε0\varepsilon \to 0, it recovers the Monge solution. The regularization makes the problem strictly convex and smooth. [Wilson ‘62]

This can be solved with simple Lagrangians — leading to Sinkhorn’s algorithm: alternately rescale rows and columns of the kernel matrix Kij=eCij/εK_{ij} = e^{-C_{ij}/\varepsilon} to match the marginals μ\mu and ν\nu.

Sinkhorn = block coordinate ascent on the dual. [Altschuler et al. ‘17]

  • Convergence: linear O(nm)O(nm) in general; O(nlogn)O(n \log n) on gridded spaces using convolutions.
  • Sinkhorn interpolates between WW (hard OT) and MMD (kernel two-sample test).

Sample complexity caveat [Hashimoto ‘16, Bonneel ‘16, Shalit ‘16]: error in WW decreases very slowly in nn — bad sample complexity. The Wasserstein LP is not well-suited for high-dimensional data directly.

3. Applications

Retrieval: [Kusner ‘15] Word Mover’s Distance — document similarity via OT over word embeddings.

Barycenters: averaging measures under WW vs. L2L^2 gives very different results. The Wasserstein barycenter of distributions {μk}\{\mu_k\} with weights {λk}\{\lambda_k\}:

μˉ=argminμkλkW2(μ,μk)2\bar{\mu} = \arg\min_{\mu} \sum_k \lambda_k W_2(\mu, \mu_k)^2

Averaging histograms is an LP; or use primal descent on regularized WW [Cuturi ‘14]. Application: brain imaging, finding smooth interpolations between distributions.

Wasserstein Posterior (WASP): aggregate distributed posteriors using Wasserstein barycenters [Srivastava ‘15].

Wasserstein Propagation [Solomon ‘14]: semi-supervised learning on graphs — propagate label distributions via OT. Could fix label noise or handle missing data.

Dictionary learning / topic models [Rolet ‘16]: represent documents as mixtures of dictionary elements under a Wasserstein loss.

Wasserstein PCA: generalized principal geodesics in the space of measures (negative curvature space — worth investigating further).

Distributionally robust optimization [Esfahani ‘17]: learning with Wasserstein ambiguity — robust to perturbations of the training distribution (minimax formulation).

Domain adaptation:

  1. Estimate transport map TT from source to target domain
  2. Transport labeled source samples to target domain
  3. Train classifier on transported samples

Generative models: density fitting via maximum likelihood is just minimizing KL(pdatapθ)\text{KL}(p_\text{data} \| p_\theta). Instead, use a low-dimensional latent space with pushforward fθ:ZXf_\theta : \mathcal{Z} \to \mathcal{X}:

minθW(fθ#μz,νdata)\min_\theta\, W(f_{\theta\,\#}\,\mu_z,\, \nu_\text{data})

This is the Wasserstein GAN formulation [Arjovsky et al. ‘17] — use WW as the loss between data and model rather than JS divergence.