diff --git a/src/chapkit/api/service_builder.py b/src/chapkit/api/service_builder.py index c3a76ed..0f9a59c 100644 --- a/src/chapkit/api/service_builder.py +++ b/src/chapkit/api/service_builder.py @@ -54,6 +54,8 @@ class MLServiceInfo(ServiceInfo): organization: str | None = None organization_logo_url: HttpUrl | None = None citation_info: str | None = None + allow_free_additional_continuous_covariates: bool = False + required_covariates: List[str] = field(default_factory=list) @dataclass(slots=True) diff --git a/src/chapkit/modules/ml/runner.py b/src/chapkit/modules/ml/runner.py index ad10e3d..17439e3 100644 --- a/src/chapkit/modules/ml/runner.py +++ b/src/chapkit/modules/ml/runner.py @@ -122,6 +122,11 @@ async def on_train( config_file = temp_dir / "config.json" config_file.write_text(json.dumps(config.model_dump(), indent=2)) + # Write config to YAML + config_file_yaml = temp_dir / "config.yaml" + config_file_yaml.write_text(config.model_dump_yaml()) + + # Write training data to CSV data_file = temp_dir / "data.csv" data.to_csv(data_file, index=False) @@ -137,7 +142,7 @@ async def on_train( # Substitute variables in command command = self.train_command.format( - config_file=str(config_file), + config_file=str(config_file_yaml), data_file=str(data_file), model_file=str(model_file), geo_file=str(geo_file) if geo_file else "", @@ -194,6 +199,9 @@ async def on_predict( config_file = temp_dir / "config.json" config_file.write_text(json.dumps(config.model_dump(), indent=2)) + config_file_yaml = temp_dir / "config.yaml" + config_file_yaml.write_text(config.model_dump_yaml()) + # Write model to file model_file = temp_dir / f"model.{self.model_format}" with open(model_file, "wb") as f: @@ -220,7 +228,7 @@ async def on_predict( # Substitute variables in command command = self.predict_command.format( - config_file=str(config_file), + config_file=str(config_file_yaml), model_file=str(model_file), historic_file=str(historic_file) if historic_file else "", future_file=str(future_file),