From 237c182e8bb603cb0547220ed3eff08736e185e0 Mon Sep 17 00:00:00 2001 From: Cen Wang Date: Thu, 24 Jul 2025 15:22:04 -0400 Subject: [PATCH 1/2] Roll back curve functions to fix plotting bug --- simopt/curve.py | 96 ++++++++++++++++++++++--------------------------- 1 file changed, 42 insertions(+), 54 deletions(-) diff --git a/simopt/curve.py b/simopt/curve.py index 8c11a4865..3b91ee657 100644 --- a/simopt/curve.py +++ b/simopt/curve.py @@ -3,10 +3,11 @@ from __future__ import annotations import logging -import math from enum import Enum from typing import TYPE_CHECKING +import numpy as np + # Imports exclusively used when type checking # Prevents imports from being executed at runtime if TYPE_CHECKING: @@ -94,19 +95,15 @@ def lookup(self, x_val: float) -> float: Raises: TypeError: If x_val is not numeric. """ - from bisect import bisect_right - - try: - # Return NaN if x_val is out of range (before first or after last x-value) - if x_val < self.x_vals[0] or x_val > self.x_vals[-1]: - return math.nan - - # Use binary search (O(log n)) instead of linear search (O(n)) - idx = bisect_right(self.x_vals, x_val) - 1 - return self.y_vals[idx] + # Type checking + if not isinstance(x_val, (int, float)): + error_msg = "x_val must be a float." + raise TypeError(error_msg) - except TypeError as e: - raise TypeError(f"x_val must be a numeric value: {e}") from e + if x_val < self.x_vals[0]: + return np.nan + idx = np.max(np.where(np.array(self.x_vals) <= x_val)) + return self.y_vals[idx] def compute_crossing_time(self, threshold: float) -> float: """Compute the first time at which a curve drops below a given threshold. @@ -120,21 +117,24 @@ def compute_crossing_time(self, threshold: float) -> float: Raises: TypeError: If threshold is not numeric. """ - from bisect import bisect_right - - try: - # Find the first index where y_vals < threshold using binary search - index = bisect_right(self.y_vals, threshold) - - # If all y-values are above the threshold, return infinity - if index == self.n_points: - return math.inf - - # Return corresponding x-value - return self.x_vals[index] - - except TypeError as e: - raise TypeError(f"Threshold must be a numeric value: {e}") from e + # Type checking + if not isinstance(threshold, (int, float)): + error_msg = "Threshold must be a float." + raise TypeError(error_msg) + + # Use binary search to find the first x-value below threshold. + # TODO: Test this + # index = bisect.bisect_left(self.y_vals, threshold) + # if index == self.n_points: + # return np.inf + # else: + # return self.x_vals[index] + + for i in range(self.n_points): + if self.y_vals[i] < threshold: + return self.x_vals[i] + # If threshold is never crossed, return infinity. + return np.inf def compute_area_under_curve(self) -> float: """Compute the area under a curve. @@ -142,10 +142,8 @@ def compute_area_under_curve(self) -> float: Returns: float: Area under the curve. """ - x_diffs = (x_next - x for x, x_next in zip(self.x_vals[:-1], self.x_vals[1:])) - area_contributions = (y * dx for y, dx in zip(self.y_vals[:-1], x_diffs)) - - return sum(area_contributions) + area = np.dot(self.y_vals[:-1], np.diff(self.x_vals)) + return area def curve_to_mesh(self, mesh: Iterable[float]) -> Curve: """Create a curve defined at equally spaced x values. @@ -159,19 +157,15 @@ def curve_to_mesh(self, mesh: Iterable[float]) -> Curve: Raises: TypeError: If mesh is not an iterable of numeric values. """ - try: - # Ensure mesh contains valid numeric values - mesh_x_vals = tuple(float(x) for x in mesh) - - # Generate corresponding y-values using lookup - mesh_y_vals = tuple(self.lookup(x) for x in mesh_x_vals) + # Type checking + if not isinstance(mesh, list) or not all( + [isinstance(x, (int, float)) for x in mesh] + ): + error_msg = "Mesh must be a list of floats." + raise TypeError(error_msg) - return Curve(x_vals=mesh_x_vals, y_vals=mesh_y_vals) - - except (TypeError, ValueError) as e: - error_msg = "Mesh must be an iterable of numeric values." - logging.error(error_msg) - raise TypeError(error_msg) from e + mesh_curve = Curve(x_vals=mesh, y_vals=[self.lookup(x) for x in mesh]) + return mesh_curve def curve_to_full_curve(self) -> Curve: """Create a curve with duplicate x- and y-values to indicate steps. @@ -179,16 +173,10 @@ def curve_to_full_curve(self) -> Curve: Returns: Curve: A curve with duplicate x- and y-values. """ - from itertools import chain, repeat - - full_curve = Curve( - x_vals=list(chain.from_iterable(repeat(x, 2) for x in self.x_vals)), - y_vals=list(chain.from_iterable(repeat(y, 2) for y in self.y_vals)), - ) - return Curve( - x_vals=list(full_curve.x_vals)[1:], - y_vals=list(full_curve.y_vals)[:-1], - ) + duplicate_x_vals = [x for x in self.x_vals for _ in (0, 1)] + duplicate_y_vals = [y for y in self.y_vals for _ in (0, 1)] + full_curve = Curve(x_vals=duplicate_x_vals[1:], y_vals=duplicate_y_vals[:-1]) + return full_curve def plot( self, From 36a36ef92ef678a8012e2ece69086e6be79115b3 Mon Sep 17 00:00:00 2001 From: William Grochocinski Date: Thu, 24 Jul 2025 20:24:20 -0400 Subject: [PATCH 2/2] unrolled back Curve code (except for compute crossing time), added note to avoid future errors --- simopt/curve.py | 85 +++++++++++++++++++++++++------------------------ 1 file changed, 44 insertions(+), 41 deletions(-) diff --git a/simopt/curve.py b/simopt/curve.py index 3b91ee657..0fbdc4dee 100644 --- a/simopt/curve.py +++ b/simopt/curve.py @@ -3,11 +3,10 @@ from __future__ import annotations import logging +import math from enum import Enum from typing import TYPE_CHECKING -import numpy as np - # Imports exclusively used when type checking # Prevents imports from being executed at runtime if TYPE_CHECKING: @@ -95,15 +94,20 @@ def lookup(self, x_val: float) -> float: Raises: TypeError: If x_val is not numeric. """ - # Type checking - if not isinstance(x_val, (int, float)): - error_msg = "x_val must be a float." - raise TypeError(error_msg) + # We can use binary search here since the x-values are sorted. + from bisect import bisect_right + + try: + # Return NaN if x_val is out of range (before first or after last x-value) + if x_val < self.x_vals[0] or x_val > self.x_vals[-1]: + return math.nan + + # Use binary search (O(log n)) instead of linear search (O(n)) + idx = bisect_right(self.x_vals, x_val) - 1 + return self.y_vals[idx] - if x_val < self.x_vals[0]: - return np.nan - idx = np.max(np.where(np.array(self.x_vals) <= x_val)) - return self.y_vals[idx] + except TypeError as e: + raise TypeError(f"x_val must be a numeric value: {e}") from e def compute_crossing_time(self, threshold: float) -> float: """Compute the first time at which a curve drops below a given threshold. @@ -113,28 +117,15 @@ def compute_crossing_time(self, threshold: float) -> float: Returns: float: First time at which the curve drops below the threshold. - - Raises: - TypeError: If threshold is not numeric. """ - # Type checking - if not isinstance(threshold, (int, float)): - error_msg = "Threshold must be a float." - raise TypeError(error_msg) - - # Use binary search to find the first x-value below threshold. - # TODO: Test this - # index = bisect.bisect_left(self.y_vals, threshold) - # if index == self.n_points: - # return np.inf - # else: - # return self.x_vals[index] - + # Linear search for the first crossing time + # NOTE: We can't use binary search because the curve's y-values aren't + # guaranteed to be strictly decreasing. for i in range(self.n_points): if self.y_vals[i] < threshold: return self.x_vals[i] # If threshold is never crossed, return infinity. - return np.inf + return math.inf def compute_area_under_curve(self) -> float: """Compute the area under a curve. @@ -142,8 +133,10 @@ def compute_area_under_curve(self) -> float: Returns: float: Area under the curve. """ - area = np.dot(self.y_vals[:-1], np.diff(self.x_vals)) - return area + x_diffs = (x_next - x for x, x_next in zip(self.x_vals[:-1], self.x_vals[1:])) + area_contributions = (y * dx for y, dx in zip(self.y_vals[:-1], x_diffs)) + + return sum(area_contributions) def curve_to_mesh(self, mesh: Iterable[float]) -> Curve: """Create a curve defined at equally spaced x values. @@ -157,15 +150,19 @@ def curve_to_mesh(self, mesh: Iterable[float]) -> Curve: Raises: TypeError: If mesh is not an iterable of numeric values. """ - # Type checking - if not isinstance(mesh, list) or not all( - [isinstance(x, (int, float)) for x in mesh] - ): - error_msg = "Mesh must be a list of floats." - raise TypeError(error_msg) + try: + # Ensure mesh contains valid numeric values + mesh_x_vals = tuple(float(x) for x in mesh) + + # Generate corresponding y-values using lookup + mesh_y_vals = tuple(self.lookup(x) for x in mesh_x_vals) - mesh_curve = Curve(x_vals=mesh, y_vals=[self.lookup(x) for x in mesh]) - return mesh_curve + return Curve(x_vals=mesh_x_vals, y_vals=mesh_y_vals) + + except (TypeError, ValueError) as e: + error_msg = "Mesh must be an iterable of numeric values." + logging.error(error_msg) + raise TypeError(error_msg) from e def curve_to_full_curve(self) -> Curve: """Create a curve with duplicate x- and y-values to indicate steps. @@ -173,10 +170,16 @@ def curve_to_full_curve(self) -> Curve: Returns: Curve: A curve with duplicate x- and y-values. """ - duplicate_x_vals = [x for x in self.x_vals for _ in (0, 1)] - duplicate_y_vals = [y for y in self.y_vals for _ in (0, 1)] - full_curve = Curve(x_vals=duplicate_x_vals[1:], y_vals=duplicate_y_vals[:-1]) - return full_curve + from itertools import chain, repeat + + full_curve = Curve( + x_vals=list(chain.from_iterable(repeat(x, 2) for x in self.x_vals)), + y_vals=list(chain.from_iterable(repeat(y, 2) for y in self.y_vals)), + ) + return Curve( + x_vals=list(full_curve.x_vals)[1:], + y_vals=list(full_curve.y_vals)[:-1], + ) def plot( self,