Add Phase 3 GRU ROS projection model#7
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Phase 3 of the Rest-of-Season (ROS) projection system, implementing a sequential GRU-based model that builds upon the existing Phase 2 architecture. Key additions include the src/models/ros module for sequence-based forecasting, new weekly feature extraction logic, and updates to the benchmarking and projection scripts to integrate the Phase 3 baseline. The Phase 2 model was also refactored to allow its encoder and decoder components to be reused by the sequential model. Review feedback highlights opportunities to optimize the sequential dataset's performance by avoiding expensive Pandas indexing during training and ensuring that configuration updates for data splits do not inadvertently overwrite existing settings.
| ) | ||
| phase3_config = dict(phase3_config) | ||
| phase3_config.setdefault("splits", {}) | ||
| phase3_config["splits"] = {"train_end_season": int(eval_year - 1)} |
There was a problem hiding this comment.
The assignment phase3_config["splits"] = {"train_end_season": ...} overwrites the entire splits dictionary. If the configuration file contains other keys under splits (such as val_season or test_season), they will be lost, which may lead to unexpected behavior in train_ros_sequence or SplitConfig.build. It is safer to update the existing dictionary.
| phase3_config["splits"] = {"train_end_season": int(eval_year - 1)} | |
| phase3_config["splits"] = {**phase3_config.get("splits", {}), "train_end_season": int(eval_year - 1)} |
| phase3_config.setdefault("training", {})["epochs"] = int(args.epochs) | ||
| if args.device is not None: | ||
| phase3_config.setdefault("training", {})["device"] = args.device | ||
| phase3_config["splits"] = {"train_end_season": int(args.year - 1)} |
There was a problem hiding this comment.
Similar to the benchmark script, this assignment overwrites the splits configuration dictionary. Any pre-existing keys in splits from the YAML config will be removed. It is recommended to update the dictionary instead of replacing it.
| phase3_config["splits"] = {"train_end_season": int(args.year - 1)} | |
| phase3_config["splits"] = {**phase3_config.get("splits", {}), "train_end_season": int(args.year - 1)} |
| def _history_positions(self, idx: int) -> list[int]: | ||
| row = self.snapshots.iloc[idx] | ||
| key = (int(row["mlbam_id"]), int(row["season"])) | ||
| positions = self._positions_by_group[key] | ||
| cutoff = ( | ||
| int(row["iso_year"]), | ||
| int(row["iso_week"]), | ||
| ) | ||
| hist = [] | ||
| for pos in positions: | ||
| pos_row = self.snapshots.iloc[pos] | ||
| pos_key = (int(pos_row["iso_year"]), int(pos_row["iso_week"])) | ||
| if pos_key <= cutoff: | ||
| hist.append(pos) | ||
| return hist[-self.max_seq_len :] |
There was a problem hiding this comment.
The implementation of _history_positions is inefficient because it iterates through all positions in a group and performs a Pandas iloc lookup for each one. Since self._positions_by_group already contains indices in chronological order (due to the sort in _build_positions_by_group), you can simply find the index of the current idx in the positions list and take the prefix. This avoids the loop and the expensive iloc calls. Additionally, you can use self.row_keys to avoid the initial iloc at line 104.
| def _history_positions(self, idx: int) -> list[int]: | |
| row = self.snapshots.iloc[idx] | |
| key = (int(row["mlbam_id"]), int(row["season"])) | |
| positions = self._positions_by_group[key] | |
| cutoff = ( | |
| int(row["iso_year"]), | |
| int(row["iso_week"]), | |
| ) | |
| hist = [] | |
| for pos in positions: | |
| pos_row = self.snapshots.iloc[pos] | |
| pos_key = (int(pos_row["iso_year"]), int(pos_row["iso_week"])) | |
| if pos_key <= cutoff: | |
| hist.append(pos) | |
| return hist[-self.max_seq_len :] | |
| def _history_positions(self, idx: int) -> list[int]: | |
| mlbam_id, season, _, _ = self.row_keys[idx] | |
| positions = self._positions_by_group[(mlbam_id, season)] | |
| try: | |
| # Since positions is sorted by time, the history is just the prefix up to idx. | |
| idx_in_group = positions.index(idx) | |
| return positions[: idx_in_group + 1][-self.max_seq_len :] | |
| except ValueError: | |
| return [idx] |
| "seq_mask": torch.from_numpy(mask), | ||
| "blend_features": torch.from_numpy(blend.astype(np.float32)), | ||
| "phase2_x": torch.from_numpy( | ||
| self.phase2_features.iloc[idx].to_numpy(dtype=np.float32).copy() |
There was a problem hiding this comment.
Accessing self.phase2_features.iloc[idx] inside __getitem__ is slow because Pandas iloc on a DataFrame creates a new Series object for every access. For better training performance, it is recommended to convert phase2_features (and sequence_features) to numpy arrays in __init__ and index into those arrays directly.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces the Phase 3 sequential ROS projection model, which incorporates within-season player trajectories using a GRU architecture. The changes include a new Statcast-based fallback for fetching weekly batting statistics, updates to benchmarking and projection scripts, and a refactor of the Phase 2 model to support latent state reuse. Review feedback identifies performance bottlenecks caused by inefficient row iteration, a potential memory exhaustion risk during validation, and a security vulnerability in the model loading logic.
| def _running_events_from_statcast(raw: pd.DataFrame) -> pd.DataFrame: | ||
| """Infer weekly SB/CS from Statcast descriptions when runner IDs are present.""" | ||
| rows: list[tuple[int, int, int, str]] = [] | ||
| for _, row in raw.iterrows(): |
There was a problem hiding this comment.
Using iterrows() for iterating over Statcast data is highly inefficient. Statcast datasets can be quite large (hundreds of thousands of rows), and iterrows() has significant overhead as it creates a Series object for each row. Consider using vectorized pandas operations or itertuples() for a significant performance boost.
| def _runs_from_statcast(pa_rows: pd.DataFrame) -> pd.DataFrame: | ||
| """Attribute runs scored to batter/runner IDs from pre-PA base state.""" | ||
| rows: list[tuple[int, int, int, int]] = [] | ||
| for _, row in pa_rows.iterrows(): |
| None, | ||
| fit_sequence_scaler=False, | ||
| ) | ||
| val_loader = DataLoader(val_ds, batch_size=max(1, len(val_ds))) |
There was a problem hiding this comment.
Setting the validation batch size to the entire length of the dataset (len(val_ds)) poses a significant risk of Out-Of-Memory (OOM) errors, especially when running on a GPU with a large validation set (e.g., multiple seasons of weekly snapshots). It is safer to use a fixed, reasonable batch size (like the training batch size) for validation.
| val_loader = DataLoader(val_ds, batch_size=max(1, len(val_ds))) | |
| val_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False) |
| load_dir / "ros_sequence_forecaster.pt", | ||
| map_location="cpu", | ||
| weights_only=False, | ||
| ) | ||
| instance = cls(base, state.get("config", {})) |
There was a problem hiding this comment.
Loading a torch checkpoint with weights_only=False is a security risk as it allows for the execution of arbitrary code during unpickling. Since the Phase 2 model loader in src/models/mtl_ros/model.py uses weights_only=True, and the state dictionary here only contains standard types (tensors, dicts, lists, strings), this should be switched to True for consistency and security.
state = torch.load(
load_dir / "ros_sequence_forecaster.pt",
map_location="cpu",
weights_only=True,
)
Summary
phase3into ROS projection generation and benchmarking, with test coverageTest Plan
.venv/bin/python -m pytest tests/ -q.venv/bin/python -m src.models.ros.train --config configs/ros.yaml --smoke --out /tmp/baseball-hydra-ros-sequence-smoke --device cpu