Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.md text working-tree-encoding=UTF-8
*.rst text working-tree-encoding=UTF-8
122 changes: 112 additions & 10 deletions mobility/choice_models/population_trips.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
from mobility.motives import Motive, HomeMotive, OtherMotive
from mobility.transport_modes.transport_mode import TransportMode
from mobility.parsers.mobility_survey import MobilitySurvey
from mobility.choice_models.population_trips_checkpoint import PopulationTripsCheckpointAsset
from mobility.choice_models.population_trips_resume import (
compute_resume_plan,
try_load_checkpoint,
restore_state_or_fresh_start,
prune_tmp_artifacts,
rehydrate_congestion_snapshot,
)

class PopulationTrips(FileAsset):
"""
Expand Down Expand Up @@ -336,7 +344,23 @@ def run_model(self, is_weekday):
parameters = self.inputs["parameters"]

cache_path = self.cache_path["weekday_flows"] if is_weekday is True else self.cache_path["weekend_flows"]
tmp_folders = self.prepare_tmp_folders(cache_path)

run_key = self.inputs_hash
resume_plan = compute_resume_plan(
run_key=run_key,
is_weekday=is_weekday,
n_iterations=parameters.n_iterations,
)
if resume_plan.resume_from_iter is None:
logging.info("No checkpoint found for run_key=%s is_weekday=%s. Starting from scratch.", run_key, str(is_weekday))
else:
logging.info(
"Latest checkpoint found for run_key=%s is_weekday=%s: iteration=%s",
run_key,
str(is_weekday),
str(resume_plan.resume_from_iter),
)
tmp_folders = self.prepare_tmp_folders(cache_path, resume=(resume_plan.resume_from_iter is not None))

chains_by_motive, chains, demand_groups = self.state_initializer.get_chains(
population,
Expand Down Expand Up @@ -369,8 +393,38 @@ def run_model(self, is_weekday):
)

remaining_sinks = sinks.clone()

for iteration in range(1, parameters.n_iterations+1):
start_iteration = 1

if resume_plan.resume_from_iter is not None:
ckpt = try_load_checkpoint(
run_key=run_key,
is_weekday=is_weekday,
iteration=resume_plan.resume_from_iter,
)
current_states, remaining_sinks, restored = restore_state_or_fresh_start(
ckpt=ckpt,
stay_home_state=stay_home_state,
sinks=sinks,
rng=self.rng,
)

if restored:
start_iteration = resume_plan.start_iteration
logging.info(
"Resuming PopulationTrips from checkpoint: run_key=%s is_weekday=%s iteration=%s",
run_key,
str(is_weekday),
str(resume_plan.resume_from_iter),
)
prune_tmp_artifacts(tmp_folders=tmp_folders, keep_up_to_iter=resume_plan.resume_from_iter)
costs = rehydrate_congestion_snapshot(
costs_aggregator=costs_aggregator,
run_key=run_key,
last_completed_iter=resume_plan.resume_from_iter,
n_iter_per_cost_update=parameters.n_iter_per_cost_update,
)

for iteration in range(start_iteration, parameters.n_iterations+1):

logging.info(f"Iteration n°{iteration}")

Expand Down Expand Up @@ -423,14 +477,45 @@ def run_model(self, is_weekday):
iteration,
parameters.n_iter_per_cost_update,
current_states_steps,
costs_aggregator
costs_aggregator,
run_key=self.inputs_hash
)

remaining_sinks = self.state_updater.get_new_sinks(
current_states_steps,
sinks,
motives
)

# Save per-iteration checkpoint after all state has been advanced.
try:
PopulationTripsCheckpointAsset(
run_key=run_key,
is_weekday=is_weekday,
iteration=iteration,
current_states=current_states,
remaining_sinks=remaining_sinks,
rng_state=self.rng.getstate(),
).create_and_get_asset()
except Exception:
logging.exception("Failed to save checkpoint for iteration %s.", str(iteration))

# If we resumed after completing all iterations (or start_iteration > n_iterations),
# rebuild step-level flows from cached artifacts for final output.
if "current_states_steps" not in locals():
possible_states_steps = self.state_updater.get_possible_states_steps(
current_states,
demand_groups,
chains_by_motive,
costs_aggregator,
remaining_sinks,
motive_dur,
parameters.n_iterations,
motives,
parameters.min_activity_time_constant,
tmp_folders
)
current_states_steps = self.state_updater.get_current_states_steps(current_states, possible_states_steps)


costs = costs_aggregator.get_costs_by_od_and_mode(
Expand Down Expand Up @@ -464,9 +549,25 @@ def run_model(self, is_weekday):
)

return current_states_steps, sinks, demand_groups, costs, chains

def remove(self, remove_checkpoints: bool = True):
"""Remove cached outputs for this PopulationTrips run.

