Skip to content

Add Phase 3 GRU ROS projection model#7

Draft
lambertchu wants to merge 4 commits intomainfrom
codex/phase3-ros-gru
Draft

Add Phase 3 GRU ROS projection model#7
lambertchu wants to merge 4 commits intomainfrom
codex/phase3-ros-gru

Conversation

@lambertchu
Copy link
Copy Markdown
Owner

Summary

  • add the Phase 3 GRU-based ROS sequence model and training pipeline
  • reuse the frozen Phase 2 encoder/decoder with weekly sequence features and cutoff datasets
  • wire phase3 into ROS projection generation and benchmarking, with test coverage

Test Plan

  • .venv/bin/python -m pytest tests/ -q
  • .venv/bin/python -m src.models.ros.train --config configs/ros.yaml --smoke --out /tmp/baseball-hydra-ros-sequence-smoke --device cpu

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Phase 3 of the Rest-of-Season (ROS) projection system, implementing a sequential GRU-based model that builds upon the existing Phase 2 architecture. Key additions include the src/models/ros module for sequence-based forecasting, new weekly feature extraction logic, and updates to the benchmarking and projection scripts to integrate the Phase 3 baseline. The Phase 2 model was also refactored to allow its encoder and decoder components to be reused by the sequential model. Review feedback highlights opportunities to optimize the sequential dataset's performance by avoiding expensive Pandas indexing during training and ensuring that configuration updates for data splits do not inadvertently overwrite existing settings.

Comment thread scripts/benchmark_ros.py Outdated
)
phase3_config = dict(phase3_config)
phase3_config.setdefault("splits", {})
phase3_config["splits"] = {"train_end_season": int(eval_year - 1)}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assignment phase3_config["splits"] = {"train_end_season": ...} overwrites the entire splits dictionary. If the configuration file contains other keys under splits (such as val_season or test_season), they will be lost, which may lead to unexpected behavior in train_ros_sequence or SplitConfig.build. It is safer to update the existing dictionary.

Suggested change
phase3_config["splits"] = {"train_end_season": int(eval_year - 1)}
phase3_config["splits"] = {**phase3_config.get("splits", {}), "train_end_season": int(eval_year - 1)}

Comment thread scripts/generate_ros_projections.py Outdated
phase3_config.setdefault("training", {})["epochs"] = int(args.epochs)
if args.device is not None:
phase3_config.setdefault("training", {})["device"] = args.device
phase3_config["splits"] = {"train_end_season": int(args.year - 1)}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the benchmark script, this assignment overwrites the splits configuration dictionary. Any pre-existing keys in splits from the YAML config will be removed. It is recommended to update the dictionary instead of replacing it.

Suggested change
phase3_config["splits"] = {"train_end_season": int(args.year - 1)}
phase3_config["splits"] = {**phase3_config.get("splits", {}), "train_end_season": int(args.year - 1)}

Comment thread src/models/ros/dataset.py Outdated
Comment on lines +103 to +117
def _history_positions(self, idx: int) -> list[int]:
row = self.snapshots.iloc[idx]
key = (int(row["mlbam_id"]), int(row["season"]))
positions = self._positions_by_group[key]
cutoff = (
int(row["iso_year"]),
int(row["iso_week"]),
)
hist = []
for pos in positions:
pos_row = self.snapshots.iloc[pos]
pos_key = (int(pos_row["iso_year"]), int(pos_row["iso_week"]))
if pos_key <= cutoff:
hist.append(pos)
return hist[-self.max_seq_len :]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation of _history_positions is inefficient because it iterates through all positions in a group and performs a Pandas iloc lookup for each one. Since self._positions_by_group already contains indices in chronological order (due to the sort in _build_positions_by_group), you can simply find the index of the current idx in the positions list and take the prefix. This avoids the loop and the expensive iloc calls. Additionally, you can use self.row_keys to avoid the initial iloc at line 104.

