Skip to content

Commit d7e08a6

Browse files
committed
Fix naive bayes line length and mypy issues
- Shortened comment to fix E501 line length violation - Added type annotations for feature_counts, means, variances, log_probabilities - Fixed mypy issue by converting numpy int to Python int - All pre-commit checks should now pass for this file
1 parent 0841d09 commit d7e08a6

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

machine_learning/naive_bayes_laplace.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _compute_feature_counts(self, x: np.ndarray, y: np.ndarray
124124
>>> int(counts[1][1][0]) # class 1, feature 1, value 0
125125
1
126126
"""
127-
feature_counts = {}
127+
feature_counts: dict[int, dict[int, dict[int, int]]] = {}
128128

129129
for class_label in np.unique(y):
130130
feature_counts[class_label] = {}
@@ -164,8 +164,8 @@ def _compute_feature_statistics(self, x: np.ndarray, y: np.ndarray
164164
>>> len(vars)
165165
2
166166
"""
167-
means = {}
168-
variances = {}
167+
means: dict[int, dict[int, float]] = {}
168+
variances: dict[int, dict[int, float]] = {}
169169

170170
for class_label in np.unique(y):
171171
means[class_label] = {}
@@ -197,7 +197,7 @@ def _compute_log_probabilities_discrete(self, x: np.ndarray, y: np.ndarray
197197
Nested dictionary: class -> feature -> value -> log_probability
198198
"""
199199
feature_counts = self._compute_feature_counts(x, y)
200-
log_probabilities = {}
200+
log_probabilities: dict[int, dict[int, dict[int, float]]] = {}
201201

202202
for class_label in np.unique(y):
203203
log_probabilities[class_label] = {}
@@ -213,10 +213,10 @@ def _compute_log_probabilities_discrete(self, x: np.ndarray, y: np.ndarray
213213
for feature_value in all_values:
214214
# Count occurrences of this value in this class
215215
count = feature_counts[class_label][feature_idx].get(
216-
feature_value, 0
216+
int(feature_value), 0
217217
)
218218

219-
# Apply Laplace smoothing: (count + alpha) / (n_class_samples + alpha * n_unique_values)
219+
# Apply Laplace smoothing formula
220220
n_unique_values = len(all_values)
221221
smoothed_prob = (count + self.alpha) / (
222222
n_class_samples + self.alpha * n_unique_values

0 commit comments

Comments
 (0)