Skip to content

Commit c68dd02

Browse files
fix KDE code with scores generation
1 parent 84c250f commit c68dd02

2 files changed

Lines changed: 6 additions & 5 deletions

File tree

docs/source/modules/using_aggregative.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ Examples
4141
# Create a classifier
4242
clf = LogisticRegression()
4343
44-
ac = AC()
44+
ac = AC(clf)
4545
4646
# used when you just have the sample to predict
4747
ac.fit(X_train, y_train)
4848
prevalence = ac.predict(X_test)
4949
5050
# used when you already have the sample predictions (usually from cross-validation for training predictions)
51-
prevalence = ac.aggregate(posteriors, train_posteriors, y_train)
51+
ac2 = AC()
52+
prevalence = ac2.aggregate(posteriors, train_posteriors, y_train)
5253
5354
5455

mlquantify/neighbors/_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(self, learner=None, bandwidth: float = 0.1, kernel: str = "gaussian
8686
self.best_distance = None
8787

8888
@_fit_context(prefer_skip_nested_validation=True)
89-
def fit(self, X, y, learner_fitted=False):
89+
def fit(self, X, y, learner_fitted=False, cv=5, stratified=True, shuffle=False):
9090
X, y = validate_data(self, X, y, ensure_2d=True, ensure_min_samples=2)
9191
validate_y(self, y)
9292

@@ -100,8 +100,8 @@ def fit(self, X, y, learner_fitted=False):
100100
else:
101101
train_predictions, y_train = apply_cross_validation(
102102
self.learner, X, y,
103-
function=learner_function, cv=5,
104-
stratified=True, shuffle=True
103+
function=learner_function, cv=cv,
104+
stratified=stratified, shuffle=shuffle
105105
)
106106

107107
self.train_predictions = train_predictions

0 commit comments

Comments
 (0)