By default this also removes any saved checkpoints for this run_key, to avoid
resuming from stale intermediate state after a "clean" remove.
"""
super().remove()

if remove_checkpoints:
run_key = self.inputs_hash
removed = 0
removed += PopulationTripsCheckpointAsset.remove_checkpoints_for_run(run_key=run_key, is_weekday=True)
removed += PopulationTripsCheckpointAsset.remove_checkpoints_for_run(run_key=run_key, is_weekday=False)
if removed > 0:
logging.info("Removed %s checkpoint files for run_key=%s", str(removed), str(run_key))


def prepare_tmp_folders(self, cache_path):
def prepare_tmp_folders(self, cache_path, resume: bool = False):
"""Create per-run temp folders next to the cache path.

Args:
Expand All @@ -478,14 +579,15 @@ def prepare_tmp_folders(self, cache_path):

inputs_hash = str(cache_path.stem).split("-")[0]

def rm_then_mkdirs(folder_name):
def ensure_dir(folder_name):
path = cache_path.parent / (inputs_hash + "-" + folder_name)
shutil.rmtree(path, ignore_errors=True)
os.makedirs(path)
if resume is False:
shutil.rmtree(path, ignore_errors=True)
os.makedirs(path, exist_ok=True)
return path

folders = ["spatialized-chains", "modes", "flows", "sequences-index"]
folders = {f: rm_then_mkdirs(f) for f in folders}
folders = {f: ensure_dir(f) for f in folders}

return folders

