diff --git a/src/cli/example.rs b/src/cli/example.rs index 96b8ead29..1132d7153 100644 --- a/src/cli/example.rs +++ b/src/cli/example.rs @@ -100,10 +100,13 @@ fn handle_example_extract_command(name: &str, dest: Option<&Path>, patch: bool) /// If `patch` is `true`, then the corresponding patched example will be extracted. fn extract_example(name: &str, patch: bool, dest: &Path) -> Result<()> { if patch { - let (file_patches, toml_patch) = get_patches(name)?; - - // NB: All patched models are based on `simple`, for now - let example = Example::from_name("simple").unwrap(); + let patch_info = get_patches(name)?; + let (base_example, file_patches, toml_patch) = ( + &patch_info.base_example, + &patch_info.file_patches, + &patch_info.toml_patch, + ); + let example = Example::from_name(base_example)?; // First extract the example to a temp dir let example_tmp = TempDir::new().context("Failed to create temporary directory")?; diff --git a/src/example/patches.rs b/src/example/patches.rs index c25ab3906..393e0fd26 100644 --- a/src/example/patches.rs +++ b/src/example/patches.rs @@ -4,9 +4,44 @@ use crate::patch::FilePatch; use anyhow::{Context, Result}; use std::{collections::BTreeMap, sync::LazyLock}; +/// Holds patch information for a patched example model +pub struct PatchInfo { + /// The base example to patch + pub base_example: &'static str, + /// File patches to apply to the base example + pub file_patches: Vec, + /// An optional TOML patch to apply to the base example + pub toml_patch: Option<&'static str>, +} +impl PatchInfo { + /// Create a new `PatchInfo` with the specified base example, file patches, and optional TOML patch + pub fn new( + base_example: &'static str, + file_patches: Vec, + toml_patch: Option<&'static str>, + ) -> Self { + Self { + base_example, + file_patches, + toml_patch, + } + } + /// Create a new `PatchInfo` with the specified base example, file patches, and TOML patch + pub fn new_with_toml_patch( + base_example: &'static str, + file_patches: Vec, + toml_patch: &'static str, + ) -> Self { + Self { + base_example, + file_patches, + toml_patch: Some(toml_patch), + } + } +} /// Map of patches keyed by name, with the file patches and an optional TOML patch -type PatchMap = BTreeMap<&'static str, (Vec, Option<&'static str>)>; +type PatchMap = BTreeMap<&'static str, PatchInfo>; /// The patches, keyed by name static PATCHES: LazyLock = LazyLock::new(get_all_patches); @@ -17,7 +52,8 @@ fn get_all_patches() -> PatchMap { // The simple example with gas boiler process made divisible ( "simple_divisible", - ( + PatchInfo::new( + "simple", vec![ FilePatch::new("processes.csv") .with_deletion("RGASBR,Gas boiler,all,RSHEAT,2020,2040,1.0,") @@ -29,7 +65,8 @@ fn get_all_patches() -> PatchMap { // The simple example with objective type set to NPV for one agent ( "simple_npv", - ( + PatchInfo::new( + "simple", vec![ FilePatch::new("agent_objectives.csv") .with_deletion("A0_RES,all,lcox,,") @@ -41,7 +78,8 @@ fn get_all_patches() -> PatchMap { ( // The simple example with electricity priced using marginal costs "simple_marginal", - ( + PatchInfo::new( + "simple", vec![FilePatch::new("commodities.csv").with_replacement(&[ "id,description,type,time_slice_level,pricing_strategy,units", "GASPRD,Gas produced,sed,season,shadow,PJ", @@ -56,7 +94,8 @@ fn get_all_patches() -> PatchMap { ( // The simple example with gas commodities priced using full costs "simple_full", - ( + PatchInfo::new( + "simple", vec![FilePatch::new("commodities.csv").with_replacement(&[ "id,description,type,time_slice_level,pricing_strategy,units", "GASPRD,Gas produced,sed,season,full,PJ", @@ -71,7 +110,8 @@ fn get_all_patches() -> PatchMap { ( // The simple example with electricity priced using average marginal costs "simple_marginal_average", - ( + PatchInfo::new( + "simple", vec![FilePatch::new("commodities.csv").with_replacement(&[ "id,description,type,time_slice_level,pricing_strategy,units", "GASPRD,Gas produced,sed,season,shadow,PJ", @@ -86,7 +126,8 @@ fn get_all_patches() -> PatchMap { ( // The simple example with gas commodities priced using average full costs "simple_full_average", - ( + PatchInfo::new( + "simple", vec![FilePatch::new("commodities.csv").with_replacement(&[ "id,description,type,time_slice_level,pricing_strategy,units", "GASPRD,Gas produced,sed,season,full_average,PJ", @@ -101,7 +142,7 @@ fn get_all_patches() -> PatchMap { // The simple example with the ironing-out loop turned on ( "simple_ironing_out", - (vec![], Some("max_ironing_out_iterations = 10")), + PatchInfo::new("simple", vec![], Some("max_ironing_out_iterations = 10")), ), ] .into_iter() @@ -114,7 +155,7 @@ pub fn get_patch_names() -> impl Iterator { } /// Get patches for the named patched example -pub fn get_patches(name: &str) -> Result<&'static (Vec, Option<&'static str>)> { +pub fn get_patches(name: &str) -> Result<&'static PatchInfo> { PATCHES .get(name) .with_context(|| format!("Patched example '{name}' not found")) diff --git a/src/fixture.rs b/src/fixture.rs index cabbf410a..0cefa572b 100644 --- a/src/fixture.rs +++ b/src/fixture.rs @@ -47,7 +47,14 @@ pub(crate) use assert_error; /// /// If the patched model cannot be built, for whatever reason, this function will panic. pub(crate) fn build_patched_simple_tempdir(file_patches: Vec) -> tempfile::TempDir { - ModelPatch::from_example("simple") + build_patched_tempdir("simple", file_patches) +} + +pub(crate) fn build_patched_tempdir( + base_example: &str, + file_patches: Vec, +) -> tempfile::TempDir { + ModelPatch::from_example(base_example) .with_file_patches(file_patches) .build_to_tempdir() .unwrap()