Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,27 @@ def propagate_parameter_metadata(root: ParameterNode) -> ParameterNode:

UNPROPAGAGED_METADATA = ["breakdown", "label", "name", "description"]

for parameter in root.get_descendants():
if parameter.metadata.get("propagate_metadata_to_children"):
# Pre-compute all descendants once to avoid O(n²) complexity
all_descendants = list(root.get_descendants())

# Find parameters that need to propagate metadata
propagators = [
p for p in all_descendants
if p.metadata.get("propagate_metadata_to_children")
]

# For each parameter that propagates metadata, update its descendants
for parameter in propagators:
# Get metadata to propagate
metadata_to_propagate = {
key: value
for key, value in parameter.metadata.items()
if key not in UNPROPAGAGED_METADATA
}

# Only call get_descendants() if there's metadata to propagate
if metadata_to_propagate:
for descendant in parameter.get_descendants():
descendant.metadata.update(
{
key: value
for key, value in parameter.metadata.items()
if key not in UNPROPAGAGED_METADATA
}
)
descendant.metadata.update(metadata_to_propagate)

return root
104 changes: 69 additions & 35 deletions policyengine_core/parameters/operations/uprate_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
if hasattr(bracket, allowed_key):
descendants.append(getattr(bracket, allowed_key))

# Global cache for uprating parameter values to avoid redundant lookups
# Key: (uprating_parameter.name, instant), Value: parameter value at that instant
uprating_cache = {}

for parameter in descendants:
if isinstance(parameter, Parameter):
if parameter.metadata.get("uprating") is not None:
Expand Down Expand Up @@ -110,6 +114,7 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
uprating_first_date,
uprating_last_date,
meta,
uprating_cache,
)

# Append uprated data to parameter values list
Expand All @@ -127,7 +132,12 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
# Pre-compute values that don't change in the loop
last_instant_str = str(last_instant)
value_at_start = parameter(last_instant)
uprater_at_start = uprating_parameter(last_instant)

# Use cache for uprating parameter lookup
cache_key = (uprating_parameter.name, last_instant)
if cache_key not in uprating_cache:
uprating_cache[cache_key] = uprating_parameter(last_instant)
uprater_at_start = uprating_cache[cache_key]

