-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscikitLL.py
More file actions
31 lines (23 loc) · 1.02 KB
/
scikitLL.py
File metadata and controls
31 lines (23 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# -*- coding: utf-8 -*-
__author__ = "Varun Nayyar <nayyarv@gmail.com>"
from .base import LikelihoodEvaluator
import numpy as np
# import at top level so we have import errors earlier than later
from sklearn.mixture.gaussian_mixture import GaussianMixture, _compute_precision_cholesky
class ScikitLL(LikelihoodEvaluator):
"""
Fastest Single Core Version so far!
"""
def __init__(self, Xpoints, numMixtures):
super().__init__(Xpoints, numMixtures)
self.evaluator = GaussianMixture(numMixtures, 'diag')
self.Xpoints = Xpoints
self.evaluator.fit(Xpoints)
def __str__(self):
return "SciKit's learn implementation Implementation"
def loglikelihood(self, means, diagCovs, weights):
self.evaluator.weights_ = weights
self.evaluator.covariances_ = diagCovs
self.evaluator.means_ = means
self.evaluator.precisions_cholesky_ = _compute_precision_cholesky(diagCovs, "diag")
return self.numPoints * np.sum(self.evaluator.score(self.Xpoints))