Suggested change
def _history_positions(self, idx: int) -> list[int]:
row = self.snapshots.iloc[idx]
key = (int(row["mlbam_id"]), int(row["season"]))
positions = self._positions_by_group[key]
cutoff = (
int(row["iso_year"]),
int(row["iso_week"]),
)
hist = []
for pos in positions:
pos_row = self.snapshots.iloc[pos]
pos_key = (int(pos_row["iso_year"]), int(pos_row["iso_week"]))
if pos_key <= cutoff:
hist.append(pos)
return hist[-self.max_seq_len :]
def _history_positions(self, idx: int) -> list[int]:
mlbam_id, season, _, _ = self.row_keys[idx]
positions = self._positions_by_group[(mlbam_id, season)]
try:
# Since positions is sorted by time, the history is just the prefix up to idx.
idx_in_group = positions.index(idx)
return positions[: idx_in_group + 1][-self.max_seq_len :]
except ValueError:
return [idx]

Comment thread src/models/ros/dataset.py Outdated
"seq_mask": torch.from_numpy(mask),
"blend_features": torch.from_numpy(blend.astype(np.float32)),
"phase2_x": torch.from_numpy(
self.phase2_features.iloc[idx].to_numpy(dtype=np.float32).copy()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Accessing self.phase2_features.iloc[idx] inside __getitem__ is slow because Pandas iloc on a DataFrame creates a new Series object for every access. For better training performance, it is recommended to convert phase2_features (and sequence_features) to numpy arrays in __init__ and index into those arrays directly.

@lambertchu
Copy link
Copy Markdown
Owner Author

/gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Phase 3 sequential ROS projection model, which incorporates within-season player trajectories using a GRU architecture. The changes include a new Statcast-based fallback for fetching weekly batting statistics, updates to benchmarking and projection scripts, and a refactor of the Phase 2 model to support latent state reuse. Review feedback identifies performance bottlenecks caused by inefficient row iteration, a potential memory exhaustion risk during validation, and a security vulnerability in the model loading logic.

def _running_events_from_statcast(raw: pd.DataFrame) -> pd.DataFrame:
"""Infer weekly SB/CS from Statcast descriptions when runner IDs are present."""
rows: list[tuple[int, int, int, str]] = []
for _, row in raw.iterrows():
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using iterrows() for iterating over Statcast data is highly inefficient. Statcast datasets can be quite large (hundreds of thousands of rows), and iterrows() has significant overhead as it creates a Series object for each row. Consider using vectorized pandas operations or itertuples() for a significant performance boost.

def _runs_from_statcast(pa_rows: pd.DataFrame) -> pd.DataFrame:
"""Attribute runs scored to batter/runner IDs from pre-PA base state."""
rows: list[tuple[int, int, int, int]] = []
for _, row in pa_rows.iterrows():
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the usage in _running_events_from_statcast, using iterrows() here to attribute runs is a performance bottleneck for large Statcast datasets. Switching to itertuples() or vectorized operations where possible would improve execution time.

Comment thread src/models/ros/model.py
None,
fit_sequence_scaler=False,
)
val_loader = DataLoader(val_ds, batch_size=max(1, len(val_ds)))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Setting the validation batch size to the entire length of the dataset (len(val_ds)) poses a significant risk of Out-Of-Memory (OOM) errors, especially when running on a GPU with a large validation set (e.g., multiple seasons of weekly snapshots). It is safer to use a fixed, reasonable batch size (like the training batch size) for validation.

Suggested change
val_loader = DataLoader(val_ds, batch_size=max(1, len(val_ds)))
val_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False)

Comment thread src/models/ros/model.py
Comment on lines +502 to +506
load_dir / "ros_sequence_forecaster.pt",
map_location="cpu",
weights_only=False,
)
instance = cls(base, state.get("config", {}))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Loading a torch checkpoint with weights_only=False is a security risk as it allows for the execution of arbitrary code during unpickling. Since the Phase 2 model loader in src/models/mtl_ros/model.py uses weights_only=True, and the state dictionary here only contains standard types (tensors, dicts, lists, strings), this should be switched to True for consistency and security.

        state = torch.load(
            load_dir / "ros_sequence_forecaster.pt",
            map_location="cpu",
            weights_only=True,
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant