Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
b42e4f1
feat: added softmax likelihood
digicosmos86 Feb 26, 2026
85f928c
Update src/hssm/likelihoods/analytical.py
digicosmos86 Feb 26, 2026
a2c8fbc
Update src/hssm/likelihoods/analytical.py
digicosmos86 Feb 26, 2026
3b5349a
feat: added a general function for creating softmax family of models
digicosmos86 Feb 26, 2026
6fa01e1
feat: added 2 modelconfigs for softmax family of models
digicosmos86 Feb 26, 2026
b685fa3
feat: added softmax family of models with 2 and 3 logits to Supported…
digicosmos86 Feb 26, 2026
3684230
test: testing updated model configs
digicosmos86 Feb 26, 2026
ce64257
fix: failing tests
digicosmos86 Feb 26, 2026
c91a557
fix: incorporated PR feedback
digicosmos86 Mar 5, 2026
2ea683e
fix: tests for `softmax_inv_temperature_config`
digicosmos86 Mar 5, 2026
8194f3b
fix: update softmax configuration to use n_choices instead of n_logits
digicosmos86 Mar 5, 2026
378b02c
Merge pull request #909 from lnccbrown/908-implement-a-general-softma…
digicosmos86 Mar 5, 2026
7905116
Merge pull request #911 from lnccbrown/910-add-config-files-for-softmax
digicosmos86 Mar 5, 2026
b635889
Merge branch '908-implement-a-general-softmax-likelihood' into 906-up…
digicosmos86 Mar 5, 2026
ec55838
Merge branch 'main' into 906-update-config-for-choice-only-models
digicosmos86 Mar 5, 2026
71dc8bf
merge main
AlexanderFengler Mar 6, 2026
ab45781
fix: make `make_distribution` and `make_hssm_rv` compatible with choi…
digicosmos86 Mar 6, 2026
02d7eed
fix: update DataValidatorMixin to be compatible with choice only models
digicosmos86 Mar 6, 2026
c486e35
feat: update HSSM class to be compatible with choice-only models
digicosmos86 Mar 6, 2026
e469cfb
tests: added tests for compatibility with choice-only and missing dat…
digicosmos86 Mar 6, 2026
cefdfb7
fix: a dummy formula in regression_param.py
digicosmos86 Mar 6, 2026
0711a5f
tests: added simple sampling tests for choice-only models
digicosmos86 Mar 6, 2026
943e263
feat: added softmax likelihood
digicosmos86 Feb 26, 2026
738dd33
Update src/hssm/likelihoods/analytical.py
digicosmos86 Feb 26, 2026
1aeb8a8
Update src/hssm/likelihoods/analytical.py
digicosmos86 Feb 26, 2026
f1d5dab
fix: make `make_distribution` and `make_hssm_rv` compatible with choi…
digicosmos86 Mar 6, 2026
ba31e9f
fix: update DataValidatorMixin to be compatible with choice only models
digicosmos86 Mar 6, 2026
e24cd95
feat: update HSSM class to be compatible with choice-only models
digicosmos86 Mar 6, 2026
9ca3a24
tests: added tests for compatibility with choice-only and missing dat…
digicosmos86 Mar 6, 2026
1153b0f
fix: a dummy formula in regression_param.py
digicosmos86 Mar 6, 2026
1302d3a
tests: added simple sampling tests for choice-only models
digicosmos86 Mar 6, 2026
5b4a66c
Merge branch '919-update-hssm-class-for-choice-only' of https://githu…
digicosmos86 Mar 6, 2026
6a9584b
fix: apply stricter missing data check
digicosmos86 Mar 9, 2026
d10aaf1
fix: type in docstrings
digicosmos86 Mar 9, 2026
a2f1db8
fix: revert the docstring for params_only argument
digicosmos86 Mar 9, 2026
86155df
fix: added a separate default lapse distribution for choice-only models
digicosmos86 Mar 9, 2026
60abaea
fix: stop cuda from being installed in CI
digicosmos86 Mar 9, 2026
c632dd4
fix: stop cuda from being installed on CI
digicosmos86 Mar 9, 2026
f2a9055
fix: add no-sync tags to avoid cuda installation
digicosmos86 Mar 9, 2026
dc954b2
fix: update fast test command to include bayesflow and cuda12 options
digicosmos86 Mar 9, 2026
6b6be36
fix: adjust test commands to ensure proper execution order and includ…
digicosmos86 Mar 9, 2026
7695ce8
fix: dependencies
digicosmos86 Mar 9, 2026
8c39e50
fix: remove pre-commit from notebook dependency group
digicosmos86 Mar 9, 2026
a49c2ea
fix: move onnx-runtime to dev
digicosmos86 Mar 9, 2026
6570991
fix: how lapse is specified
digicosmos86 Mar 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/setup-env-notebooks/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ runs:

- name: Install hssm
if: steps.cache.outputs.cache-hit != 'true'
run: uv sync --group test --group notebook
run: uv sync --group notebook
shell: bash
2 changes: 1 addition & 1 deletion .github/setup-env/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ runs:

- name: Install hssm
if: steps.cache.outputs.cache-hit != 'true'
run: uv sync --group test
run: uv sync
shell: bash
2 changes: 1 addition & 1 deletion .github/workflows/build_and_publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ jobs:
uv publish --token ${{ secrets.PYPI_TOKEN }}

- name: Build and publish docs
run: uv run mkdocs gh-deploy --force
run: uv run --group notebook --group docs mkdocs gh-deploy --force
2 changes: 1 addition & 1 deletion .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ jobs:
python-version: ${{ matrix.python-version }}

- name: Run mkdocs
run: uv run mkdocs gh-deploy --force
run: uv run --group notebook --group docs mkdocs gh-deploy --force
2 changes: 1 addition & 1 deletion .github/workflows/check_notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:

echo "Cleaning $notebook"
uv run nb-clean clean -o "$notebook"
if ! uv run jupyter nbconvert --ExecutePreprocessor.timeout=10000 --to notebook --execute "$notebook"; then
if ! uv run --group notebook jupyter nbconvert --ExecutePreprocessor.timeout=10000 --to notebook --execute "$notebook"; then
echo "::error::Failed to execute notebook: $notebook"
EXIT_CODE=1
fi
Expand Down
36 changes: 18 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ dependencies = [
"arviz>=0.22.0",
"bambi>=0.17.2",
"cloudpickle>=3.0.0",
"formulae>=0.6.0",
"hddm-wfpt>=0.1.6",
"huggingface-hub>=0.34.0",
"jaxonnxruntime>=0.3.0",
Expand All @@ -52,39 +51,40 @@ bayesflow = ["bayesflow", "keras>=3"]
[dependency-groups]
dev = [
"coverage>=7.6.4",
"mypy>=1.11.1",
"pytest-cov>=6.0.0",
"pytest-xdist>=3.6.1",
"pytest>=8.3.1",
"pytest-random-order>=1.1.1",
"pytest-rerunfailures>=15.0",
"ruff>=0.15.0",
"pre-commit>=4.1.0",
"onnxruntime>=1.17.1",
]

notebook = [
"graphviz>=0.20.3",
"ipykernel>=6.29.5",
"ipython>=8.31.0",
"ipywidgets>=8.1.2",
"jupyterlab>=4.2.4",
"mistune>=3.0.2",
"mkdocs-material>=9.5.21",
"mkdocs>=1.6.0",
"mkdocs-jupyter>=0.25.1",
"mkdocstrings-python>=1.10.0",
"nbconvert>=7.16.5",
"onnxruntime>=1.17.1",
"nb-clean>=4.0.1",
"nbval>=0.11.0",
"pre-commit>=2.20.0",
"ptpython>=3.0.29",
"pyarrow>=20.0.0",
"lanfactory>=0.5.3",
"HSSM[test]",
"zeus-mcmc>=2.5.4",
]

test = [
"mypy>=1.11.1",
"pytest-cov>=6.0.0",
"pytest-xdist>=3.6.1",
"pytest>=8.3.1",
"ruff>=0.15.0",
"pytest-random-order>=1.1.1",
"pytest-rerunfailures>=15.0",
docs = [
"mkdocs-material>=9.5.21",
"mkdocs>=1.6.0",
"mkdocs-jupyter>=0.25.1",
"mkdocstrings-python>=1.10.0",
]

notebook = ["ipykernel>=6.29.5", "zeus-mcmc>=2.5.4"]

[tool.ruff]
line-length = 88
target-version = "py310"
Expand Down
7 changes: 4 additions & 3 deletions src/hssm/data_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def _pre_check_data_sanity(self):

def _post_check_data_sanity(self):
"""Check if the data is clean enough for the model."""
if self.is_choice_only:
return
if self.deadline or self.missing_data:
if -999.0 not in self.data["rt"].unique():
raise ValueError(
Expand Down Expand Up @@ -106,9 +108,8 @@ def _handle_missing_data_and_deadline(self):
if not self.missing_data and not self.deadline:
# In the case of choice only model, we don't need to do anything with the
# data.
# TODO: commented out for now for tests to pass
# if self.is_choice_only:
# return
if self.is_choice_only:
return
# In the case where missing_data is set to False, we need to drop the
# cases where rt = na_value
if pd.isna(self.missing_data_value):
Expand Down
52 changes: 37 additions & 15 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def make_hssm_rv(
simulator_fun: Callable | str,
list_params: list[str],
lapse: bmb.Prior | None = None,
is_choice_only: bool = False,
) -> type[RandomVariable]:
"""Build a RandomVariable Op according to the list of parameters.

Expand All @@ -202,6 +203,8 @@ def make_hssm_rv(
A list of str of all parameters for this `RandomVariable`.
lapse : optional
A bmb.Prior object representing the lapse distribution.
is_choice_only : bool
Whether the model is a choice-only model.

Returns
-------
Expand All @@ -225,7 +228,12 @@ class HSSMRV(RandomVariable):
# parameter is a scalar. The string to the right of the
# `->` sign describes the output signature, which is `(2)`, which means the
# random variable is a length-2 array.
signature: str = f"{','.join(['()'] * len(list_params))}->({obs_dim_int})"

# Override the output from ssm_simulator based on whether the model is
# choice-only.
output = "()" if is_choice_only else f"({obs_dim_int})"
signature: str = f"{','.join(['()'] * len(list_params))}->{output}"

dtype: str = "floatX"
_print_name: tuple[str, str] = ("SSM", "\\operatorname{SSM}")
_list_params = list_params
Expand Down Expand Up @@ -385,10 +393,11 @@ def make_distribution(
loglik: LogLikeFunc | pytensor.graph.Op,
list_params: list[str],
bounds: dict | None = None,
lapse: bmb.Prior | None = None,
lapse: float | bmb.Prior | None = None,
extra_fields: list[np.ndarray] | None = None,
fixed_vector_params: dict[str, np.ndarray] | None = None,
params_is_trialwise: list[bool] | None = None,
is_choice_only: bool = False,
) -> type[pm.Distribution]:
"""Make a `pymc.Distribution`.

Expand Down Expand Up @@ -418,7 +427,7 @@ def make_distribution(
A dictionary with parameters as keys (a string) and its boundaries as values.
Example: {"parameter": (lower_boundary, upper_boundary)}.
lapse : optional
A bmb.Prior object representing the lapse distribution.
A float or bmb.Prior object representing the lapse distribution.
extra_fields : optional
An optional list of arrays that are stored in the class created and will be
used in likelihood calculation. Defaults to None.
Expand All @@ -437,6 +446,8 @@ def make_distribution(
that vmapped JAX log-likelihoods receive consistently shaped inputs,
regardless of whether Bambi produces ``(1,)`` or ``(n_obs,)`` tensors.
When ``None``, no graph-level broadcasting is applied.
is_choice_only : optional
Whether the model is a choice-only model.

Returns
-------
Expand Down Expand Up @@ -481,15 +492,18 @@ def make_distribution(
if list_params[-1] != "p_outlier":
list_params.append("p_outlier")

data_vector = pt.dvector()
lapse_logp = pm.logp(
get_distribution_from_prior(lapse).dist(**lapse.args),
data_vector,
)
lapse_func = pytensor.function(
[data_vector],
lapse_logp,
)
if isinstance(lapse, float):
lapse_func = lambda data: np.full_like(data, lapse)
else:
data_vector = pt.dvector()
lapse_logp = pm.logp(
get_distribution_from_prior(lapse).dist(**lapse.args),
data_vector,
)
lapse_func = pytensor.function(
[data_vector],
lapse_logp,
)
else:
lapse_func = None

Expand Down Expand Up @@ -561,12 +575,15 @@ def logp(data, *dist_params): # pylint: disable=E0213
"lapse_func is not defined. "
"Make sure lapse is properly initialized."
)
lapse_logp = lapse_func(data[:, 0].eval())
data_for_lapse = data if is_choice_only else data[:, 0]
lapse_logp = lapse_func(data_for_lapse.eval())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be careful here.

The lapse function logic we have is based around rt lapses, so somethine like Uniform(0, 20), which then evaluates to 1/20 for each rt, makes sense.

If a model is choice_only we want a different default lapse distribution, I would suggest simply 1/n_choices.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexanderFengler I have made the updates to fix this. Can you take another look?

BTW: this is an important implementation detail. In the future, can this type of details be communicated to us early? It's much easier to plan for these details early

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@digicosmos86 the truth is that I didn't think about lapse distributions in this context before that's why I didn't mention it earlier.

When I saw it I realized this needs to be figured out correctly.


# AF-TODO potentially apply clipping here
logp = loglik(data, *dist_params, *extra_fields)
# Ensure that non-decision time is always smaller than rt.
# Assuming that the non-decision time parameter is always named "t".
logp = ensure_positive_ndt(data, logp, list_params, dist_params)
if not is_choice_only:
logp = ensure_positive_ndt(data, logp, list_params, dist_params)
logp = pt.log(
(1.0 - p_outlier) * pt.exp(logp)
+ p_outlier * pt.exp(lapse_logp)
Expand All @@ -575,7 +592,8 @@ def logp(data, *dist_params): # pylint: disable=E0213
else:
logp = loglik(data, *dist_params, *extra_fields)
# Ensure that non-decision time is always smaller than rt.
logp = ensure_positive_ndt(data, logp, list_params, dist_params)
if not is_choice_only:
logp = ensure_positive_ndt(data, logp, list_params, dist_params)

if bounds is not None:
logp = apply_param_bounds_to_loglik(
Expand All @@ -593,6 +611,7 @@ def make_distribution_for_supported_model(
backend: Literal["pytensor", "jax", "other"] = "pytensor",
reg_params: list[str] | None = None,
lapse: bmb.Prior | None = None,
is_choice_only: bool = False,
) -> type[pm.Distribution]:
"""Make a pm.Distribution class for a supported model.

Expand All @@ -614,6 +633,8 @@ class that can be used for PyMC modeling.
parameters are assumed.
lapse : optional
A bmb.Prior object representing the lapse distribution.
is_choice_only : optional
Whether the model is a choice-only model.
"""
supported_models = get_args(SupportedModels)
if model not in supported_models:
Expand Down Expand Up @@ -643,6 +664,7 @@ class that can be used for PyMC modeling.
list_params=config.list_params,
bounds=config.bounds,
lapse=lapse,
is_choice_only=is_choice_only,
)


Expand Down
Loading