Expand Down Expand Up @@ -583,7 +685,7 @@ def plot_modal_share(self, zone="origin", mode="car", period="weekdays",

if mode == "public_transport":
mode_name = "Public transport"
mode_share["mode"] = mode_share["mode"].replace("\S+\/public_transport\/\S+", "public_transport", regex=True)
mode_share["mode"] = mode_share["mode"].replace(r"\S+/public_transport/\S+", "public_transport", regex=True)
else:
mode_name = mode.capitalize()
mode_share = mode_share[mode_share["mode"] == mode]
Expand Down
194 changes: 194 additions & 0 deletions mobility/choice_models/population_trips_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import os
import json
import pickle
import pathlib
import logging
import re

import polars as pl

from mobility.file_asset import FileAsset


class PopulationTripsCheckpointAsset(FileAsset):
"""Per-iteration checkpoint for PopulationTrips to enable crash-safe resume.

The checkpoint is keyed by:
- run_key: PopulationTrips.inputs_hash (includes the seed and all params)
- is_weekday: True/False
- iteration: last completed iteration k

Payload:
- current_states (pl.DataFrame)
- remaining_sinks (pl.DataFrame)
- rng_state (pickle of random.Random.getstate())

Notes:
- We write a JSON meta file last so incomplete checkpoints are ignored.
- This asset is intentionally not part of the main model dependency graph;
it is only used as an optional resume source.
"""

SCHEMA_VERSION = 1

def __init__(
self,
*,
run_key: str,
is_weekday: bool,
iteration: int,
current_states: pl.DataFrame | None = None,
remaining_sinks: pl.DataFrame | None = None,
rng_state=None,
):
self._payload_current_states = current_states
self._payload_remaining_sinks = remaining_sinks
self._payload_rng_state = rng_state

inputs = {
"run_key": str(run_key),
"is_weekday": bool(is_weekday),
"iteration": int(iteration),
"schema_version": self.SCHEMA_VERSION,
}

project_folder = pathlib.Path(os.environ["MOBILITY_PROJECT_DATA_FOLDER"])
period = "weekday" if is_weekday else "weekend"
base_dir = project_folder / "population_trips" / period / "checkpoints"

stem = f"checkpoint_{run_key}_iter_{int(iteration)}"
cache_path = {
"current_states": base_dir / f"{stem}_current_states.parquet",
"remaining_sinks": base_dir / f"{stem}_remaining_sinks.parquet",
"rng_state": base_dir / f"{stem}_rng_state.pkl",
"meta": base_dir / f"{stem}.json",
}

super().__init__(inputs, cache_path)

def get_cached_asset(self):
current_states = pl.read_parquet(self.cache_path["current_states"])
remaining_sinks = pl.read_parquet(self.cache_path["remaining_sinks"])
with open(self.cache_path["rng_state"], "rb") as f:
rng_state = pickle.load(f)

meta = {}
try:
with open(self.cache_path["meta"], "r", encoding="utf-8") as f:
meta = json.load(f)
except Exception:
# Meta is only for convenience; payload files are the source of truth.
pass

return {
"current_states": current_states,
"remaining_sinks": remaining_sinks,
"rng_state": rng_state,
"meta": meta,
}

def create_and_get_asset(self):
for p in self.cache_path.values():
pathlib.Path(p).parent.mkdir(parents=True, exist_ok=True)

if self._payload_current_states is None or self._payload_remaining_sinks is None or self._payload_rng_state is None:
raise ValueError("Checkpoint payload is missing (current_states, remaining_sinks, rng_state).")

def atomic_write_bytes(final_path: pathlib.Path, data: bytes):
tmp = pathlib.Path(str(final_path) + ".tmp")
with open(tmp, "wb") as f:
f.write(data)
os.replace(tmp, final_path)

def atomic_write_text(final_path: pathlib.Path, text: str):
tmp = pathlib.Path(str(final_path) + ".tmp")
with open(tmp, "w", encoding="utf-8") as f:
f.write(text)
os.replace(tmp, final_path)

# Write payload first
tmp_states = pathlib.Path(str(self.cache_path["current_states"]) + ".tmp")
self._payload_current_states.write_parquet(tmp_states)
os.replace(tmp_states, self.cache_path["current_states"])

tmp_sinks = pathlib.Path(str(self.cache_path["remaining_sinks"]) + ".tmp")
self._payload_remaining_sinks.write_parquet(tmp_sinks)
os.replace(tmp_sinks, self.cache_path["remaining_sinks"])

atomic_write_bytes(self.cache_path["rng_state"], pickle.dumps(self._payload_rng_state, protocol=pickle.HIGHEST_PROTOCOL))

# Meta last, so readers only see complete checkpoints.
meta = {
"run_key": self.inputs["run_key"],
"is_weekday": self.inputs["is_weekday"],
"iteration": self.inputs["iteration"],
"schema_version": self.SCHEMA_VERSION,
}
atomic_write_text(self.cache_path["meta"], json.dumps(meta, sort_keys=True))

logging.info(
"Checkpoint saved: run_key=%s is_weekday=%s iteration=%s",
self.inputs["run_key"],
str(self.inputs["is_weekday"]),
str(self.inputs["iteration"]),
)

return self.get_cached_asset()

@staticmethod
def find_latest_checkpoint_iter(*, run_key: str, is_weekday: bool) -> int | None:
project_folder = pathlib.Path(os.environ["MOBILITY_PROJECT_DATA_FOLDER"])
period = "weekday" if is_weekday else "weekend"
base_dir = project_folder / "population_trips" / period / "checkpoints"
if not base_dir.exists():
return None

# FileAsset prefixes filenames with its own inputs_hash, so we match on the suffix.
pattern = f"*checkpoint_{run_key}_iter_*.json"
candidates = list(base_dir.glob(pattern))
if not candidates:
return None

rx = re.compile(rf"checkpoint_{re.escape(run_key)}_iter_(\d+)\.json$")
best = None
for p in candidates:
m = rx.search(p.name)
if not m:
continue
it = int(m.group(1))
if best is None or it > best:
best = it

return best

@staticmethod
def remove_checkpoints_for_run(*, run_key: str, is_weekday: bool) -> int:
"""Remove all checkpoint files for a given run_key and period.

Returns number of files removed.
"""
project_folder = pathlib.Path(os.environ["MOBILITY_PROJECT_DATA_FOLDER"])
period = "weekday" if is_weekday else "weekend"
base_dir = project_folder / "population_trips" / period / "checkpoints"
if not base_dir.exists():
return 0

# FileAsset prefixes filenames with its own inputs_hash, so just match suffix fragments.
pattern = f"*checkpoint_{run_key}_iter_*"
removed = 0
for p in base_dir.glob(pattern):
try:
p.unlink(missing_ok=True)
removed += 1
except Exception:
logging.exception("Failed to remove checkpoint file: %s", str(p))

# Also delete any stray tmp files.
for p in base_dir.glob(pattern + ".tmp"):
try:
p.unlink(missing_ok=True)
removed += 1
except Exception:
logging.exception("Failed to remove checkpoint tmp file: %s", str(p))

return removed
Loading
Loading