Skip to content

Commit 7a5d06e

Browse files
authored
Update loss_functions.py
1 parent e8b2047 commit 7a5d06e

File tree

1 file changed

+5
-10
lines changed

1 file changed

+5
-10
lines changed

machine_learning/loss_functions.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ def kullback_leibler_divergence(y_true: np.ndarray, y_pred: np.ndarray) -> float
663663
return np.sum(kl_loss)
664664

665665

666-
def root_mean_squared_error(y_true, y_pred):
666+
def root_mean_squared_error(y_true: np.array, y_pred:np.array) -> float:
667667
"""
668668
Root Mean Squared Error (RMSE)
669669
@@ -682,15 +682,10 @@ def root_mean_squared_error(y_true, y_pred):
682682
Returns:
683683
float: The RMSE Loss function between y_pred and y_true
684684
685-
>>> true_labels = np.array([100, 200, 300])
686-
>>> predicted_probs = np.array([110, 190, 310])
685+
>>> true_labels = np.array([2, 4, 6, 8])
686+
>>> predicted_probs = np.array([3, 5, 7, 10])
687687
>>> root_mean_squared_error(true_labels, predicted_probs)
688-
3.42
689-
690-
>>> true_labels = [2, 4, 6, 8]
691-
>>> predicted_probs = [3, 5, 7, 10]
692-
>>> root_mean_squared_error(true_labels, predicted_probs)
693-
1.2247
688+
1.3228
694689
695690
>>> true_labels = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
696691
>>> predicted_probs = np.array([0.3, 0.8, 0.9, 0.2])
@@ -703,7 +698,7 @@ def root_mean_squared_error(y_true, y_pred):
703698
raise ValueError("Input arrays must have the same length.")
704699
y_true, y_pred = np.array(y_true), np.array(y_pred)
705700

706-
mse = np.mean((y_pred - y_true) ** 2)
701+
mse = np.mean((y_true - y_pred) ** 2)
707702
return np.sqrt(mse)
708703

709704

0 commit comments

Comments
 (0)