Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 18 additions & 84 deletions sky_area/sky_area_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,107 +6,38 @@
from astropy.utils.misc import NumpyRNGContext
import healpy as hp
import numpy as np
import numpy.linalg as nl
from scipy.stats import gaussian_kde
from scipy.cluster.vq import kmeans, vq
from lalinference.bayestar import distance, moc
from functools import partial
from six.moves import copyreg

__all__ = ('Clustered2DSkyKDE', 'Clustered3DSkyKDE', 'Clustered2Plus1DSkyKDE')


def km_assign(mus, cov, pts):
"""Implements the assignment step in the k-means algorithm. Given a
set of centers, ``mus``, a covariance matrix used to produce a
metric on the space, ``cov``, and a set of points, ``pts`` (shape
``(npts, ndim)``), assigns each point to its nearest center,
returning an array of indices of shape ``(npts,)`` giving the
assignments.

"""
k = mus.shape[0]
n = pts.shape[0]

dists = np.zeros((k, n))

for i, mu in enumerate(mus):
dx = pts - mu
try:
dists[i, :] = np.sum(dx * nl.solve(cov, dx.T).T, axis=1)
except nl.LinAlgError:
dists[i, :] = np.nan

return np.nanargmin(dists, axis=0)


def km_centroids(pts, assign, k):
"""Implements the centroid-update step of the k-means algorithm.
Given a set of points, ``pts``, of shape ``(npts, ndim)``, and an
assignment of each point to a region, ``assign``, and the number
of means, ``k``, returns an array of shape ``(k, ndim)`` giving
the centroid of each region.

"""

mus = np.zeros((k, pts.shape[1]))
for i in range(k):
sel = assign == i
if np.sum(sel) > 0:
mus[i, :] = np.mean(pts[sel, :], axis=0)
else:
mus[i, :] = pts[np.random.randint(pts.shape[0]), :]

return mus


def k_means(pts, k):
"""Implements k-means clustering on the set of points.

:param pts: Array of shape ``(npts, ndim)`` giving the points on
which k-means is to operate.

:param k: Positive integer giving the number of regions.

:return: ``(centroids, assign)``, where ``centroids`` is an ``(k,
ndim)`` array giving the centroid of each region, and ``assign``
is a ``(npts,)`` array of integers between 0 (inclusive) and k
(exclusive) indicating the assignment of each point to a region.

"""
assert pts.shape[0] > k, 'must have more points than means'

def whiten(pts):
# There has to be a more elegant way to do this.
cov = np.cov(pts, rowvar=0)
W, V = np.linalg.eigh(cov)
inv_sqrt_cov = np.dot(V, np.dot(np.diag(W**-0.5), V.T))
return np.dot(pts, inv_sqrt_cov.T)

mus = np.random.permutation(pts)[:k, :]
assign = km_assign(mus, cov, pts)
while True:
old_mus = mus
old_assign = assign

mus = km_centroids(pts, assign, k)
assign = km_assign(mus, cov, pts)

if np.all(assign == old_assign):
break

return mus, assign


def _cluster(cls, pts, trials, i, seed):
def _cluster(cls, pts, whitened_pts, trials, i, seed):
k = i // trials
if k == 0:
raise ValueError('Expected at least one cluster')
elif k == 1:
assign = np.zeros(len(pts), dtype=np.intp)
else:
with NumpyRNGContext(i + seed):
means, _ = kmeans(pts, k, iter=1, thresh=0)
assign, _ = vq(pts, means)
try:
if k == 1:
assign = np.zeros(len(pts), dtype=np.intp)
else:
with NumpyRNGContext(i + seed):
_, assign = k_means(pts, k)
obj = cls(pts, assign=assign)
except np.linalg.LinAlgError:
return -np.inf,
else:
return obj.bic, k, obj.kdes
return obj.bic, k, obj.kdes


class _mapfunc(object):
Expand All @@ -133,7 +64,10 @@ def __init__(self, pts, max_k=40, trials=5, assign=None,
# The seed must be an unsigned 32-bit integer, so if there are n
# threads, then s must be drawn from the interval [0, 2**32 - n).
seed = np.random.randint(0, 2**32 - max_k * trials)
func = partial(_cluster, type(self), pts, trials, seed=seed)
# Pre-whiten the points
whitened_pts = whiten(pts)
func = partial(
_cluster, type(self), pts, whitened_pts, trials, seed=seed)
self.bic, self.k, self.kdes = max(
self._map(func, range(trials, (max_k + 1) * trials)))
else:
Expand Down