Conversation
There was a problem hiding this comment.
Pull request overview
Adds first-class RL + SSM (RLSSM) support to HSSM by introducing a new RLSSM model that builds a differentiable PyTensor Op from an annotated JAX SSM log-likelihood and plugs it into the existing distribution-building pipeline.
Changes:
- Introduces
RLSSMmodel class plus RL utilityvalidate_balanced_panel. - Extends configuration via
RLSSMConfig.ssm_logp_funcand exposes RLSSM in the public API. - Adds test coverage for RLSSM initialization/model build and updates RLSSMConfig validation tests.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
src/hssm/rl/rlssm.py |
New RLSSM model implementation integrating RL likelihood Op into HSSMBase. |
src/hssm/rl/utils.py |
Adds balanced-panel validation helper for RLSSM datasets. |
src/hssm/rl/__init__.py |
RL subpackage exports for RLSSM and utilities. |
src/hssm/config.py |
Adds ssm_logp_func to RLSSMConfig and validates presence. |
src/hssm/__init__.py |
Exposes RLSSM / RLSSMConfig at top-level. |
tests/test_rlssm.py |
New end-to-end-ish RLSSM tests (init, model build, balanced panel, smoke sampling). |
tests/test_rlssm_config.py |
Updates RLSSMConfig tests to include the new required field. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| f"same number of trials. Observed trial counts: {dict(counts)}" | ||
| ) | ||
|
|
||
| return int(len(counts)), int(counts.iloc[0]) |
There was a problem hiding this comment.
validate_balanced_panel only checks equal trial counts, but the RL likelihood builder reshapes the row order into (n_participants, n_trials, ...) (see make_rl_logp_func), which assumes each participant’s trials are in one contiguous block (and usually in-trial order). With interleaved participants, the panel can be “balanced” yet produce a silently incorrect likelihood. Consider validating contiguity (each participant appears in exactly one run of length n_trials) and/or sorting by participant_col (+ an optional trial_col if present) before returning (n_participants, n_trials).
| return int(len(counts)), int(counts.iloc[0]) | |
| # Ensure that each participant's trials form a single contiguous block | |
| # of rows of length n_trials. This is required because downstream code | |
| # reshapes the data into (n_participants, n_trials, ...) based on row | |
| # order, assuming no interleaving across participants. | |
| n_trials = int(counts.iloc[0]) | |
| # Identify contiguous "blocks" of identical participant IDs. | |
| blocks = data[participant_col].ne(data[participant_col].shift()).cumsum() | |
| block_counts = data.groupby([participant_col, blocks]).size() | |
| # Each participant must appear in exactly one block, and that block | |
| # must have length n_trials. | |
| blocks_per_participant = block_counts.groupby(level=0).size() | |
| invalid_multi_blocks = blocks_per_participant[blocks_per_participant != 1] | |
| invalid_block_sizes = block_counts[block_counts != n_trials] | |
| if not invalid_multi_blocks.empty or not invalid_block_sizes.empty: | |
| raise ValueError( | |
| "Data must be ordered so that each participant's trials appear in " | |
| "a single contiguous block of rows of length n_trials. " | |
| "Participants with non-contiguous or incorrectly sized blocks " | |
| f"were found. Consider sorting your data by '{participant_col}' " | |
| "and, if available, by a trial index column before building the " | |
| "RL likelihood." | |
| ) | |
| return int(len(counts)), n_trials |
There was a problem hiding this comment.
@cpaniaguam this one may have merit actually. Can we unpack this on tues (Apr. 21st)?
| "Please provide the correct participant column name via " | ||
| "`participant_col`." | ||
| ) | ||
|
|
There was a problem hiding this comment.
groupby(participant_col) drops NaN participant IDs by default, which can make n_participants/n_trials incorrect without an explicit error. Consider adding a check like data[participant_col].isna().any() and raising a clear ValueError if participant IDs are missing.
| # Ensure there are no missing participant IDs, since groupby will drop NaNs | |
| # silently, which would make n_participants / n_trials incorrect. | |
| if data[participant_col].isna().any(): | |
| raise ValueError( | |
| f"Column '{participant_col}' contains missing values. " | |
| "Please fill or remove rows with missing participant IDs before " | |
| "calling validate_balanced_panel." | |
| ) |
There was a problem hiding this comment.
Are we checking for NaN participant id's in the data validator?
If yes this is fine here no?
…ble and has required attributes
… callable and properly annotated
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ) | ||
|
|
||
| # Rearrange data so missing rows come first (no-op when missing_data=False). | ||
| self.data = _rearrange_data(self.data) | ||
|
|
There was a problem hiding this comment.
_rearrange_data(self.data) changes row order, but the RL logp Op reshapes trials purely by row order into (n_participants, n_trials, ...). If any rows are moved (e.g., when missing_data=True and rt == -999), this will break per-participant trial sequences and invalidate the RL learning dynamics. Since missing-data networks are not supported for RLSSM, consider raising an explicit error when missing_data/deadline handling is requested (or implement a participant-wise rearrangement that preserves within-subject order).
| counts = data.groupby(participant_col).size() | ||
| if counts.nunique() != 1: | ||
| raise ValueError( | ||
| "Data must form balanced panels: all participants must have the " | ||
| f"same number of trials. Observed trial counts: {dict(counts)}" | ||
| ) | ||
|
|
||
| return int(len(counts)), int(counts.iloc[0]) |
There was a problem hiding this comment.
validate_balanced_panel() only checks equal trial counts via groupby().size(), but it does not validate that rows are ordered/grouped by participant. The RL likelihood builder (make_rl_logp_func) reshapes arrays with .reshape(n_participants, n_trials, -1) based purely on row order, so interleaved participant rows will silently mix subjects/trials and produce an incorrect likelihood. Consider either (a) enforcing contiguous blocks per participant (and optionally stable-sorting by participant_col + a trial index column if available) or (b) returning a sorted copy of the data and using that downstream.
…preserve trial sequence integrity
… RLSSM initialization
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # All RLSSM parameters are treated as trialwise: the Op expects arrays of | ||
| # length n_total_trials for every parameter, and make_distribution.logp | ||
| # broadcasts scalar / (1,)-shaped tensors up to (n_obs,) accordingly. | ||
| params_is_trialwise = [ | ||
| True for param_name in self.params if param_name != "p_outlier" | ||
| ] | ||
|
|
||
| extra_fields_data = ( | ||
| None | ||
| if not self.extra_fields | ||
| else [deepcopy(self.data[field].values) for field in self.extra_fields] | ||
| ) | ||
|
|
||
| assert self.list_params is not None, "list_params should be set" | ||
| # self.loglik was set to the pytensor Op built in __init__; cast to | ||
| # narrow the inherited union type so make_distribution's type-checker | ||
| # accepts it without a runtime penalty. | ||
| loglik_op = cast("Callable[..., Any] | Op", self.loglik) | ||
| return make_distribution( | ||
| rv=self.model_name, | ||
| loglik=loglik_op, | ||
| list_params=self.list_params, | ||
| bounds=self.bounds, | ||
| lapse=self.lapse, | ||
| extra_fields=extra_fields_data, | ||
| params_is_trialwise=params_is_trialwise, | ||
| ) |
There was a problem hiding this comment.
params_is_trialwise is derived from self.params (excluding p_outlier), but it is passed alongside list_params=self.list_params. If self.list_params includes p_outlier (common in HSSMBase), this makes params_is_trialwise shorter and potentially misaligned with list_params, which can cause incorrect broadcasting or length-check failures in make_distribution. Build params_is_trialwise from self.list_params in the same order, marking p_outlier as non-trialwise.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…ary assertion for list_params
… for independent copies
…ror for unsupported usage
…asses-basemodelconfig-only-dict-supported' into inject-RLSSMConfig-directly-into-HSSMBase
| data: pd.DataFrame, | ||
| rlssm_config: RLSSMConfig, | ||
| participant_col: str = "participant_id", | ||
| include: list[dict[str, Any] | Any] | None = None, |
There was a problem hiding this comment.
What would you call it instead here?
We would want to make that change globally not just for this class I guess.
Either way, would do that as a separate PR.
| ) | ||
| if deadline is not False: | ||
| raise ValueError( | ||
| "RLSSM does not support `deadline` handling. " |
There was a problem hiding this comment.
@krishnbera do we actually have a solution for this?
| """ | ||
| # Start with defaults | ||
| config = cls.config_class.from_defaults(model, loglik_kind) | ||
| # get_config_class is provided by Config/RLSSMConfig mixin through MRO |
There was a problem hiding this comment.
why does RLSSMConfig show up here in this file?
…ed data constants
…LSSMConfig and related tests
…conditional checks
…emoving conditional checks" This reverts commit 7cf8bca.
…to-HSSMBase Inject rlssm config directly into hssm base
…y-injection-into-model-classes-basemodelconfig-only-dict-supported Handle configs via dependency injection
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
…oglik_kind key in RLSSMConfig; update model instantiation parameter name
AlexanderFengler
left a comment
There was a problem hiding this comment.
few substantive comments, but this should be very very close.
Thanks @cpaniaguam ( @krishnbera for visibility )
| for models that couple a reinforcement learning (RL) learning process with a | ||
| sequential sampling decision model (SSM). | ||
|
|
||
| The key difference from :class:`HSSM` is the likelihood: |
There was a problem hiding this comment.
I just want to flag that we should bring this logic up tomorrow (Apr 21st) in the team meeting once more explicitly.
For now I want to make sure this works and we can add models, but reflecting on it a bit, I wonder if we leave some room for harmonization on the table.
| # Raise early so the user gets a clear message before model construction. | ||
| if missing_data is not False: | ||
| raise ValueError( | ||
| "RLSSM does not support `missing_data` handling. " |
| ) | ||
| if deadline is not False: | ||
| raise ValueError( | ||
| "RLSSM does not support `deadline` handling. " |
| # Build the differentiable pytensor Op from the annotated SSM function. | ||
| # This Op supersedes the loglik/loglik_kind workflow: it is passed as | ||
| # `loglik` to HSSMBase so Config.validate() is satisfied, and | ||
| # _make_model_distribution() uses it directly without any further wrapping. | ||
| # | ||
| # Pass copies of list_params / extra_fields so the closure inside | ||
| # make_rl_logp_func captures its own isolated list objects. HSSMBase will | ||
| # later append "p_outlier" to self.list_params (which is the SAME list | ||
| # object as `list_params` above), and that mutation must NOT be visible to | ||
| # the Op's _validate_args_length check at sampling time. | ||
| loglik_op = make_rl_logp_op( | ||
| ssm_logp_func=rlssm_config.ssm_logp_func, | ||
| n_participants=n_participants, | ||
| n_trials=n_trials, | ||
| data_cols=list(data_cols), | ||
| list_params=list(list_params), | ||
| extra_fields=list(extra_fields), | ||
| ) |
| params_is_trialwise=params_is_trialwise, | ||
| ) | ||
|
|
||
| def _get_prefix(self, name_str: str) -> str: |
There was a problem hiding this comment.
this utility might be more generically useful than placing it in this .py file?
| "Please provide the correct participant column name via " | ||
| "`participant_col`." | ||
| ) | ||
|
|
There was a problem hiding this comment.
Are we checking for NaN participant id's in the data validator?
If yes this is fine here no?
| f"same number of trials. Observed trial counts: {dict(counts)}" | ||
| ) | ||
|
|
||
| return int(len(counts)), int(counts.iloc[0]) |
There was a problem hiding this comment.
@cpaniaguam this one may have merit actually. Can we unpack this on tues (Apr. 21st)?
| hooks: | ||
| - id: ruff | ||
| args: [--fix, --exit-non-zero-on-fix] | ||
| exclude: ^tests/ |
There was a problem hiding this comment.
why actually exclude tests from ruff?
| hooks: | ||
| - id: mypy | ||
| args: [--no-strict-optional, --ignore-missing-imports] | ||
| exclude: ^tests/ |
There was a problem hiding this comment.
same here. just asking, is that actually a pattern people follow? Hadn't seen that before.
There was a problem hiding this comment.
correct to take this one mostly as copy/paste from our original HSSM class?
This pull request introduces reinforcement learning sequential sampling model (RLSSM) support to the HSSM package. It adds a new
RLSSMclass, supporting configuration, likelihood construction, and data validation for RL+SSM models, and refines the configuration workflow to require a fully annotated log-likelihood function. The changes also improve pre-commit configuration and update the package's public API.Major features and changes:
1. RLSSM Model Integration
RLSSMclass insrc/hssm/rl/rlssm.pyto support models that combine reinforcement learning processes with sequential sampling models. This class builds a differentiable pytensor Op from an annotated JAX log-likelihood function and enforces strict data requirements for balanced panels.validate_balanced_panelinsrc/hssm/rl/utils.pyto ensure input data forms a balanced panel, which is required for RLSSM models.2. Configuration Enhancements
RLSSMConfiginsrc/hssm/config.pyto require anssm_logp_func(an annotated JAX SSM log-likelihood function), replacing the previousloglik/loglik_kindworkflow. Added runtime validation to ensure this function is callable and properly annotated. [1] [2] [3]from_rlssm_dictto accept a config dictionary and extractssm_logp_funcandmodel_namedirectly from it, simplifying model instantiation.3. Public API and Package Structure
RLSSMandRLSSMConfigin the package's public API viasrc/hssm/__init__.pyand created a newsrc/hssm/rl/__init__.pyfor RL-related exports. [1] [2] [3]4. Developer Experience
.pre-commit-config.yamlto exclude thetests/directory fromruffandmypychecks, streamlining development workflows.