Skip to content

Commit 03db394

Browse files
committed
linter allow non-yaml network formats
1 parent 9a5eb5b commit 03db394

3 files changed

Lines changed: 30 additions & 14 deletions

File tree

petab/v2/extensions/sciml.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -272,21 +272,22 @@ def from_config(
272272
sciml_config: SciMLConfig = config.extensions[C.EXT_ID_SCIML]
273273

274274
# Neural network classes are constructed via pytorch for now to get
275-
# the proper inputs
276-
neural_networks = [
277-
NNModel.from_pytorch_module(
278-
NNModelStandard.load_data(
279-
_generate_path(
280-
file_path=nn_config.location,
281-
base_path=base_path,
275+
# the proper inputs. Non-YAML formats are opaque — the file is assumed
276+
# to contain a valid model and is not read here.
277+
neural_networks = []
278+
for nn_id, nn_config in (sciml_config.neural_networks or {}).items():
279+
if nn_config.format.lower() == "yaml":
280+
neural_networks.append(
281+
NNModel.from_pytorch_module(
282+
NNModelStandard.load_data(
283+
_generate_path(
284+
file_path=nn_config.location,
285+
base_path=base_path,
286+
)
287+
).to_pytorch_module(),
288+
nn_model_id=nn_id,
282289
)
283-
).to_pytorch_module(),
284-
nn_model_id=nn_id,
285-
)
286-
for nn_id, nn_config in (
287-
sciml_config.neural_networks or {}
288-
).items()
289-
]
290+
)
290291

291292
hybridization_tables = [
292293
HybridizationTable.from_tsv(f, base_path)

petab/v2/extensions/sciml_lint.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ def run(self, problem: core.Problem) -> lint.ValidationIssue | None:
1616
condition_targets = {
1717
c.target_id for ct in problem.conditions for c in ct.changes
1818
}
19+
# Only YAML-format networks are loaded as NNModel objects
1920
nn_input_ids = {
2021
inp.input_id
2122
for nn in problem.extensions.sciml.neural_networks
23+
if hasattr(nn, "inputs")
2224
for inp in nn.inputs
2325
}
2426
hyb_target_ids = {

tests/v2/test_sciml.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,16 @@ def _get_test_problem():
138138
def test_lint():
139139
problem = _get_test_problem()
140140
assert problem.validate() == []
141+
142+
143+
def test_lint_equinox_network_format():
144+
"""Linter accepts non-YAML formats without reading the network file."""
145+
problem = _get_test_problem()
146+
# Replace the YAML network config with equinox format
147+
sciml_cfg = problem.config.extensions["sciml"]
148+
sciml_cfg.neural_networks["net1"] = NeuralNetConfig(
149+
location="net1.py",
150+
pre_initialization=False,
151+
format="equinox",
152+
)
153+
assert problem.validate() == []

0 commit comments

Comments
 (0)