Update HSSM class for choice-only models#920
Conversation
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…x-likelihood Added softmax likelihood for choice-only models
Added config files for softmax
…date-config-for-choice-only-models
…a/deadline models
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…a/deadline models
…b.com/lnccbrown/HSSM into 919-update-hssm-class-for-choice-only
src/hssm/distribution_utils/dist.py
Outdated
| lapse : optional | ||
| A bmb.Prior object representing the lapse distribution. | ||
| is_choice_only : bool | ||
| Whether the model is a choice-only model. This parameter overrides |
There was a problem hiding this comment.
The last sentence here seems incomplete. What does the parameter override?
There was a problem hiding this comment.
Pull request overview
This PR updates the HSSM class and supporting infrastructure to support choice-only models (models where only response/choice data is collected, without reaction times). The softmax_inv_temperature model family is introduced as the primary use case.
Changes:
- New
softmax_inv_temperaturelikelihood function added toanalytical.py, with config files for 2-choice and 3-choice variants, and a shared factory_softmax_inv_temperature_config.py HSSM.__init__anddist.pyupdated withis_choice_onlyguards to skip RT-specific processing (NDT checks, missing data handling, lapse data slicing) for choice-only modelsDataValidatorMixin._post_check_data_sanityand_handle_missing_data_and_deadlineupdated to skip RT-based checks for choice-only models;_get_design_matricesinregression_param.pychanged from hardcoded"rt ~ "to"response ~ "for compatibility
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
src/hssm/hssm.py |
Added is_choice_only flag, choice-only–aware missing data guard, dummy RV creation, and updated response_c property |
src/hssm/distribution_utils/dist.py |
Added is_choice_only parameter to make_hssm_rv, make_distribution, and make_distribution_for_supported_model; conditionalised lapse data indexing and NDT checks |
src/hssm/data_validator.py |
Early returns in _post_check_data_sanity and _handle_missing_data_and_deadline for choice-only models |
src/hssm/param/regression_param.py |
Changed formula LHS from "rt" to "response" for design matrix construction |
src/hssm/likelihoods/analytical.py |
New softmax_inv_temperature log-likelihood function; Type[...] → type[...] modernization |
src/hssm/modelconfig/_softmax_inv_temperature_config.py |
New shared factory for the softmax inverse temperature model config |
src/hssm/modelconfig/softmax_inv_temperature_2_config.py |
Config for 2-choice softmax inv. temperature model |
src/hssm/modelconfig/softmax_inv_temperature_3_config.py |
Config for 3-choice softmax inv. temperature model |
src/hssm/_types.py |
Added new models to SupportedModels literal |
tests/test_modelconfig.py |
Tests for softmax_inv_temperature_config factory |
tests/test_likelihoods_choice_only.py |
Shape tests for softmax_inv_temperature |
tests/test_hssm.py |
Test for is_choice_only and deadline interaction |
tests/slow/test_choice_only.py |
Integration tests for choice-only models with various sampler/backend combinations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
AlexanderFengler
left a comment
There was a problem hiding this comment.
Key thing for me to still address is the lapse part, apart from that, I think it's fine to relegate further cleanup/improvements to downstream PRs that deal with the class refactor.
Thanks @digicosmos86
| ) | ||
| lapse_logp = lapse_func(data[:, 0].eval()) | ||
| data_for_lapse = data if is_choice_only else data[:, 0] | ||
| lapse_logp = lapse_func(data_for_lapse.eval()) |
There was a problem hiding this comment.
We should be careful here.
The lapse function logic we have is based around rt lapses, so somethine like Uniform(0, 20), which then evaluates to 1/20 for each rt, makes sense.
If a model is choice_only we want a different default lapse distribution, I would suggest simply 1/n_choices.
There was a problem hiding this comment.
@AlexanderFengler I have made the updates to fix this. Can you take another look?
BTW: this is an important implementation detail. In the future, can this type of details be communicated to us early? It's much easier to plan for these details early
There was a problem hiding this comment.
@digicosmos86 the truth is that I didn't think about lapse distributions in this context before that's why I didn't mention it earlier.
When I saw it I realized this needs to be figured out correctly.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…e necessary options
AlexanderFengler
left a comment
There was a problem hiding this comment.
This seems ok now. Eventually we might want to get out of the if-else doom with a refactor that focuses the execution path between choice_only and choice,rt models, but can do that later. Thanks @digicosmos86
This PR updates the HSSM class to be compatible with choice-only models. Changes are:
dist.pywithis_choice_onlytoggle tomake_distributionandmake_hssm_rvto account for choice-only modelsDataValidatorto skip certain checks if the model is choice-onlyregression_param.pyEDIT: I had to update package dependencies to streamline CI workflows based on how
uvworks.pyproject.tomlwas a bit messy:devdependencies (installed withuv add package-name --dev) are automatically installed withuv sync. These should include all dev dependencies required for tests.uv add package-name --group group_name) are not automatically installed viauv syncbut can be requested with--groupflag. These should include specific dependencies required for some workflows. For example,notebookdependency group should only be specifically requested when notebooks need to be executeduv add package-name --extra extra_name) are not installed viauv synceither, but can be requested with--extraflag. These are extra dependencies that can be requested via pip in square brackets (e.g.pip install hssm[cuda12]). These are more user facing and should not be installed on ci unless specific tests require them