-
Notifications
You must be signed in to change notification settings - Fork 9
Open
Description
The DirichletCalibrator is throwing error on passing l2_list for reg_lambda parameter as you originally did in your experiment here and you passed your l2_list as here.
In my case, its throwing error as below
/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.6/site-packages/dirichletcal/calib/fulldirichlet.py in fit(self, X, y, X_val, y_val, *args, **kwargs)
46 reg_norm=self.reg_norm,
47 ref_row=self.ref_row)
---> 48 self.calibrator_.fit(_X, y, *args, **kwargs)
49 final_loss = log_loss(y_val, self.calibrator_.predict_proba(_X_val))
50
/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.6/site-packages/dirichletcal/calib/multinomial.py in fit(self, X, y, *args, **kwargs)
95 reg_mu=self.reg_mu, ref_row=self.ref_row,
96 initializer=self.initializer,
---> 97 reg_format=self.reg_format)
98 else:
99 res = scipy.optimize.fmin_l_bfgs_b(func=_objective, fprime=_gradient,
/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.6/site-packages/dirichletcal/calib/multinomial.py in _newton_update(weights_0, X, XX_T, target, k, method_, maxiter, ftol, gtol, reg_lambda, reg_mu, ref_row, initializer, reg_format)
233 L_list = [raw_np.float(_objective(weights_0, X, XX_T, target, k, method_,
234 reg_lambda, reg_mu, ref_row, initializer,
--> 235 reg_format))]
236
237 weights = weights_0.copy()
/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.6/site-packages/dirichletcal/calib/multinomial.py in _objective(params, *args)
151 else:
152 reg = np.zeros((k, k+1))
--> 153 loss = loss + reg_lambda * np.sum((weights - reg)**2)
154 else:
155 weights_hat = weights - np.hstack([weights[:, :-1] * np.eye(k),
/home/ec2-user/anaconda3/envs/JupyterSystemEnv/lib/python3.6/site-packages/jax/interpreters/xla.py in _forward_method(attrname, self, fun, *args)
985
986 def _forward_method(attrname, self, fun, *args):
--> 987 return fun(getattr(self, attrname), *args)
988 _forward_to_value = partial(_forward_method, "_value")
989
TypeError: only integer scalar arrays can be converted to a scalar indexMetadata
Metadata
Assignees
Labels
No labels