|
17 | 17 | from .bregman import sinkhorn, jcpot_barycenter |
18 | 18 | from .lp import emd |
19 | 19 | from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots |
20 | | -from .utils import list_to_array, check_params, BaseEstimator |
| 20 | +from .utils import list_to_array, check_params, BaseEstimator, deprecated |
21 | 21 | from .unbalanced import sinkhorn_unbalanced |
| 22 | +from .gaussian import empirical_bures_wasserstein_mapping |
22 | 23 | from .optim import cg |
23 | 24 | from .optim import gcg |
24 | 25 |
|
@@ -679,112 +680,7 @@ def df(G): |
679 | 680 | return G, L |
680 | 681 |
|
681 | 682 |
|
682 | | -def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, |
683 | | - wt=None, bias=True, log=False): |
684 | | - r"""Return OT linear operator between samples. |
685 | | -
|
686 | | - The function estimates the optimal linear operator that aligns the two |
687 | | - empirical distributions. This is equivalent to estimating the closed |
688 | | - form mapping between two Gaussian distributions :math:`\mathcal{N}(\mu_s,\Sigma_s)` |
689 | | - and :math:`\mathcal{N}(\mu_t,\Sigma_t)` as proposed in |
690 | | - :ref:`[14] <references-OT-mapping-linear>` and discussed in remark 2.29 in |
691 | | - :ref:`[15] <references-OT-mapping-linear>`. |
692 | | -
|
693 | | - The linear operator from source to target :math:`M` |
694 | | -
|
695 | | - .. math:: |
696 | | - M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} |
697 | | -
|
698 | | - where : |
699 | | -
|
700 | | - .. math:: |
701 | | - \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} |
702 | | - \Sigma_s^{-1/2} |
703 | | -
|
704 | | - \mathbf{b} &= \mu_t - \mathbf{A} \mu_s |
705 | | -
|
706 | | - Parameters |
707 | | - ---------- |
708 | | - xs : array-like (ns,d) |
709 | | - samples in the source domain |
710 | | - xt : array-like (nt,d) |
711 | | - samples in the target domain |
712 | | - reg : float,optional |
713 | | - regularization added to the diagonals of covariances (>0) |
714 | | - ws : array-like (ns,1), optional |
715 | | - weights for the source samples |
716 | | - wt : array-like (ns,1), optional |
717 | | - weights for the target samples |
718 | | - bias: boolean, optional |
719 | | - estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) |
720 | | - log : bool, optional |
721 | | - record log if True |
722 | | -
|
723 | | -
|
724 | | - Returns |
725 | | - ------- |
726 | | - A : (d, d) array-like |
727 | | - Linear operator |
728 | | - b : (1, d) array-like |
729 | | - bias |
730 | | - log : dict |
731 | | - log dictionary return only if log==True in parameters |
732 | | -
|
733 | | -
|
734 | | - .. _references-OT-mapping-linear: |
735 | | - References |
736 | | - ---------- |
737 | | - .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of |
738 | | - distributions", Journal of Optimization Theory and Applications |
739 | | - Vol 43, 1984 |
740 | | -
|
741 | | - .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal |
742 | | - Transport", 2018. |
743 | | -
|
744 | | -
|
745 | | - """ |
746 | | - xs, xt = list_to_array(xs, xt) |
747 | | - nx = get_backend(xs, xt) |
748 | | - |
749 | | - d = xs.shape[1] |
750 | | - |
751 | | - if bias: |
752 | | - mxs = nx.mean(xs, axis=0)[None, :] |
753 | | - mxt = nx.mean(xt, axis=0)[None, :] |
754 | | - |
755 | | - xs = xs - mxs |
756 | | - xt = xt - mxt |
757 | | - else: |
758 | | - mxs = nx.zeros((1, d), type_as=xs) |
759 | | - mxt = nx.zeros((1, d), type_as=xs) |
760 | | - |
761 | | - if ws is None: |
762 | | - ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] |
763 | | - |
764 | | - if wt is None: |
765 | | - wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] |
766 | | - |
767 | | - Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) |
768 | | - Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) |
769 | | - |
770 | | - Cs12 = nx.sqrtm(Cs) |
771 | | - Cs_12 = nx.inv(Cs12) |
772 | | - |
773 | | - M0 = nx.sqrtm(dots(Cs12, Ct, Cs12)) |
774 | | - |
775 | | - A = dots(Cs_12, M0, Cs_12) |
776 | | - |
777 | | - b = mxt - nx.dot(mxs, A) |
778 | | - |
779 | | - if log: |
780 | | - log = {} |
781 | | - log['Cs'] = Cs |
782 | | - log['Ct'] = Ct |
783 | | - log['Cs12'] = Cs12 |
784 | | - log['Cs_12'] = Cs_12 |
785 | | - return A, b, log |
786 | | - else: |
787 | | - return A, b |
| 683 | +OT_mapping_linear = deprecated(empirical_bures_wasserstein_mapping) |
788 | 684 |
|
789 | 685 |
|
790 | 686 | def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5, |
@@ -1378,10 +1274,10 @@ class label |
1378 | 1274 | self.mu_t = self.distribution_estimation(Xt) |
1379 | 1275 |
|
1380 | 1276 | # coupling estimation |
1381 | | - returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg, |
1382 | | - ws=nx.reshape(self.mu_s, (-1, 1)), |
1383 | | - wt=nx.reshape(self.mu_t, (-1, 1)), |
1384 | | - bias=self.bias, log=self.log) |
| 1277 | + returned_ = empirical_bures_wasserstein_mapping(Xs, Xt, reg=self.reg, |
| 1278 | + ws=nx.reshape(self.mu_s, (-1, 1)), |
| 1279 | + wt=nx.reshape(self.mu_t, (-1, 1)), |
| 1280 | + bias=self.bias, log=self.log) |
1385 | 1281 |
|
1386 | 1282 | # deal with the value of log |
1387 | 1283 | if self.log: |
|
0 commit comments