Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 38 additions & 14 deletions dirichletcal/calib/fulldirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,27 @@


class FullDirichletCalibrator(BaseEstimator, RegressorMixin):
def __init__(self, reg_lambda=0.0, reg_mu=None, weights_init=None,
def __init__(self, reg_lambda_list=[0.0], reg_mu_list=[None], weights_init=None,
initializer='identity', reg_norm=False, ref_row=True):

"""
Params:
weights_init: (nd.array) weights used for initialisation, if None then idendity matrix used. Shape = (n_classes - 1, n_classes + 1)
comp_l2: (bool) If true, then complementary L2 regularization used (off-diagonal regularization)
"""
self.reg_lambda = reg_lambda
self.reg_mu = reg_mu # Complementary L2 regularization. (Off-diagonal regularization)
self.reg_lambda_list = reg_lambda_list
self.reg_mu_list = reg_mu_list # Complementary L2 regularization. (Off-diagonal regularization)
self.weights_init = weights_init # Input weights for initialisation
self.initializer = initializer
self.reg_norm = reg_norm
self.ref_row = ref_row

def __setup(self):
self.reg_lambda = 0.0
self.reg_mu = None
self.calibrator_ = None
self.weights_ = self.weights_init

def fit(self, X, y, X_val=None, y_val=None, *args, **kwargs):

self.weights_ = self.weights_init
Expand All @@ -40,13 +46,31 @@ def fit(self, X, y, X_val=None, y_val=None, *args, **kwargs):
_X_val = np.copy(X_val)
_X_val = np.log(clip_for_log(X_val))

self.calibrator_ = MultinomialRegression(method='Full',
reg_lambda=self.reg_lambda,
reg_mu=self.reg_mu,
reg_norm=self.reg_norm,
ref_row=self.ref_row)
self.calibrator_.fit(_X, y, *args, **kwargs)
final_loss = log_loss(y_val, self.calibrator_.predict_proba(_X_val))
for i in range(0, len(self.reg_lambda_list)):
for j in range(0, len(self.reg_mu_list)):
tmp_cal = MultinomialRegression(method='Full',
reg_lambda=self.reg_lambda_list[i],
reg_mu=self.reg_mu_list[j],
reg_norm=self.reg_norm,
ref_row=self.ref_row)
tmp_cal.fit(_X, y, *args, **kwargs)
tmp_loss = log_loss(y_val, tmp_cal.predict_proba(_X_val))

if (i + j) == 0:
final_cal = tmp_cal
final_loss = tmp_loss
final_reg_lambda = self.reg_lambda_list[i]
final_reg_mu = self.reg_mu_list[j]
elif tmp_loss < final_loss:
final_cal = tmp_cal
final_loss = tmp_loss
final_reg_lambda = self.reg_lambda_list[i]
final_reg_mu = self.reg_mu_list[j]

self.calibrator_ = final_cal
self.reg_lambda = final_reg_lambda
self.reg_mu = final_reg_mu
self.weights_ = self.calibrator_.weights_

return self

Expand All @@ -65,9 +89,9 @@ def intercept_(self):
return self.calibrator_.intercept_

def predict_proba(self, S):
S = np.log(clip_for_log(S))
return np.asarray(self.calibrator_.predict_proba(S))
_S = np.log(clip_for_log(np.copy(S)))
return np.asarray(self.calibrator_.predict_proba(_S))

def predict(self, S):
S = np.log(clip_for_log(S))
return np.asarray(self.calibrator_.predict(S))
_S = np.log(clip_for_log(np.copy(S)))
return np.asarray(self.calibrator_.predict(_S))
2 changes: 1 addition & 1 deletion dirichletcal/calib/multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def fit(self, X, y, *args, **kwargs):
self.reg_lambda = self.reg_lambda / (k * (k - 1))
self.reg_mu = self.reg_mu / k

target = label_binarize(y, self.classes)
target = label_binarize(y, classes=self.classes)

if k == 2:
target = np.hstack([1-target, target])
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy
scipy
scikit-learn
numpy>=1.14.2
scipy>=1.0.0
scikit-learn>=0.19.1
jax
jaxlib
autograd
12 changes: 4 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

with open("README.md", 'r') as f:
long_description = f.read()

with open("requirements.txt") as fh:
requirements = fh.read().splitlines()

main_ns = {}
ver_path = convert_path('dirichletcal/version.py')
Expand All @@ -28,12 +31,5 @@
"Operating System :: OS Independent",
],
python_requires='>=3.6',
install_requires = [
'numpy>=1.14.2'
'scipy>=1.0.0'
'scikit-learn>=0.19.1'
'jax'
'jaxlib'
'autograd'
]
install_requires=requirements,
)