if uprater_at_start is None:
raise ValueError(
Expand All @@ -137,30 +147,38 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
# Pre-compute uprater values for all entries to avoid repeated lookups
has_rounding = "rounding" in meta

# For each defined instant in the uprating parameter
# Pre-compute all uprating values and instants to reduce function calls
uprating_entries = []
for entry in uprating_parameter.values_list[::-1]:
entry_instant = instant(entry.instant_str)
# If the uprater instant is defined after the last parameter instant
if entry_instant > last_instant:
# Apply the uprater and add to the parameter
uprater_at_entry = uprating_parameter(
entry_instant
)
uprater_change = (
uprater_at_entry / uprater_at_start
uprating_entries.append((entry_instant, entry.instant_str))

# Batch lookup of uprating parameter values using global cache
uprater_values = {}
for entry_instant, _ in uprating_entries:
cache_key = (uprating_parameter.name, entry_instant)
if cache_key not in uprating_cache:
uprating_cache[cache_key] = uprating_parameter(entry_instant)
uprater_values[entry_instant] = uprating_cache[cache_key]

# For each defined instant in the uprating parameter
for entry_instant, entry_instant_str in uprating_entries:
# Apply the uprater and add to the parameter
uprater_at_entry = uprater_values[entry_instant]
uprater_change = uprater_at_entry / uprater_at_start
uprated_value = value_at_start * uprater_change
if has_rounding:
uprated_value = round_uprated_value(
meta, uprated_value
)
uprated_value = value_at_start * uprater_change
if has_rounding:
uprated_value = round_uprated_value(
meta, uprated_value
)
parameter.values_list.append(
ParameterAtInstant(
parameter.name,
entry.instant_str,
data=uprated_value,
)
parameter.values_list.append(
ParameterAtInstant(
parameter.name,
entry_instant_str,
data=uprated_value,
)
)
# Whether using cadence or not, sort the parameter values_list
parameter.values_list.sort(
key=lambda x: x.instant_str, reverse=True
Expand Down Expand Up @@ -345,6 +363,7 @@ def uprate_by_cadence(
first_date: datetime,
last_date: datetime,
meta: dict,
uprating_cache: dict,
) -> list[ParameterAtInstant]:
# Determine the frequency module to utilize within rrule
interval = ""
Expand Down Expand Up @@ -374,21 +393,36 @@ def uprate_by_cadence(
# Set a starting reference value to calculate against
reference_value = parameter.get_at_instant(instant(first_date.date()))

# For each entry (corresponding to an enactment date) in the iteration list...
for enactment_date in iterations:
# Calculate the start and end calculation dates
start_calc_date: datetime = enactment_date - enactment_start_offset
end_calc_date: datetime = enactment_date - enactment_end_offset

# Find uprater value at cadence start
start_val = uprating_parameter.get_at_instant(
instant(start_calc_date.date())
)
# Pre-compute all instants and batch lookup uprating parameter values
iteration_list = list(iterations)
calc_dates = []
for enactment_date in iteration_list:
start_calc_date = enactment_date - enactment_start_offset
end_calc_date = enactment_date - enactment_end_offset
calc_dates.append((enactment_date, start_calc_date, end_calc_date))

# Batch lookup all uprating parameter values using global cache
for _, start_calc_date, end_calc_date in calc_dates:
start_instant = instant(start_calc_date.date())
end_instant = instant(end_calc_date.date())
start_key = (uprating_parameter.name, start_instant)
end_key = (uprating_parameter.name, end_instant)
if start_key not in uprating_cache:
uprating_cache[start_key] = uprating_parameter.get_at_instant(start_instant)
if end_key not in uprating_cache:
uprating_cache[end_key] = uprating_parameter.get_at_instant(end_instant)

has_rounding = "rounding" in meta

# Find uprater value at cadence end
end_val = uprating_parameter.get_at_instant(
instant(end_calc_date.date())
)
# For each entry (corresponding to an enactment date) in the iteration list...
for enactment_date, start_calc_date, end_calc_date in calc_dates:
# Get pre-computed uprater values from global cache
start_instant = instant(start_calc_date.date())
end_instant = instant(end_calc_date.date())
start_key = (uprating_parameter.name, start_instant)
end_key = (uprating_parameter.name, end_instant)
start_val = uprating_cache[start_key]
end_val = uprating_cache[end_key]

# Ensure that earliest date exists within uprater
if not start_val:
Expand All @@ -401,7 +435,7 @@ def uprate_by_cadence(

# Uprate value
uprated_value = difference * reference_value
if "rounding" in meta:
if has_rounding:
uprated_value = round_uprated_value(meta, uprated_value)

# Add uprated value to data list
Expand Down
43 changes: 29 additions & 14 deletions policyengine_core/periods/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,29 +47,50 @@ def instant(instant):
"""
if instant is None:
return None

# Fast path for already-instant objects (most common after string)
if isinstance(instant, periods.Instant):
return instant

# String is the most common input type during parameter loading
if isinstance(instant, str):
return _instant_from_string(instant)

# For other types, create a cache key and check the cache
cache_key = None
# Fast path for datetime.date objects (common in uprating)
if isinstance(instant, datetime.date):
cache_key = (instant.year, instant.month, instant.day)
cached = _instant_cache.get(cache_key)
if cached is not None:
return cached
result = periods.Instant(cache_key)
_instant_cache[cache_key] = result
return result

# Check Period before tuple since Period is a subclass of tuple
if isinstance(instant, periods.Period):
return instant.start
elif isinstance(instant, datetime.date):
cache_key = (instant.year, instant.month, instant.day)
elif isinstance(instant, int):

# Handle int input
if isinstance(instant, int):
cache_key = (instant, 1, 1)
elif isinstance(instant, (tuple, list)):
cached = _instant_cache.get(cache_key)
if cached is not None:
return cached
result = periods.Instant(cache_key)
_instant_cache[cache_key] = result
return result

# Handle tuple/list input
if isinstance(instant, (tuple, list)):
if len(instant) == 1:
cache_key = (instant[0], 1, 1)
elif len(instant) == 2:
cache_key = (instant[0], instant[1], 1)
elif len(instant) == 3:
cache_key = tuple(instant)
else:
raise AssertionError(f"Invalid instant length: {len(instant)}")

if cache_key is not None:
cached = _instant_cache.get(cache_key)
if cached is not None:
return cached
Expand All @@ -78,13 +99,7 @@ def instant(instant):
return result

# Fallback for unexpected types
assert isinstance(instant, tuple), instant
assert 1 <= len(instant) <= 3
if len(instant) == 1:
return periods.Instant((instant[0], 1, 1))
if len(instant) == 2:
return periods.Instant((instant[0], instant[1], 1))
return periods.Instant(instant)
raise TypeError(f"Unexpected instant type: {type(instant)}")


def instant_date(instant):
Expand Down
Loading