Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions distclassipy/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ def initialize_metric_function(metric):
elif isinstance(metric, str):
metric_str_lowercase = metric.lower()
metric_found = False

# Map DistClassiPy metric names to SciPy equivalents where possible
scipy_metric_mapping = {
'squared_euclidean': 'sqeuclidean',
'jensenshannon_divergence': 'jensenshannon',
}

for package_str, source in METRIC_SOURCES_.items():

# Don't use scipy for jaccard as their implementation only works with
Expand All @@ -82,6 +89,24 @@ def initialize_metric_function(metric):
):
continue

# Check for direct mapping to SciPy equivalents
if package_str == "scipy.spatial.distance":
scipy_metric_name = scipy_metric_mapping.get(metric_str_lowercase, metric_str_lowercase)
if hasattr(source, scipy_metric_name):
if metric_str_lowercase == 'jensenshannon_divergence':
# Special handling for Jensen-Shannon divergence
# We need to wrap it to square the result
import functools
base_fn = getattr(source, scipy_metric_name)
metric_fn_ = lambda u, v: base_fn(u, v) ** 2
# Still use the optimized scipy function for cdist
metric_arg_ = scipy_metric_name
else:
metric_fn_ = getattr(source, scipy_metric_name)
metric_arg_ = scipy_metric_name
metric_found = True
break

if hasattr(source, metric_str_lowercase):
metric_fn_ = getattr(source, metric_str_lowercase)
metric_found = True
Expand Down
56 changes: 30 additions & 26 deletions distclassipy/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,13 @@ def clark(u, v):
1(4), 300-307.
"""
u, v = np.asarray(u), np.asarray(v)
with np.errstate(divide="ignore", invalid="ignore"):
return np.sqrt(np.nansum(np.power(np.abs(u - v) / (u + v), 2)))
diff = np.abs(u - v)
sum_uv = u + v
# Use boolean indexing to avoid division by zero more efficiently
nonzero_mask = sum_uv != 0
result = np.zeros_like(diff)
result[nonzero_mask] = (diff[nonzero_mask] / sum_uv[nonzero_mask]) ** 2
return np.sqrt(np.sum(result))


def hellinger(u, v):
Expand Down Expand Up @@ -386,11 +391,13 @@ def hellinger(u, v):
1(4), 300-307.
"""
u, v = np.asarray(u), np.asarray(v)
# Clip negative values to zero for valid sqrt
with np.errstate(divide="ignore", invalid="ignore"):
u = np.clip(u, a_min=0, a_max=None)
v = np.clip(v, a_min=0, a_max=None)
return np.sqrt(2 * np.sum((np.sqrt(u) - np.sqrt(v)) ** 2))
# Clip negative values to zero for valid sqrt and vectorize operations
u = np.clip(u, 0, None)
v = np.clip(v, 0, None)
sqrt_u = np.sqrt(u)
sqrt_v = np.sqrt(v)
diff = sqrt_u - sqrt_v
return np.sqrt(2 * np.dot(diff, diff))


def jaccard(u, v):
Expand Down Expand Up @@ -442,8 +449,8 @@ def lorentzian(u, v):
eschew the log of zero.
"""
u, v = np.asarray(u), np.asarray(v)
with np.errstate(divide="ignore", invalid="ignore"):
return np.sum(np.log(np.abs(u - v) + 1))
abs_diff = np.abs(u - v)
return np.sum(np.log1p(abs_diff)) # log1p(x) = log(1 + x) is more accurate


def marylandbridge(u, v):
Expand Down Expand Up @@ -548,7 +555,9 @@ def soergel(u, v):
1(4), 300-307.
"""
u, v = np.asarray(u), np.asarray(v)
return np.sum(np.abs(u - v)) / np.sum(np.maximum(u, v))
abs_diff = np.abs(u - v)
max_uv = np.maximum(u, v)
return np.sum(abs_diff) / np.sum(max_uv)


def wave_hedges(u, v):
Expand All @@ -570,10 +579,13 @@ def wave_hedges(u, v):
1(4), 300-307
"""
u, v = np.asarray(u), np.asarray(v)
with np.errstate(divide="ignore", invalid="ignore"):
u_v = abs(u - v)
uvmax = np.maximum(u, v)
return np.sum(np.where(((u_v != 0) & (uvmax != 0)), u_v / uvmax, 0))
abs_diff = np.abs(u - v)
max_uv = np.maximum(u, v)
# Use boolean indexing for more efficient zero handling
nonzero_mask = (abs_diff != 0) & (max_uv != 0)
result = np.zeros_like(abs_diff)
result[nonzero_mask] = abs_diff[nonzero_mask] / max_uv[nonzero_mask]
return np.sum(result)


def kulczynski(u, v):
Expand Down Expand Up @@ -907,17 +919,9 @@ def jensenshannon_divergence(u, v):
return np.sum(el1 - el2 * el3)
"""
u, v = np.asarray(u), np.asarray(v)
with np.errstate(divide="ignore", invalid="ignore"):
# Clip negative values to zero for valid log
u[u == 0] = EPSILON
v[v == 0] = EPSILON

term1 = np.clip(2 * u / (u + v), a_min=EPSILON, a_max=None)
term2 = np.clip(2 * v / (u + v), a_min=EPSILON, a_max=None)

dl = u * np.log(term1)
dr = v * np.log(term2)
return (np.sum(dl) + np.sum(dr)) / 2
# Use SciPy's optimized implementation and square the result
# to match the expected Jensen-Shannon divergence formula
return scipy.spatial.distance.jensenshannon(u, v) ** 2


def jensen_difference(u, v):
Expand Down Expand Up @@ -1207,7 +1211,7 @@ def squared_euclidean(u, v):
Equals to squared Euclidean distance.
"""
u, v = np.asarray(u), np.asarray(v)
return np.dot((u - v), (u - v))
return scipy.spatial.distance.sqeuclidean(u, v)


def taneja(u, v):
Expand Down