diff --git a/sky_area/sky_area_clustering.py b/sky_area/sky_area_clustering.py index cc79f16..8148cfc 100644 --- a/sky_area/sky_area_clustering.py +++ b/sky_area/sky_area_clustering.py @@ -6,8 +6,8 @@ from astropy.utils.misc import NumpyRNGContext import healpy as hp import numpy as np -import numpy.linalg as nl from scipy.stats import gaussian_kde +from scipy.cluster.vq import kmeans, vq from lalinference.bayestar import distance, moc from functools import partial from six.moves import copyreg @@ -15,98 +15,29 @@ __all__ = ('Clustered2DSkyKDE', 'Clustered3DSkyKDE', 'Clustered2Plus1DSkyKDE') -def km_assign(mus, cov, pts): - """Implements the assignment step in the k-means algorithm. Given a - set of centers, ``mus``, a covariance matrix used to produce a - metric on the space, ``cov``, and a set of points, ``pts`` (shape - ``(npts, ndim)``), assigns each point to its nearest center, - returning an array of indices of shape ``(npts,)`` giving the - assignments. - - """ - k = mus.shape[0] - n = pts.shape[0] - - dists = np.zeros((k, n)) - - for i, mu in enumerate(mus): - dx = pts - mu - try: - dists[i, :] = np.sum(dx * nl.solve(cov, dx.T).T, axis=1) - except nl.LinAlgError: - dists[i, :] = np.nan - - return np.nanargmin(dists, axis=0) - - -def km_centroids(pts, assign, k): - """Implements the centroid-update step of the k-means algorithm. - Given a set of points, ``pts``, of shape ``(npts, ndim)``, and an - assignment of each point to a region, ``assign``, and the number - of means, ``k``, returns an array of shape ``(k, ndim)`` giving - the centroid of each region. - - """ - - mus = np.zeros((k, pts.shape[1])) - for i in range(k): - sel = assign == i - if np.sum(sel) > 0: - mus[i, :] = np.mean(pts[sel, :], axis=0) - else: - mus[i, :] = pts[np.random.randint(pts.shape[0]), :] - - return mus - - -def k_means(pts, k): - """Implements k-means clustering on the set of points. - - :param pts: Array of shape ``(npts, ndim)`` giving the points on - which k-means is to operate. - - :param k: Positive integer giving the number of regions. - - :return: ``(centroids, assign)``, where ``centroids`` is an ``(k, - ndim)`` array giving the centroid of each region, and ``assign`` - is a ``(npts,)`` array of integers between 0 (inclusive) and k - (exclusive) indicating the assignment of each point to a region. - - """ - assert pts.shape[0] > k, 'must have more points than means' - +def whiten(pts): + # There has to be a more elegant way to do this. cov = np.cov(pts, rowvar=0) + W, V = np.linalg.eigh(cov) + inv_sqrt_cov = np.dot(V, np.dot(np.diag(W**-0.5), V.T)) + return np.dot(pts, inv_sqrt_cov.T) - mus = np.random.permutation(pts)[:k, :] - assign = km_assign(mus, cov, pts) - while True: - old_mus = mus - old_assign = assign - - mus = km_centroids(pts, assign, k) - assign = km_assign(mus, cov, pts) - if np.all(assign == old_assign): - break - - return mus, assign - - -def _cluster(cls, pts, trials, i, seed): +def _cluster(cls, pts, whitened_pts, trials, i, seed): k = i // trials if k == 0: raise ValueError('Expected at least one cluster') + elif k == 1: + assign = np.zeros(len(pts), dtype=np.intp) + else: + with NumpyRNGContext(i + seed): + means, _ = kmeans(pts, k, iter=1, thresh=0) + assign, _ = vq(pts, means) try: - if k == 1: - assign = np.zeros(len(pts), dtype=np.intp) - else: - with NumpyRNGContext(i + seed): - _, assign = k_means(pts, k) obj = cls(pts, assign=assign) except np.linalg.LinAlgError: return -np.inf, - else: - return obj.bic, k, obj.kdes + return obj.bic, k, obj.kdes class _mapfunc(object): @@ -133,7 +64,10 @@ def __init__(self, pts, max_k=40, trials=5, assign=None, # The seed must be an unsigned 32-bit integer, so if there are n # threads, then s must be drawn from the interval [0, 2**32 - n). seed = np.random.randint(0, 2**32 - max_k * trials) - func = partial(_cluster, type(self), pts, trials, seed=seed) + # Pre-whiten the points + whitened_pts = whiten(pts) + func = partial( + _cluster, type(self), pts, whitened_pts, trials, seed=seed) self.bic, self.k, self.kdes = max( self._map(func, range(trials, (max_k + 1) * trials))) else: