|
25 | 25 | ) |
26 | 26 | from typing_extensions import Self |
27 | 27 |
|
| 28 | +from ..v1.distributions import * |
28 | 29 | from ..v1.lint import is_valid_identifier |
29 | 30 | from ..v1.math import petab_math_str, sympify_petab |
30 | 31 | from . import C, get_observable_df |
@@ -150,6 +151,26 @@ class PriorDistribution(str, Enum): |
150 | 151 | f"{set(C.PRIOR_DISTRIBUTIONS)} vs { {e.value for e in PriorDistribution} }" |
151 | 152 | ) |
152 | 153 |
|
| 154 | +_prior_to_cls = { |
| 155 | + PriorDistribution.CAUCHY: Cauchy, |
| 156 | + PriorDistribution.CHI_SQUARED: ChiSquare, |
| 157 | + PriorDistribution.EXPONENTIAL: Exponential, |
| 158 | + PriorDistribution.GAMMA: Gamma, |
| 159 | + PriorDistribution.LAPLACE: Laplace, |
| 160 | + PriorDistribution.LOG10_NORMAL: Normal, |
| 161 | + PriorDistribution.LOG_LAPLACE: Laplace, |
| 162 | + PriorDistribution.LOG_NORMAL: Normal, |
| 163 | + PriorDistribution.LOG_UNIFORM: Uniform, |
| 164 | + PriorDistribution.NORMAL: Normal, |
| 165 | + PriorDistribution.RAYLEIGH: Rayleigh, |
| 166 | + PriorDistribution.UNIFORM: Uniform, |
| 167 | +} |
| 168 | + |
| 169 | +assert not (_mismatch := set(PriorDistribution) ^ set(_prior_to_cls)), ( |
| 170 | + "PriorDistribution enum does not match _prior_to_cls. " |
| 171 | + f"Mismatches: {_mismatch}" |
| 172 | +) |
| 173 | + |
153 | 174 |
|
154 | 175 | class Observable(BaseModel): |
155 | 176 | """Observable definition.""" |
@@ -929,6 +950,37 @@ def _validate(self) -> Self: |
929 | 950 |
|
930 | 951 | return self |
931 | 952 |
|
| 953 | + @property |
| 954 | + def prior_dist(self) -> Distribution: |
| 955 | + """Get the pior distribution of the parameter.""" |
| 956 | + if self.estimate is False: |
| 957 | + raise ValueError(f"Parameter `{self.id}' is not estimated.") |
| 958 | + |
| 959 | + if self.prior_distribution is None: |
| 960 | + return Uniform(self.lb, self.ub) |
| 961 | + |
| 962 | + if not (cls := _prior_to_cls.get(self.prior_distribution)): |
| 963 | + raise ValueError( |
| 964 | + f"Prior distribution `{self.prior_distribution}' not " |
| 965 | + "supported." |
| 966 | + ) |
| 967 | + |
| 968 | + if str(self.prior_distribution).startswith("log-"): |
| 969 | + log = True |
| 970 | + elif str(self.prior_distribution).startswith("log10-"): |
| 971 | + log = 10 |
| 972 | + else: |
| 973 | + log = False |
| 974 | + |
| 975 | + if cls == Exponential: |
| 976 | + if log is not False: |
| 977 | + raise ValueError( |
| 978 | + "Exponential distribution does not support log " |
| 979 | + "transformation." |
| 980 | + ) |
| 981 | + return cls(*self.prior_parameters, trunc=[self.lb, self.ub]) |
| 982 | + return cls(*self.prior_parameters, log=log, trunc=[self.lb, self.ub]) |
| 983 | + |
932 | 984 |
|
933 | 985 | class ParameterTable(BaseModel): |
934 | 986 | """PEtab parameter table.""" |
|
0 commit comments