From 70da5df33124ccc2168deb40630a681ece71aeba Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 12 Sep 2025 16:20:48 -0400 Subject: [PATCH 01/40] Rewrite __main__/cli so argparse uses sub-parsers To add a `configfile` command, we can no longer assume that we have just two arguments, `command` and `configfile`. Instead we need to move to sub-parsers, even though must of the sub-parsers will just be the same named commands as before, and will only have a single positional argument of `configfile`. As of now, only the new `configfile` command will be different, although I haven't figured out what its args will be yet. To do this I ended up representing cli commands as a dataclass inside the `vak.cli.cli` module, with attributes `name`, `help`, `func`, and `add_parser_args_func`. I am choosing to send each sub-parser to `dest='command'`, and then dispatching based on the value of `command` inside the `vak.cli.cli.cli` function, instead of letting `argparse` do all the work, because I want to be able to raise a somewhat friendly error message. Maybe this is not needed since argparse validates arguments anyways. But this requires the least rewiring of what we already had in place. --- src/vak/__main__.py | 31 +--------- src/vak/cli/cli.py | 141 ++++++++++++++++++++++++++++++++++++-------- 2 files changed, 119 insertions(+), 53 deletions(-) diff --git a/src/vak/__main__.py b/src/vak/__main__.py index a25d3f833..35a4bfd69 100644 --- a/src/vak/__main__.py +++ b/src/vak/__main__.py @@ -2,38 +2,9 @@ Invokes __main__ when the module is run as a script. Example: python -m vak --help """ - -import argparse -from pathlib import Path - from .cli import cli -def get_parser(): - """returns ArgumentParser instance used by main()""" - parser = argparse.ArgumentParser( - prog="vak", - description="vak command-line interface", - formatter_class=argparse.RawTextHelpFormatter, - ) - parser.add_argument( - "command", - type=str, - metavar="command", - choices=cli.CLI_COMMANDS, - help="Command to run, valid options are:\n" - f"{cli.CLI_COMMANDS}\n" - "$ vak train ./configs/config_2018-12-17.toml", - ) - parser.add_argument( - "configfile", - type=Path, - help="name of config.toml file to use \n" - "$ vak train ./configs/config_2018-12-17.toml", - ) - return parser - - def main(args=None): """Main function called when run as script or through command-line interface @@ -44,7 +15,7 @@ def main(args=None): ``args`` is used for unit testing only """ if args is None: - parser = get_parser() + parser = cli.get_parser() args = parser.parse_args() cli.cli(command=args.command, config_file=args.configfile) diff --git a/src/vak/cli/cli.py b/src/vak/cli/cli.py index d6d2eaca3..efbd5dcb6 100644 --- a/src/vak/cli/cli.py +++ b/src/vak/cli/cli.py @@ -1,46 +1,141 @@ -def eval(toml_path): +"""Implements the vak command-line interface""" +import argparse +from dataclasses import dataclass +from pathlib import Path +from typing import Callable + + +def eval(args): from .eval import eval - eval(toml_path=toml_path) + eval(toml_path=args.config_file) -def train(toml_path): +def train(args): from .train import train - train(toml_path=toml_path) + train(toml_path=args.config_file) -def learncurve(toml_path): +def learncurve(args): from .learncurve import learning_curve - learning_curve(toml_path=toml_path) + learning_curve(toml_path=args.config_file) -def predict(toml_path): +def predict(args): from .predict import predict - predict(toml_path=toml_path) + predict(toml_path=args.config_file) -def prep(toml_path): +def prep(args): from .prep import prep - prep(toml_path=toml_path) + prep(toml_path=args.config_file) -COMMAND_FUNCTION_MAP = { - "prep": prep, - "train": train, - "eval": eval, - "predict": predict, - "learncurve": learncurve, +@dataclass +class CLICommand: + """Dataclass representing a cli command + + Attributes + ---------- + name : str + Name of the command, that gets added to the CLI as a sub-parser + help : str + Help for the command, that gets added to the CLI as a sub-parser + func : Callable + Function to call for command + add_parser_args_func: Callable + Function to call to add arguments to sub-parser representing command + """ + name: str + help: str + func: Callable + add_parser_args_func : Callable + + +def add_configfile_arg( + cli_command, + cli_command_parser +): + cli_command_parser.add_argument( + "configfile", + type=Path, + help="name of TOML configuration file to use \n" + f"$ vak {cli_command.name} ./configs/config_rat01337.toml", + ) + + +CLI_COMMANDS = [ + CLICommand( + name='prep', + help='prepare a dataset', + func=prep, + add_parser_args_func=add_configfile_arg, + ), + CLICommand( + name='train', + help='train a model', + func=train, + add_parser_args_func=add_configfile_arg, + ), + CLICommand( + name='eval', + help='evaluate a trained model', + func=eval, + add_parser_args_func=add_configfile_arg, + ), + CLICommand( + name='predict', + help='generate predictions from trained model', + func=predict, + add_parser_args_func=add_configfile_arg, + ), + CLICommand( + name='learncurve', + help='run a learning curve', + func=learncurve, + add_parser_args_func=add_configfile_arg, + ), +] + + +def get_parser(): + """returns ArgumentParser instance used by main()""" + parser = argparse.ArgumentParser( + prog="vak", + description="vak command-line interface", + formatter_class=argparse.RawTextHelpFormatter, + ) + + # create sub-parser + sub_parsers = parser.add_subparsers( + help='Commands for vak command-line interface', + dest="command", + ) + + for cli_command in CLI_COMMANDS: + cli_command_parser = sub_parsers.add_parser( + cli_command.name, + help=cli_command.help + ) + cli_command.add_parser_args_func( + cli_command, + cli_command_parser + ) + + return parser + + +CLI_COMMAND_FUNCTION_MAP = { + cli_command.name: cli_command.func + for cli_command in CLI_COMMANDS } -CLI_COMMANDS = tuple(COMMAND_FUNCTION_MAP.keys()) - - -def cli(command, config_file): +def cli(args): """Execute the commands of the command-line interface. Parameters @@ -50,7 +145,7 @@ def cli(command, config_file): config_file : str, Path path to a config.toml file """ - if command in COMMAND_FUNCTION_MAP: - COMMAND_FUNCTION_MAP[command](toml_path=config_file) + if args.command in CLI_COMMAND_FUNCTION_MAP: + CLI_COMMAND_FUNCTION_MAP[args.command](args) else: - raise ValueError(f"command not recognized: {command}") + raise ValueError(f"command not recognized: {args.command}") From 71c1a2ca1fe99c63a798094b22895e9aa27a864c Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 12 Sep 2025 23:12:18 -0400 Subject: [PATCH 02/40] Add 'template' TOML configs in src/vak/config --- src/vak/config/config_learncurve.toml | 106 +++++++++++++++++++++++++ src/vak/config/configfile_eval.toml | 86 ++++++++++++++++++++ src/vak/config/configfile_predict.toml | 80 +++++++++++++++++++ src/vak/config/configfile_train.toml | 101 +++++++++++++++++++++++ 4 files changed, 373 insertions(+) create mode 100644 src/vak/config/config_learncurve.toml create mode 100644 src/vak/config/configfile_eval.toml create mode 100644 src/vak/config/configfile_predict.toml create mode 100644 src/vak/config/configfile_train.toml diff --git a/src/vak/config/config_learncurve.toml b/src/vak/config/config_learncurve.toml new file mode 100644 index 000000000..825a8d54a --- /dev/null +++ b/src/vak/config/config_learncurve.toml @@ -0,0 +1,106 @@ +# [vak.prep]: options for preparing dataset +[vak.prep] +# dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" +dataset_type = "frame classification" +# input_type: input to model, either audio ("audio") or spectrogram ("spect") +input_type = "spect" +# data_dir: directory with data to use when preparing dataset +data_dir = "/Users/davidnicholson/Documents/repos/vocalpy/vak/tests/scripts/vaktestdata/../../data_for_tests/generated/spect-output-dir/audio_cbin_annot_notmat/gy6or6/032312" +# output_dir: directory where dataset will be created (as a sub-directory within output_dir) +output_dir = "./tests/data_for_tests/generated/prep/learncurve/audio_cbin_annot_notmat/TweetyNet" +# audio_format: format of audio, either wav or cbin +spect_format = "npz" +# annot_format: format of annotations +annot_format = "notmat" +# labelset: string or array with unique set of labels used in annotations +labelset = "iabcdefghjk" +# train_dur: duration of training split in dataset, in seconds +train_dur = 50 +# val_dur: duration of validation split in dataset, in seconds +val_dur = 15 +# test_dur: duration of test split in dataset, in seconds +test_dur = 30 +train_set_durs = [ 4, 6,] +num_replicates = 2 + +# [vak.prep.spect_params]: parameters for computing spectrograms +[vak.prep.spect_params] +# fft_size: size of window used for Fast Fourier Transform, in number of samples +fft_size = 512 +# step_size: size of step to take when computing spectra with FFT for spectrogram +# also known as hop size +step_size = 64 +# qualitatively, we find that log transforming the spectrograms improves performance; +# think of this as increasing the contrast between high power and low power regions +transform_type = "log_spect" +# specifying cutoff frequencies of the spectrogram can (1) make the model more +# computationally efficient and (2) improve performance by only fitting the model +# to parts of the spectrum that are relevant for sounds of interest. +# Note these cutoffs are applied by computing the whole spectrogram first +# and then throwing away frequencies above and below the cutoffs; +# we do not apply a bandpass filter to the audio. +freq_cutoffs = [ 500, 10000,] +# Note that for the TweetyNet model, the default is to set the hidden_size of the RNN +# equal to the input_size, so if you reduce the size of the spectrogram, this will reduce the +# hidden size of the RNN. If you observe impaired performance of TweetyNet after applying the frequency cutoffs, +# consider manually specifying a larger hidden (see `[vak.train.model.TweetyNet]` table below). + +# learncurve: options for running the learning curve +# that estimates model performance +# as a function of training set size +[vak.learncurve] +# root_results_dir: directory where results should be saved, as a sub-directory within `root_results_dir` +root_results_dir = "./tests/data_for_tests/generated/results/learncurve/audio_cbin_annot_notmat/TweetyNet" +# batch_size: number of samples from dataset per batch fed into network +batch_size = 11 +# num_epochs: number of training epochs, where an epoch is one iteration through all samples in training split +num_epochs = 2 +# standardize_frames: if true, standardize (normalize) frames (input to neural network) per frequency bin, so mean of each is 0.0 and std is 1.0 +# across the entire training split +standardize_frames = true +# val_step: step number on which to compute metrics with validation set, every time step % val_step == 0 +# (a step is one batch fed through the network) +# saves a checkpoint if the monitored evaluation metric improves (which is model specific) +val_step = 50 +# ckpt_step: step number on which to save a checkpoint (as a backup, regardless of validation metrics) +ckpt_step = 200 +# patience: number of validation steps to wait before stopping training early +# if the monitored evaluation metrics does not improve after `patience` validation steps, +# then we stop training +patience = 4 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 16 + +[vak.learncurve.post_tfm_kwargs] +majority_vote = true +min_segment_dur = 0.02 + +[vak.learncurve.dataset] +# params : parameters that configure the `vak.datapipes` or `vak.datasets` class +# for a frame classification model, we use dataset classes with a specific `window_size` +# Bigger windows work better. +# For frame classification models, prefer smaller batch sizes with bigger windows +# Intuitively, bigger windows give the model more "contexts" for each frame per batch. +# See https://github.com/vocalpy/Nicholson-Cohen-SfN-2023-poster for more detail +params = { window_size = 88 } +# path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it + +# TweetyNet.network: we specify options for the model's network in this table +# To indicate the model to train, we use a "dotted key" with `model` followed by the string name of the model. +# This name must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` +# We use another dotted key to indicate options for configuring the model, e.g. `TweetyNet.optimizer` +[vak.train.model.TweetyNet.optimizer] +# vak.train.model.TweetyNet.optimizer: we specify options for the model's optimizer in this table +# lr: the learning rate +lr = 0.001 + +[vak.learncurve.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +hidden_size = 256 + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.learncurve.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/configfile_eval.toml b/src/vak/config/configfile_eval.toml new file mode 100644 index 000000000..0f6bad0b4 --- /dev/null +++ b/src/vak/config/configfile_eval.toml @@ -0,0 +1,86 @@ +[vak.prep] +# dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" +dataset_type = "frame classification" +# input_type: input to model, either audio ("audio") or spectrogram ("spect") +input_type = "spect" +# data_dir: directory with data to use when preparing dataset +data_dir = "/PATH/TO/FOLDER/gy6or6/032212" +# output_dir: directory where dataset will be created (as a sub-directory within output_dir) +output_dir = "/PATH/TO/FOLDER/prep/train" +# audio_format: format of audio, either wav or cbin +audio_format = "wav" +# annot_format: format of annotations +annot_format = "simple-seq" +# labelset: string or array with unique set of labels used in annotations +labelset = "iabcdefghjk" +# train_dur: duration of training split in dataset, in seconds +train_dur = 50 +# val_dur: duration of validation split in dataset, in seconds +val_dur = 15 + +# SPECT_PARAMS: parameters for computing spectrograms +[vak.prep.spect_params] +# fft_size: size of window used for Fast Fourier Transform, in number of samples +fft_size = 512 +# step_size: size of step to take when computing spectra with FFT for spectrogram +# also known as hop size +step_size = 64 + +# EVAL: options for evaluating a trained model. This is done using the "test" split. +[vak.eval] +# checkpoint_path: path to saved model checkpoint +checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" +# labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations; +# this is used when generating predictions +labelmap_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/labelmap.json" +# frames_standardizer_path: path to file containing SpectScaler that was fit to training set +# We want to transform the data we predict on in the exact same way +frames_standardizer_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/StandardizeSpect" +# batch_size +# for predictions with a frame classification model, this should always be 1 +# and will be ignored if it's not +batch_size = 11 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 16 +# device: name of device to run model on, one of "cuda", "cpu" + +# output_dir: directory where output should be saved, as a sub-directory within `output_dir` +output_dir = "/PATH/TO/FOLDER/results/eval" +# dataset_path : path to dataset created by prep +# ADD THE dataset_path OPTION FROM THE TRAIN FILE HERE (we already created a test split when we ran `vak prep` with that config) + +# EVAL.post_tfm_kwargs: options for post-processing +[vak.eval.post_tfm_kwargs] +# both these transforms require that there is an "unlabeled" label, +# and they will only be applied to segments that are bordered on both sides +# by the "unlabeled" label. +# Such a label class is added by default by vak. +# majority_vote: post-processing transformation that takes majority vote within segments that +# do not have the 'unlabeled' class label. Only applied if `majority_vote` is `true` +# (default is false). +majority_vote = true +# min_segment_dur: post-processing transformation removes any segments +# with a duration shorter than `min_segment_dur` that do not have the 'unlabeled' class. +# Only applied if this option is specified. +min_segment_dur = 0.02 + +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.eval.dataset] +path = "/copy/path/from/train/config/here" +params = { window_size = 176 } + +# We put this table though vak knows which model we are using +[vak.eval.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +# we trained with hidden size = 256 so we need to evaluate with the same hidden size; +# otherwise we'll get an error about "shapes do not match" when torch tries to load the checkpoint +hidden_size = 256 + + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.eval.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/configfile_predict.toml b/src/vak/config/configfile_predict.toml new file mode 100644 index 000000000..b82cf048c --- /dev/null +++ b/src/vak/config/configfile_predict.toml @@ -0,0 +1,80 @@ +# PREP: options for preparing dataset +[vak.prep] +# dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" +dataset_type = "frame classification" +# input_type: input to model, either audio ("audio") or spectrogram ("spect") +input_type = "spect" +# data_dir: directory with data to use when preparing dataset +data_dir = "/PATH/TO/FOLDER/gy6or6/032312" +# output_dir: directory where dataset will be created (as a sub-directory within output_dir) +output_dir = "/PATH/TO/FOLDER/prep/predict" +# audio_format: format of audio, either wav or cbin +audio_format = "wav" +# note that for predictions we don't need to specify labelset or annot_format +# note also that we do not specify train_dur / val_dur / test_dur; +# all data found in `data_dir` will be assigned to a "predict split" instead + +# SPECT_PARAMS: parameters for computing spectrograms +[vak.prep.spect_params] +# fft_size: size of window used for Fast Fourier Transform, in number of samples +fft_size = 512 +# step_size: size of step to take when computing spectra with FFT for spectrogram +# also known as hop size +step_size = 64 + +# PREDICT: options for generating predictions with a trained model +[vak.predict] +# checkpoint_path: path to saved model checkpoint +checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" +# labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations; +# this is used when generating predictions +labelmap_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/labelmap.json" +# frames_standardizer_path: path to file containing SpectScaler that was fit to training set +# We want to transform the data we predict on in the exact same way +frames_standardizer_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/StandardizeSpect" +# batch_size +# for predictions with a frame classification model, this should always be 1 +# and will be ignored if it's not +batch_size = 1 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 4 +# device: name of device to run model on, one of "cuda", "cpu" + +# output_dir: directory where output should be saved, as a sub-directory within `output_dir` +output_dir = "/PATH/TO/FOLDER/results/predict" +# annot_csv_filename +annot_csv_filename = "gy6or6.032312.annot.csv" +# The next two options are for post-processing transforms. +# Both these transforms require that there is an "unlabeled" label, +# and they will only be applied to segments that are bordered on both sides +# by the "unlabeled" label. +# Such a label class is added by default by vak. +# majority_vote: post-processing transformation that takes majority vote within segments that +# do not have the 'unlabeled' class label. Only applied if `majority_vote` is `true` +# (default is false). +majority_vote = true +# min_segment_dur: post-processing transformation removes any segments +# with a duration shorter than `min_segment_dur` that do not have the 'unlabeled' class. +# Only applied if this option is specified. +min_segment_dur = 0.01 +# dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it + +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.predict.dataset] +path = "/copy/path/from/train/config/here" +params = { window_size = 176 } + +# We put this table though vak knows which model we are using +[vak.predict.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +# we trained with hidden size = 256 so we need to evaluate with the same hidden size; +# otherwise we'll get an error about "shapes do not match" when torch tries to load the checkpoint +hidden_size = 256 + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.predict.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/configfile_train.toml b/src/vak/config/configfile_train.toml new file mode 100644 index 000000000..1aa0dc45f --- /dev/null +++ b/src/vak/config/configfile_train.toml @@ -0,0 +1,101 @@ +# [vak.prep]: options for preparing dataset +[vak.prep] +# dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" +dataset_type = "frame classification" +# input_type: input to model, either audio ("audio") or spectrogram ("spect") +input_type = "spect" +# data_dir: directory with data to use when preparing dataset +data_dir = "/PATH/TO/FOLDER/gyor6/032212" +# output_dir: directory where dataset will be created (as a sub-directory within output_dir) +output_dir = "/PATH/TO/FOLDER/prep/train" +# audio_format: format of audio, either wav or cbin +audio_format = "wav" +# annot_format: format of annotations +annot_format = "simple-seq" +# labelset: string or array with unique set of labels used in annotations +labelset = "iabcdefghjk" +# train_dur: duration of training split in dataset, in seconds +train_dur = 2000 +# val_dur: duration of validation split in dataset, in seconds +val_dur = 170 +# test_dur: duration of test split in dataset, in seconds +test_dur = 350 + +# [vak.prep.spect_params]: parameters for computing spectrograms +[vak.prep.spect_params] +# fft_size: size of window used for Fast Fourier Transform, in number of samples +fft_size = 512 +# step_size: size of step to take when computing spectra with FFT for spectrogram +# also known as hop size +step_size = 64 +# qualitatively, we find that log transforming the spectrograms improves performance; +# think of this as increasing the contrast between high power and low power regions +transform_type = "log_spect" +# specifying cutoff frequencies of the spectrogram can (1) make the model more +# computationally efficient and (2) improve performance by only fitting the model +# to parts of the spectrum that are relevant for sounds of interest. +# Note these cutoffs are applied by computing the whole spectrogram first +# and then throwing away frequencies above and below the cutoffs; +# we do not apply a bandpass filter to the audio. +freq_cutoffs = [500, 8000] +# Note that for the TweetyNet model, the default is to set the hidden_size of the RNN +# equal to the input_size, so if you reduce the size of the spectrogram, this will reduce the +# hidden size of the RNN. If you observe impaired performance of TweetyNet after applying the frequency cutoffs, +# consider manually specifying a larger hidden (see `[vak.train.model.TweetyNet]` table below). + +# [vak.train]: options for training model +[vak.train] +# root_results_dir: directory where results should be saved, as a sub-directory within `root_results_dir` +root_results_dir = "/PATH/TO/FOLDER/results/train" +# batch_size: number of samples from dataset per batch fed into network +batch_size = 8 +# num_epochs: number of training epochs, where an epoch is one iteration through all samples in training split +num_epochs = 2 +# standardize_frames: if true, standardize (normalize) frames (input to neural network) per frequency bin, so mean of each is 0.0 and std is 1.0 +# across the entire training split +standardize_frames = true +# val_step: step number on which to compute metrics with validation set, every time step % val_step == 0 +# (a step is one batch fed through the network) +# saves a checkpoint if the monitored evaluation metric improves (which is model specific) +val_step = 1000 +# ckpt_step: step number on which to save a checkpoint (as a backup, regardless of validation metrics) +ckpt_step = 500 +# patience: number of validation steps to wait before stopping training early +# if the monitored evaluation metrics does not improve after `patience` validation steps, +# then we stop training +patience = 6 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 4 +# device: name of device to run model on, one of "cuda", "cpu" + +# dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it + +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.train.dataset.params] +# Bigger windows work better. +# For frame classification models, prefer smaller batch sizes with bigger windows +# Intuitively, bigger windows give the model more "contexts" for each frame per batch. +# See https://github.com/vocalpy/Nicholson-Cohen-SfN-2023-poster for more detail +window_size = 2000 + +# TweetyNet.network: we specify options for the model's network in this table +# To indicate the model to train, we use a "dotted key" with `model` followed by the string name of the model. +# This name must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` +# We use another dotted key to indicate options for configuring the model, e.g. `TweetyNet.optimizer` +[vak.train.model.TweetyNet] +[vak.train.model.TweetyNet.optimizer] +# vak.train.model.TweetyNet.optimizer: we specify options for the model's optimizer in this table +# lr: the learning rate +lr = 0.001 + +[vak.train.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +hidden_size = 256 + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.train.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] From 8114a685ab572dcc28bce4dadb8772bccd372320 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 12 Sep 2025 23:12:36 -0400 Subject: [PATCH 03/40] Fix how we call cli.cli in __main__ --- src/vak/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/__main__.py b/src/vak/__main__.py index 35a4bfd69..2d5440349 100644 --- a/src/vak/__main__.py +++ b/src/vak/__main__.py @@ -17,7 +17,7 @@ def main(args=None): if args is None: parser = cli.get_parser() args = parser.parse_args() - cli.cli(command=args.command, config_file=args.configfile) + cli.cli(args) if __name__ == "__main__": From b999985bed32c1a5e81fede8fdecd04a7ed05f6e Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 12 Sep 2025 23:13:09 -0400 Subject: [PATCH 04/40] Add cli command 'configfile' in vak.cli.cli --- src/vak/cli/cli.py | 103 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 91 insertions(+), 12 deletions(-) diff --git a/src/vak/cli/cli.py b/src/vak/cli/cli.py index efbd5dcb6..523c0a472 100644 --- a/src/vak/cli/cli.py +++ b/src/vak/cli/cli.py @@ -1,5 +1,6 @@ """Implements the vak command-line interface""" import argparse +import pathlib from dataclasses import dataclass from pathlib import Path from typing import Callable @@ -35,6 +36,26 @@ def prep(args): prep(toml_path=args.config_file) +def configfile(args): + print( + f"Generating TOML configuration file of kind: {args.kind}" + ) + if args.add_prep: + print( + f"Will add prep table" + ) + else: + print( + f"Will not add prep table" + ) + from ..config.generate import generate + generate( + kind=args.kind, + add_prep=args.add_prep, + dst=args.dst, + ) + + @dataclass class CLICommand: """Dataclass representing a cli command @@ -56,16 +77,68 @@ class CLICommand: add_parser_args_func : Callable -def add_configfile_arg( +def add_single_arg_configfile_to_command( cli_command, cli_command_parser ): - cli_command_parser.add_argument( - "configfile", - type=Path, - help="name of TOML configuration file to use \n" - f"$ vak {cli_command.name} ./configs/config_rat01337.toml", - ) + """Most of the CLICommands call this function + to add arguments to their sub-parser. + It adds a single positional argument, `configfile`. + Not to be confused with the *command* configfile, + that adds different arguments + """ + cli_command_parser.add_argument( + "configfile", + type=Path, + help="name of TOML configuration file to use \n" + f"$ vak {cli_command.name} ./configs/config_rat01337.toml", + ) + + +KINDS_OF_CONFIG_FILES = [ + # FIXME: there's no way to have a stand-alone prep file right now + # we need to add a `purpose` key-value pair to the file format + # to make this possible + # "prep", + "train", + "eval", + "predict", + "learncurve", +] + + +def add_args_to_configfile_command( + cli_command, + cli_command_parser +): + """This is the function that gets called + to add arguments to the sub-parser + for the configfile command + """ + cli_command_parser.add_argument( + "kind", + type=str, + choices=KINDS_OF_CONFIG_FILES, + help="kind: the kind of TOML configuration file to generate" + ) + cli_command_parser.add_argument( + "--add-prep", + action=argparse.BooleanOptionalAction, + default=False, + help="Adding this option will add a 'prep' table to the TOML configuration file. Default is False." + ) + cli_command_parser.add_argument( + "-dst", + type=pathlib.Path, + default=pathlib.Path.cwd(), + help="Destination, where TOML configuration file should be generated. Default is current working directory." + ) + # TODO: add this option + # cli_command_parser.add_argument( + # "--from", + # type=pathlib.Path, + # help="Path to another configuration file that this file should be generated from\n" + # ) CLI_COMMANDS = [ @@ -73,31 +146,37 @@ def add_configfile_arg( name='prep', help='prepare a dataset', func=prep, - add_parser_args_func=add_configfile_arg, + add_parser_args_func=add_single_arg_configfile_to_command, ), CLICommand( name='train', help='train a model', func=train, - add_parser_args_func=add_configfile_arg, + add_parser_args_func=add_single_arg_configfile_to_command, ), CLICommand( name='eval', help='evaluate a trained model', func=eval, - add_parser_args_func=add_configfile_arg, + add_parser_args_func=add_single_arg_configfile_to_command, ), CLICommand( name='predict', help='generate predictions from trained model', func=predict, - add_parser_args_func=add_configfile_arg, + add_parser_args_func=add_single_arg_configfile_to_command, ), CLICommand( name='learncurve', help='run a learning curve', func=learncurve, - add_parser_args_func=add_configfile_arg, + add_parser_args_func=add_single_arg_configfile_to_command, + ), + CLICommand( + name='configfile', + help='generate a TOML configuration file for vak', + func=configfile, + add_parser_args_func=add_args_to_configfile_command, ), ] From ff95d4ac95c457e60f7b3492280d38aab95cbd84 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Fri, 12 Sep 2025 23:13:29 -0400 Subject: [PATCH 05/40] Add src/vak/config/generate.py --- src/vak/config/generate.py | 50 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 src/vak/config/generate.py diff --git a/src/vak/config/generate.py b/src/vak/config/generate.py new file mode 100644 index 000000000..7848f443f --- /dev/null +++ b/src/vak/config/generate.py @@ -0,0 +1,50 @@ +import importlib.resources +import pathlib +import shutil + +CONFIGFILE_KIND_FILENAME_MAP = { + "train": "configfile_train.toml", + "eval": "configfile_eval.toml", + "predict": "configfile_predict.toml", + "learncurve": "configfile_learncurve.toml", +} + + +def generate( + kind: str, + add_prep: bool = False, + dst: str | pathlib.Path = pathlib.Path.cwd(), +) -> None: + """Generate a TOML configuration file + + This is the function called by + :func:`vak.cli.cli.generate` + when a user runs the command ``vak configfile`` + using the command-line interface. + + Parameters + ---------- + kind : str + The kind of TOML configuration file to generate. + One of: ``{'train', 'eval', 'predict', 'learncurve'}`` + add_prep : bool + If True, add a ``[vak.prep]`` table to the + TOML configuration file. + dst : string, pathlib.Path + Destination for the generated configuration file. + Either a full path including filename, + or a directory, in which case a default filename + will be used. + The default `dst` is the current working directory. + """ + dst = pathlib.Path(dst) + if not dst.is_dir() and dst.exists(): + raise ValueError( + f"Destination for generated config file `dst` is already a file that exists:\n{dst}\n" + "Please specify a value for the `--dst` argument that will not overwrite an existing file." + ) + filename = CONFIGFILE_KIND_FILENAME_MAP[kind] + src = pathlib.Path( + importlib.resources.files("vak.config").joinpath(filename) + ) + shutil.copy(src, dst) From fa3f333c47eca6657d6ae52d8e9dd655972583f5 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 13 Sep 2025 09:57:46 -0400 Subject: [PATCH 06/40] Revise module-level docstring in src/vak/config/load.py --- src/vak/config/load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/config/load.py b/src/vak/config/load.py index 3134dc85e..6d4597490 100644 --- a/src/vak/config/load.py +++ b/src/vak/config/load.py @@ -1,4 +1,4 @@ -"""Functions to parse toml config files.""" +"""Functions to load TOML configuration files.""" from __future__ import annotations From 427b75d4b622008be51309e7e498c919c4da8644 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 13 Sep 2025 09:58:02 -0400 Subject: [PATCH 07/40] Fix link in a comment in src/vak/cli/prep.py --- src/vak/cli/prep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/cli/prep.py b/src/vak/cli/prep.py index d86c4c0a9..de40ae810 100644 --- a/src/vak/cli/prep.py +++ b/src/vak/cli/prep.py @@ -51,7 +51,7 @@ def purpose_from_toml( # note NO LOGGING -- we configure logger inside `core.prep` # so we can save log file inside dataset directory -# see https://github.com/NickleDave/vak/issues/334 +# see https://github.com/vocalpy/vak/issues/334 TABLES_PREP_SHOULD_PARSE = "prep" From 85aaeac377fece6a532bf7ee2cc47c8796109b7a Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 13 Sep 2025 11:29:48 -0400 Subject: [PATCH 08/40] Move/add files -> src/vak/config/_toml_config_templates --- .../configfile_eval.toml | 58 ++++++++++ .../configfile_eval_prep.toml} | 13 +-- .../configfile_learncurve.toml | 59 ++++++++++ .../configfile_learncurve_prep.toml} | 2 +- .../configfile_predict.toml | 57 ++++++++++ .../configfile_predict_prep.toml} | 0 .../configfile_train.toml | 0 .../configfile_train_prep.toml | 101 ++++++++++++++++++ 8 files changed, 283 insertions(+), 7 deletions(-) create mode 100644 src/vak/config/_toml_config_templates/configfile_eval.toml rename src/vak/config/{configfile_eval.toml => _toml_config_templates/configfile_eval_prep.toml} (87%) create mode 100644 src/vak/config/_toml_config_templates/configfile_learncurve.toml rename src/vak/config/{config_learncurve.toml => _toml_config_templates/configfile_learncurve_prep.toml} (98%) create mode 100644 src/vak/config/_toml_config_templates/configfile_predict.toml rename src/vak/config/{configfile_predict.toml => _toml_config_templates/configfile_predict_prep.toml} (100%) rename src/vak/config/{ => _toml_config_templates}/configfile_train.toml (100%) create mode 100644 src/vak/config/_toml_config_templates/configfile_train_prep.toml diff --git a/src/vak/config/_toml_config_templates/configfile_eval.toml b/src/vak/config/_toml_config_templates/configfile_eval.toml new file mode 100644 index 000000000..229a5d53d --- /dev/null +++ b/src/vak/config/_toml_config_templates/configfile_eval.toml @@ -0,0 +1,58 @@ +# [vak.eval]: options for evaluating a trained model. This is done using the "test" split in a dataset by default. +[vak.eval] +# checkpoint_path: path to saved model checkpoint +checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" +# labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations; +# this is used when generating predictions +labelmap_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/labelmap.json" +# frames_standardizer_path: path to file containing SpectScaler that was fit to training set +# We want to transform the data we predict on in the exact same way +frames_standardizer_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/StandardizeSpect" +# batch_size +# for predictions with a frame classification model, this should always be 1 +# and will be ignored if it's not +batch_size = 11 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 16 +# device: name of device to run model on, one of "cuda", "cpu" + +# output_dir: directory where output should be saved, as a sub-directory within `output_dir` +output_dir = "/PATH/TO/FOLDER/results/eval" +# dataset_path : path to dataset created by prep +# ADD THE dataset_path OPTION FROM THE TRAIN FILE HERE (we already created a test split when we ran `vak prep` with that config) + +# [vak.eval.post_tfm_kwargs]: options for post-processing +[vak.eval.post_tfm_kwargs] +# both these transforms require that there is an "unlabeled" label, +# and they will only be applied to segments that are bordered on both sides +# by the "unlabeled" label. +# Such a label class is added by default by vak. +# majority_vote: post-processing transformation that takes majority vote within segments that +# do not have the 'unlabeled' class label. Only applied if `majority_vote` is `true` +# (default is false). +majority_vote = true +# min_segment_dur: post-processing transformation removes any segments +# with a duration shorter than `min_segment_dur` that do not have the 'unlabeled' class. +# Only applied if this option is specified. +min_segment_dur = 0.02 + +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.eval.dataset] +path = "/copy/path/from/train/config/here" +params = { window_size = 176 } + +# [vak.eval.model.TweetyNet]: We put this table so vak knows which model we are using +# We then add additional sub-tables to configure the model, e.g., [vak.eval.model.TweetyNet.network] +[vak.eval.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +# we trained with hidden size = 256 so we need to evaluate with the same hidden size; +# otherwise we'll get an error about "shapes do not match" when torch tries to load the checkpoint +hidden_size = 256 + +# [vak.eval.trainer]: this sub-table configures the `lightning.pytorch.Trainer` +[vak.eval.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/configfile_eval.toml b/src/vak/config/_toml_config_templates/configfile_eval_prep.toml similarity index 87% rename from src/vak/config/configfile_eval.toml rename to src/vak/config/_toml_config_templates/configfile_eval_prep.toml index 0f6bad0b4..ca8398281 100644 --- a/src/vak/config/configfile_eval.toml +++ b/src/vak/config/_toml_config_templates/configfile_eval_prep.toml @@ -1,3 +1,4 @@ +# [vak.prep]: options for preparing dataset [vak.prep] # dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" dataset_type = "frame classification" @@ -18,7 +19,7 @@ train_dur = 50 # val_dur: duration of validation split in dataset, in seconds val_dur = 15 -# SPECT_PARAMS: parameters for computing spectrograms +# [vak.prep.spect_params]: parameters for computing spectrograms [vak.prep.spect_params] # fft_size: size of window used for Fast Fourier Transform, in number of samples fft_size = 512 @@ -26,7 +27,7 @@ fft_size = 512 # also known as hop size step_size = 64 -# EVAL: options for evaluating a trained model. This is done using the "test" split. +# [vak.eval]: options for evaluating a trained model. This is done using the "test" split in a dataset by default. [vak.eval] # checkpoint_path: path to saved model checkpoint checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" @@ -49,7 +50,7 @@ output_dir = "/PATH/TO/FOLDER/results/eval" # dataset_path : path to dataset created by prep # ADD THE dataset_path OPTION FROM THE TRAIN FILE HERE (we already created a test split when we ran `vak prep` with that config) -# EVAL.post_tfm_kwargs: options for post-processing +# [vak.eval.post_tfm_kwargs]: options for post-processing [vak.eval.post_tfm_kwargs] # both these transforms require that there is an "unlabeled" label, # and they will only be applied to segments that are bordered on both sides @@ -70,15 +71,15 @@ min_segment_dur = 0.02 path = "/copy/path/from/train/config/here" params = { window_size = 176 } -# We put this table though vak knows which model we are using +# [vak.eval.model.TweetyNet]: We put this table so vak knows which model we are using +# We then add additional sub-tables to configure the model, e.g., [vak.eval.model.TweetyNet.network] [vak.eval.model.TweetyNet.network] # hidden_size: the number of elements in the hidden state in the recurrent layer of the network # we trained with hidden size = 256 so we need to evaluate with the same hidden size; # otherwise we'll get an error about "shapes do not match" when torch tries to load the checkpoint hidden_size = 256 - -# this sub-table configures the `lightning.pytorch.Trainer` +# [vak.eval.trainer]: this sub-table configures the `lightning.pytorch.Trainer` [vak.eval.trainer] # setting to 'gpu' means "train models on 'gpu' (not 'cpu')" accelerator = "gpu" diff --git a/src/vak/config/_toml_config_templates/configfile_learncurve.toml b/src/vak/config/_toml_config_templates/configfile_learncurve.toml new file mode 100644 index 000000000..a3d5294f4 --- /dev/null +++ b/src/vak/config/_toml_config_templates/configfile_learncurve.toml @@ -0,0 +1,59 @@ +# [vak.learncurve]: options for running the learning curve +# that estimates model performance +# as a function of training set size +[vak.learncurve] +# root_results_dir: directory where results should be saved, as a sub-directory within `root_results_dir` +root_results_dir = "./tests/data_for_tests/generated/results/learncurve/audio_cbin_annot_notmat/TweetyNet" +# batch_size: number of samples from dataset per batch fed into network +batch_size = 11 +# num_epochs: number of training epochs, where an epoch is one iteration through all samples in training split +num_epochs = 2 +# standardize_frames: if true, standardize (normalize) frames (input to neural network) per frequency bin, so mean of each is 0.0 and std is 1.0 +# across the entire training split +standardize_frames = true +# val_step: step number on which to compute metrics with validation set, every time step % val_step == 0 +# (a step is one batch fed through the network) +# saves a checkpoint if the monitored evaluation metric improves (which is model specific) +val_step = 50 +# ckpt_step: step number on which to save a checkpoint (as a backup, regardless of validation metrics) +ckpt_step = 200 +# patience: number of validation steps to wait before stopping training early +# if the monitored evaluation metrics does not improve after `patience` validation steps, +# then we stop training +patience = 4 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 16 + +[vak.learncurve.post_tfm_kwargs] +majority_vote = true +min_segment_dur = 0.02 + +[vak.learncurve.dataset] +# params : parameters that configure the `vak.datapipes` or `vak.datasets` class +# for a frame classification model, we use dataset classes with a specific `window_size` +# Bigger windows work better. +# For frame classification models, prefer smaller batch sizes with bigger windows +# Intuitively, bigger windows give the model more "contexts" for each frame per batch. +# See https://github.com/vocalpy/Nicholson-Cohen-SfN-2023-poster for more detail +params = { window_size = 88 } +# path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it + +# TweetyNet.network: we specify options for the model's network in this table +# To indicate the model to train, we use a "dotted key" with `model` followed by the string name of the model. +# This name must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` +# We use another dotted key to indicate options for configuring the model, e.g. `TweetyNet.optimizer` +[vak.train.model.TweetyNet.optimizer] +# vak.train.model.TweetyNet.optimizer: we specify options for the model's optimizer in this table +# lr: the learning rate +lr = 0.001 + +[vak.learncurve.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +hidden_size = 256 + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.learncurve.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/config_learncurve.toml b/src/vak/config/_toml_config_templates/configfile_learncurve_prep.toml similarity index 98% rename from src/vak/config/config_learncurve.toml rename to src/vak/config/_toml_config_templates/configfile_learncurve_prep.toml index 825a8d54a..43ca6080f 100644 --- a/src/vak/config/config_learncurve.toml +++ b/src/vak/config/_toml_config_templates/configfile_learncurve_prep.toml @@ -45,7 +45,7 @@ freq_cutoffs = [ 500, 10000,] # hidden size of the RNN. If you observe impaired performance of TweetyNet after applying the frequency cutoffs, # consider manually specifying a larger hidden (see `[vak.train.model.TweetyNet]` table below). -# learncurve: options for running the learning curve +# [vak.learncurve]: options for running the learning curve # that estimates model performance # as a function of training set size [vak.learncurve] diff --git a/src/vak/config/_toml_config_templates/configfile_predict.toml b/src/vak/config/_toml_config_templates/configfile_predict.toml new file mode 100644 index 000000000..6303c6537 --- /dev/null +++ b/src/vak/config/_toml_config_templates/configfile_predict.toml @@ -0,0 +1,57 @@ +# [vak.predict]: options for generating predictions with a trained model +[vak.predict] +# checkpoint_path: path to saved model checkpoint +checkpoint_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/TweetyNet/checkpoints/max-val-acc-checkpoint.pt" +# labelmap_path: path to file that maps from outputs of model (integers) to text labels in annotations; +# this is used when generating predictions +labelmap_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/labelmap.json" +# frames_standardizer_path: path to file containing SpectScaler that was fit to training set +# We want to transform the data we predict on in the exact same way +frames_standardizer_path = "/PATH/TO/FOLDER/results/train/RESULTS_TIMESTAMP/StandardizeSpect" +# batch_size +# for predictions with a frame classification model, this should always be 1 +# and will be ignored if it's not +batch_size = 1 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 4 +# device: name of device to run model on, one of "cuda", "cpu" + +# output_dir: directory where output should be saved, as a sub-directory within `output_dir` +output_dir = "/PATH/TO/FOLDER/results/predict" +# annot_csv_filename +annot_csv_filename = "gy6or6.032312.annot.csv" +# The next two options are for post-processing transforms. +# Both these transforms require that there is an "unlabeled" label, +# and they will only be applied to segments that are bordered on both sides +# by the "unlabeled" label. +# Such a label class is added by default by vak. +# majority_vote: post-processing transformation that takes majority vote within segments that +# do not have the 'unlabeled' class label. Only applied if `majority_vote` is `true` +# (default is false). +majority_vote = true +# min_segment_dur: post-processing transformation removes any segments +# with a duration shorter than `min_segment_dur` that do not have the 'unlabeled' class. +# Only applied if this option is specified. +min_segment_dur = 0.01 +# dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it + +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.predict.dataset] +path = "/copy/path/from/train/config/here" +params = { window_size = 176 } + +# [vak.predict.model.TweetyNet]: We put this table so vak knows which model we are using +# We then add additional sub-tables to configure the model, e.g., [vak.eval.model.TweetyNet.network] +[vak.predict.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +# we trained with hidden size = 256 so we need to evaluate with the same hidden size; +# otherwise we'll get an error about "shapes do not match" when torch tries to load the checkpoint +hidden_size = 256 + +# [vak.predict.trainer]: this sub-table configures the `lightning.pytorch.Trainer` +[vak.predict.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] diff --git a/src/vak/config/configfile_predict.toml b/src/vak/config/_toml_config_templates/configfile_predict_prep.toml similarity index 100% rename from src/vak/config/configfile_predict.toml rename to src/vak/config/_toml_config_templates/configfile_predict_prep.toml diff --git a/src/vak/config/configfile_train.toml b/src/vak/config/_toml_config_templates/configfile_train.toml similarity index 100% rename from src/vak/config/configfile_train.toml rename to src/vak/config/_toml_config_templates/configfile_train.toml diff --git a/src/vak/config/_toml_config_templates/configfile_train_prep.toml b/src/vak/config/_toml_config_templates/configfile_train_prep.toml new file mode 100644 index 000000000..1aa0dc45f --- /dev/null +++ b/src/vak/config/_toml_config_templates/configfile_train_prep.toml @@ -0,0 +1,101 @@ +# [vak.prep]: options for preparing dataset +[vak.prep] +# dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" +dataset_type = "frame classification" +# input_type: input to model, either audio ("audio") or spectrogram ("spect") +input_type = "spect" +# data_dir: directory with data to use when preparing dataset +data_dir = "/PATH/TO/FOLDER/gyor6/032212" +# output_dir: directory where dataset will be created (as a sub-directory within output_dir) +output_dir = "/PATH/TO/FOLDER/prep/train" +# audio_format: format of audio, either wav or cbin +audio_format = "wav" +# annot_format: format of annotations +annot_format = "simple-seq" +# labelset: string or array with unique set of labels used in annotations +labelset = "iabcdefghjk" +# train_dur: duration of training split in dataset, in seconds +train_dur = 2000 +# val_dur: duration of validation split in dataset, in seconds +val_dur = 170 +# test_dur: duration of test split in dataset, in seconds +test_dur = 350 + +# [vak.prep.spect_params]: parameters for computing spectrograms +[vak.prep.spect_params] +# fft_size: size of window used for Fast Fourier Transform, in number of samples +fft_size = 512 +# step_size: size of step to take when computing spectra with FFT for spectrogram +# also known as hop size +step_size = 64 +# qualitatively, we find that log transforming the spectrograms improves performance; +# think of this as increasing the contrast between high power and low power regions +transform_type = "log_spect" +# specifying cutoff frequencies of the spectrogram can (1) make the model more +# computationally efficient and (2) improve performance by only fitting the model +# to parts of the spectrum that are relevant for sounds of interest. +# Note these cutoffs are applied by computing the whole spectrogram first +# and then throwing away frequencies above and below the cutoffs; +# we do not apply a bandpass filter to the audio. +freq_cutoffs = [500, 8000] +# Note that for the TweetyNet model, the default is to set the hidden_size of the RNN +# equal to the input_size, so if you reduce the size of the spectrogram, this will reduce the +# hidden size of the RNN. If you observe impaired performance of TweetyNet after applying the frequency cutoffs, +# consider manually specifying a larger hidden (see `[vak.train.model.TweetyNet]` table below). + +# [vak.train]: options for training model +[vak.train] +# root_results_dir: directory where results should be saved, as a sub-directory within `root_results_dir` +root_results_dir = "/PATH/TO/FOLDER/results/train" +# batch_size: number of samples from dataset per batch fed into network +batch_size = 8 +# num_epochs: number of training epochs, where an epoch is one iteration through all samples in training split +num_epochs = 2 +# standardize_frames: if true, standardize (normalize) frames (input to neural network) per frequency bin, so mean of each is 0.0 and std is 1.0 +# across the entire training split +standardize_frames = true +# val_step: step number on which to compute metrics with validation set, every time step % val_step == 0 +# (a step is one batch fed through the network) +# saves a checkpoint if the monitored evaluation metric improves (which is model specific) +val_step = 1000 +# ckpt_step: step number on which to save a checkpoint (as a backup, regardless of validation metrics) +ckpt_step = 500 +# patience: number of validation steps to wait before stopping training early +# if the monitored evaluation metrics does not improve after `patience` validation steps, +# then we stop training +patience = 6 +# num_workers: number of workers to use when loading data with multiprocessing +num_workers = 4 +# device: name of device to run model on, one of "cuda", "cpu" + +# dataset_path : path to dataset created by prep. This will be added when you run `vak prep`, you don't have to add it + +# dataset.params = parameters used for datasets +# for a frame classification model, we use dataset classes with a specific `window_size` +[vak.train.dataset.params] +# Bigger windows work better. +# For frame classification models, prefer smaller batch sizes with bigger windows +# Intuitively, bigger windows give the model more "contexts" for each frame per batch. +# See https://github.com/vocalpy/Nicholson-Cohen-SfN-2023-poster for more detail +window_size = 2000 + +# TweetyNet.network: we specify options for the model's network in this table +# To indicate the model to train, we use a "dotted key" with `model` followed by the string name of the model. +# This name must be a name within `vak.models` or added e.g. with `vak.model.decorators.model` +# We use another dotted key to indicate options for configuring the model, e.g. `TweetyNet.optimizer` +[vak.train.model.TweetyNet] +[vak.train.model.TweetyNet.optimizer] +# vak.train.model.TweetyNet.optimizer: we specify options for the model's optimizer in this table +# lr: the learning rate +lr = 0.001 + +[vak.train.model.TweetyNet.network] +# hidden_size: the number of elements in the hidden state in the recurrent layer of the network +hidden_size = 256 + +# this sub-table configures the `lightning.pytorch.Trainer` +[vak.train.trainer] +# setting to 'gpu' means "train models on 'gpu' (not 'cpu')" +accelerator = "gpu" +# use the first GPU (numbering starts from 0) +devices = [0] From e734c66f67e3bc81a91228eb411b3e7ee4635280 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sun, 14 Sep 2025 11:14:26 -0400 Subject: [PATCH 09/40] Fixup vak.config.generate --- src/vak/config/generate.py | 41 +++++++++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 5 deletions(-) diff --git a/src/vak/config/generate.py b/src/vak/config/generate.py index 7848f443f..c46c3629c 100644 --- a/src/vak/config/generate.py +++ b/src/vak/config/generate.py @@ -1,6 +1,8 @@ import importlib.resources import pathlib -import shutil + +import tomlkit + CONFIGFILE_KIND_FILENAME_MAP = { "train": "configfile_train.toml", @@ -9,6 +11,11 @@ "learncurve": "configfile_learncurve.toml", } +# next line: can't use `.items()`, we'll get `RuntimeError` about dictionary changed sized during iteration +for key in list(CONFIGFILE_KIND_FILENAME_MAP.keys()): + val = CONFIGFILE_KIND_FILENAME_MAP[key] + CONFIGFILE_KIND_FILENAME_MAP[f"{key}_prep"] = val.replace(key, f"{key}_prep") + def generate( kind: str, @@ -43,8 +50,32 @@ def generate( f"Destination for generated config file `dst` is already a file that exists:\n{dst}\n" "Please specify a value for the `--dst` argument that will not overwrite an existing file." ) - filename = CONFIGFILE_KIND_FILENAME_MAP[kind] - src = pathlib.Path( - importlib.resources.files("vak.config").joinpath(filename) + + # for now, we "add a prep section" by using a naming convention + # and loading an existing toml file that has a `[vak.prep]` table + if add_prep: + kind = f"{kind}_prep" + + try: + src_filename = CONFIGFILE_KIND_FILENAME_MAP[kind] + except KeyError: + raise ValueError( + f"Invalid kind: {kind}" + ) + + src_path = pathlib.Path( + importlib.resources.files("vak.config._toml_config_templates").joinpath(src_filename) ) - shutil.copy(src, dst) + # even though we are loading an existing file, + # we use tomlkit to load and dump. + # TODO: add "interactive" arg and use tomlkit with `input` to interactively build config file + with src_path.open("r") as fp: + tomldoc = tomlkit.load(fp) + + if dst.is_dir(): + dst_path = dst / src_filename + else: + dst_path = dst + + with dst_path.open("w") as fp: + tomlkit.dump(tomldoc, fp) From eb7ca317f818afa33bc7009cbb1b5e13ea0077ad Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sun, 14 Sep 2025 11:14:49 -0400 Subject: [PATCH 10/40] Import generate function in vak.config.__init__.py --- src/vak/config/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index c1828aff9..8e073827a 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -17,6 +17,7 @@ from .config import Config from .dataset import DatasetConfig from .eval import EvalConfig +from .generate import generate from .learncurve import LearncurveConfig from .model import ModelConfig from .predict import PredictConfig @@ -29,6 +30,7 @@ "config", "dataset", "eval", + "generate", "learncurve", "model", "load", From 203d7f3193f6a57d9d63b82382556280c61f7ea6 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 17 Sep 2025 12:57:14 -0400 Subject: [PATCH 11/40] Add tests/fixture/parser.py with get_parser fixture --- tests/fixtures/__init__.py | 1 + tests/fixtures/parser.py | 10 ++++++++++ 2 files changed, 11 insertions(+) create mode 100644 tests/fixtures/parser.py diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 18d506be7..e0899b8ae 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -8,6 +8,7 @@ from .device import * from .trainer import * from .model import * +from .parser import * from .path import * from .source_files import * from .spect import * diff --git a/tests/fixtures/parser.py b/tests/fixtures/parser.py new file mode 100644 index 000000000..81d48ab6e --- /dev/null +++ b/tests/fixtures/parser.py @@ -0,0 +1,10 @@ +import pytest + +import vak.cli.cli + + +@pytest.fixture +def parser(): + """Return an instance of the parser used by the command-line interface, + by calling :func:`vak.cli.cli.get_parser`""" + return vak.cli.cli.get_parser() From 12b31f3bce77ecd250fdd9fc7334b00c1366b14e Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 17 Sep 2025 12:57:46 -0400 Subject: [PATCH 12/40] Add tests/test_cli/test_cli.py with unit tests moved from tests/test__main__.py --- tests/test___main__.py | 52 -------------------------------------- tests/test_cli/test_cli.py | 50 ++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 52 deletions(-) create mode 100644 tests/test_cli/test_cli.py diff --git a/tests/test___main__.py b/tests/test___main__.py index eb94797a5..2d67561c3 100644 --- a/tests/test___main__.py +++ b/tests/test___main__.py @@ -6,58 +6,6 @@ import vak -@pytest.fixture -def parser(): - return vak.__main__.get_parser() - - -def test_parser_usage(parser, - capsys): - with pytest.raises(SystemExit): - parser.parse_args(args=['']) - captured = capsys.readouterr() - assert captured.err.startswith( - "usage: vak [-h] command configfile" - ) - - -def test_parser_help(parser, - capsys): - with pytest.raises(SystemExit): - parser.parse_args(['-h']) - captured = capsys.readouterr() - assert captured.out.startswith( - "usage: vak [-h] command configfile" - ) - - -DUMMY_CONFIGFILE = './configs/config_2018-12-17.toml' - - -@pytest.mark.parametrize( - 'command, raises', - [ - ('prep', False), - ('train', False), - ('learncurve', False), - ('eval', False), - ('predict', False), - ('not-a-valid-command', True), - ] -) -def test_parser(command, - raises, - parser, - capsys): - if raises: - with pytest.raises(SystemExit): - parser.parse_args([command, DUMMY_CONFIGFILE]) - else: - args = parser.parse_args([command, DUMMY_CONFIGFILE]) - assert args.command == command - assert args.configfile == pathlib.Path(DUMMY_CONFIGFILE) - - @pytest.mark.parametrize( 'command', [ diff --git a/tests/test_cli/test_cli.py b/tests/test_cli/test_cli.py new file mode 100644 index 000000000..f4dc06b4a --- /dev/null +++ b/tests/test_cli/test_cli.py @@ -0,0 +1,50 @@ +import pathlib + +import pytest + + +def test_parser_usage(parser, + capsys): + with pytest.raises(SystemExit): + parser.parse_args(args=['']) + captured = capsys.readouterr() + assert captured.err.startswith( + "usage: vak [-h] {prep,train,eval,predict,learncurve,configfile} ..." + ) + + +def test_parser_help(parser, + capsys): + with pytest.raises(SystemExit): + parser.parse_args(['-h']) + captured = capsys.readouterr() + assert captured.out.startswith( + "usage: vak [-h] {prep,train,eval,predict,learncurve,configfile} ..." + ) + + +DUMMY_CONFIGFILE = './configs/config_2018-12-17.toml' + + +@pytest.mark.parametrize( + 'command, raises', + [ + ('prep', False), + ('train', False), + ('learncurve', False), + ('eval', False), + ('predict', False), + ('not-a-valid-command', True), + ] +) +def test_parser(command, + raises, + parser, + capsys): + if raises: + with pytest.raises(SystemExit): + parser.parse_args([command, DUMMY_CONFIGFILE]) + else: + args = parser.parse_args([command, DUMMY_CONFIGFILE]) + assert args.command == command + assert args.configfile == pathlib.Path(DUMMY_CONFIGFILE) \ No newline at end of file From 30d791e2440135a2979346627abd89b55613d8ba Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 20 Sep 2025 10:47:36 -0400 Subject: [PATCH 13/40] Have __main__ call parser.print_help() if no args are passed to cli --- src/vak/__main__.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/vak/__main__.py b/src/vak/__main__.py index 2d5440349..f17629753 100644 --- a/src/vak/__main__.py +++ b/src/vak/__main__.py @@ -2,6 +2,8 @@ Invokes __main__ when the module is run as a script. Example: python -m vak --help """ +import sys + from .cli import cli @@ -14,9 +16,16 @@ def main(args=None): ``args`` is used for unit testing only """ + parser = cli.get_parser() + + if len(sys.argv) < 2: + parser.print_help() + sys.exit(1) + if args is None: - parser = cli.get_parser() args = parser.parse_args() + else: + args = parser.parse_args(args) cli.cli(args) From a117d0e9d5ef2308649ced28ed60a9ebb91713cc Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 20 Sep 2025 10:48:34 -0400 Subject: [PATCH 14/40] Set required=True when we add_subparsers in cli.get_parser, and add title+description so we get more info from parser.print_help --- src/vak/cli/cli.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/vak/cli/cli.py b/src/vak/cli/cli.py index 523c0a472..a26ac4e1a 100644 --- a/src/vak/cli/cli.py +++ b/src/vak/cli/cli.py @@ -185,14 +185,16 @@ def get_parser(): """returns ArgumentParser instance used by main()""" parser = argparse.ArgumentParser( prog="vak", - description="vak command-line interface", + description="Vak command-line interface", formatter_class=argparse.RawTextHelpFormatter, ) # create sub-parser sub_parsers = parser.add_subparsers( - help='Commands for vak command-line interface', + title="Command", + description="Commands for the vak command-line interface", dest="command", + required=True, ) for cli_command in CLI_COMMANDS: From b6fef396d1b6d7621b820215a33c3164381f6cc8 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sat, 20 Sep 2025 10:50:32 -0400 Subject: [PATCH 15/40] Add smoke test for vak.cli.cli.get_parser --- tests/test_cli/test_cli.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_cli/test_cli.py b/tests/test_cli/test_cli.py index f4dc06b4a..f9b7268c9 100644 --- a/tests/test_cli/test_cli.py +++ b/tests/test_cli/test_cli.py @@ -1,7 +1,16 @@ +import argparse import pathlib import pytest +import vak.cli.cli + + +def test_get_parser(): + """Smoke test that just makes sure we get back a parser as expected""" + parser = vak.cli.cli.get_parser() + assert isinstance(parser, argparse.ArgumentParser) + def test_parser_usage(parser, capsys): From 5fffe81116b561a83eb72b429f34b3b6a647a96c Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sun, 21 Sep 2025 10:22:35 -0400 Subject: [PATCH 16/40] Rewrite __main__.main to take `args_list` for clarity --- src/vak/__main__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/vak/__main__.py b/src/vak/__main__.py index f17629753..e9f1c3aca 100644 --- a/src/vak/__main__.py +++ b/src/vak/__main__.py @@ -7,14 +7,14 @@ from .cli import cli -def main(args=None): +def main(args_list:list[str] | None = None): """Main function called when run as script or through command-line interface called when package is run with `python -m vak` or alternatively just calling `vak` at the command line (because this function is installed under just `vak` as a console script) - ``args`` is used for unit testing only + ``args_list`` is used for unit testing only """ parser = cli.get_parser() @@ -22,10 +22,10 @@ def main(args=None): parser.print_help() sys.exit(1) - if args is None: + if args_list is None: args = parser.parse_args() else: - args = parser.parse_args(args) + args = parser.parse_args(args_list) cli.cli(args) From d46931aa8648979beaec6f1e9c47fe50e29afff7 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sun, 21 Sep 2025 10:24:17 -0400 Subject: [PATCH 17/40] In vak/cli/cli.py, fix args attribute name config_file -> configfile, typehints args parameter as arpgarse.Namespace, and spell out pathlib.Path instead of importing Path from pathlib --- src/vak/cli/cli.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/vak/cli/cli.py b/src/vak/cli/cli.py index a26ac4e1a..ea534b951 100644 --- a/src/vak/cli/cli.py +++ b/src/vak/cli/cli.py @@ -2,38 +2,37 @@ import argparse import pathlib from dataclasses import dataclass -from pathlib import Path from typing import Callable def eval(args): from .eval import eval - eval(toml_path=args.config_file) + eval(toml_path=args.configfile) def train(args): from .train import train - train(toml_path=args.config_file) + train(toml_path=args.configfile) def learncurve(args): from .learncurve import learning_curve - learning_curve(toml_path=args.config_file) + learning_curve(toml_path=args.configfile) def predict(args): from .predict import predict - predict(toml_path=args.config_file) + predict(toml_path=args.configfile) def prep(args): from .prep import prep - prep(toml_path=args.config_file) + prep(toml_path=args.configfile) def configfile(args): @@ -89,7 +88,7 @@ def add_single_arg_configfile_to_command( """ cli_command_parser.add_argument( "configfile", - type=Path, + type=pathlib.Path, help="name of TOML configuration file to use \n" f"$ vak {cli_command.name} ./configs/config_rat01337.toml", ) @@ -128,7 +127,7 @@ def add_args_to_configfile_command( help="Adding this option will add a 'prep' table to the TOML configuration file. Default is False." ) cli_command_parser.add_argument( - "-dst", + "--dst", type=pathlib.Path, default=pathlib.Path.cwd(), help="Destination, where TOML configuration file should be generated. Default is current working directory." @@ -216,15 +215,15 @@ def get_parser(): } -def cli(args): +def cli(args: argparse.Namespace): """Execute the commands of the command-line interface. Parameters ---------- - command : string - One of {'prep', 'train', 'eval', 'predict', 'learncurve'} - config_file : str, Path - path to a config.toml file + args : argparse.Namespace + Result of calling :meth:`ArgumentParser.parse_args` + on the :class:`ArgumentParser` instance returned by + :func:`vak.cli.cli.get_parser`. """ if args.command in CLI_COMMAND_FUNCTION_MAP: CLI_COMMAND_FUNCTION_MAP[args.command](args) From cdd079c5cccbcb3533430e2e3669f0ffbc9812ad Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sun, 21 Sep 2025 10:24:55 -0400 Subject: [PATCH 18/40] Rewrite unit tests in test___main__.py: test main with arg_list, and test that calling 'vak' without args results in parser.print_help() --- tests/test___main__.py | 45 ++++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/tests/test___main__.py b/tests/test___main__.py index 2d67561c3..3a6cc9e9f 100644 --- a/tests/test___main__.py +++ b/tests/test___main__.py @@ -1,4 +1,4 @@ -import pathlib +import subprocess from unittest import mock import pytest @@ -6,21 +6,40 @@ import vak +DUMMY_CONFIGFILE = './configs/config_2018-12-17.toml' + + @pytest.mark.parametrize( - 'command', + 'args_list', [ - 'prep', - 'train', - 'learncurve', - 'eval', - 'predict', + ['prep', DUMMY_CONFIGFILE], + ['train', DUMMY_CONFIGFILE], + ['learncurve', DUMMY_CONFIGFILE], + ['eval', DUMMY_CONFIGFILE], + ['predict', DUMMY_CONFIGFILE], + ['configfile', 'train', '--add-prep', '--dst', DUMMY_CONFIGFILE] ] ) -def test_main(command, - parser): - args = parser.parse_args([command, DUMMY_CONFIGFILE]) +def test_main(args_list): + """Test that :func:`vak.__main__.main` calls the function we expect through :func:`vak.cli.cli`""" + command = args_list[0] mock_cli_function = mock.Mock(name=f'mock_{command}') - with mock.patch.dict(vak.cli.cli.COMMAND_FUNCTION_MAP, - {command: mock_cli_function}) as mock_command_function_map: - vak.__main__.main(args) + with mock.patch.dict( + vak.cli.cli.CLI_COMMAND_FUNCTION_MAP, {command: mock_cli_function} + ): + # we can't do this with `subprocess` since the function won't be mocked in the subprocess, + # so we need to test indirectly with `arg_list` passed into `main` + vak.__main__.main(args_list) mock_cli_function.assert_called() + + +def test___main__prints_help_with_no_args(parser, capsys): + """Test that if we don't pass in any args, we get """ + parser.print_help() + expected_output = capsys.readouterr().out.rstrip() + + # doing this by calling a `subprocess` is slow but lets us test the CLI directly + result = subprocess.run("vak", capture_output=True, text=True) # call `vak` at CLI with no help + output = result.stdout.rstrip() + + assert output == expected_output From 0e0ef5dd50b9d350ce77d4fb8da47787db476f78 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sun, 21 Sep 2025 10:55:48 -0400 Subject: [PATCH 19/40] WIP: Rewrite unit tests in tests/test_cli/test_cli.py --- tests/test_cli/test_cli.py | 118 +++++++++++++++++++++++++------------ 1 file changed, 80 insertions(+), 38 deletions(-) diff --git a/tests/test_cli/test_cli.py b/tests/test_cli/test_cli.py index f9b7268c9..753fc2699 100644 --- a/tests/test_cli/test_cli.py +++ b/tests/test_cli/test_cli.py @@ -1,59 +1,101 @@ import argparse import pathlib +from unittest import mock import pytest import vak.cli.cli -def test_get_parser(): - """Smoke test that just makes sure we get back a parser as expected""" +DUMMY_CONFIGFILE_STR = './configs/config_2018-12-17.toml' +DUMMY_CONFIGFILE_PATH = pathlib.Path(DUMMY_CONFIGFILE_STR) + + +@pytest.mark.parametrize( + 'args_list, expected_attributes', + [ + ( + ['prep', DUMMY_CONFIGFILE_STR], + dict(command="prep", configfile=DUMMY_CONFIGFILE_PATH) + ), + ( + ['train', DUMMY_CONFIGFILE_STR], + dict(command="train", configfile=DUMMY_CONFIGFILE_PATH) + ), + ( + ['learncurve', DUMMY_CONFIGFILE_STR], + dict(command="learncurve", configfile=DUMMY_CONFIGFILE_PATH) + ), + ( + ['eval', DUMMY_CONFIGFILE_STR], + dict(command="eval", configfile=DUMMY_CONFIGFILE_PATH) + ), + ( + ['predict', DUMMY_CONFIGFILE_STR], + dict(command="predict", configfile=DUMMY_CONFIGFILE_PATH) + ), + ( + ['configfile', 'train'], + dict(command="configfile", kind="train", add_prep=False, dst=pathlib.Path.cwd()) + ), + ( + ['configfile', 'eval'], + dict(command="configfile", kind="eval", add_prep=False, dst=pathlib.Path.cwd()) + ), + ( + ['configfile', 'train', "--add-prep"], + dict(command="configfile", kind="train", add_prep=True, dst=pathlib.Path.cwd()) + ) + ] +) +def test_parser_commands_with_configfile(args_list, expected_attributes): + """Test that calling parser.parse_args gives us a Namespace with the expected args""" parser = vak.cli.cli.get_parser() assert isinstance(parser, argparse.ArgumentParser) + args = parser.parse_args(args_list) + assert isinstance(args, argparse.Namespace) -def test_parser_usage(parser, - capsys): - with pytest.raises(SystemExit): - parser.parse_args(args=['']) - captured = capsys.readouterr() - assert captured.err.startswith( - "usage: vak [-h] {prep,train,eval,predict,learncurve,configfile} ..." - ) + for attr_name, expected_value in expected_attributes.items(): + assert hasattr(args, attr_name) + assert getattr(args, attr_name) == expected_value -def test_parser_help(parser, - capsys): - with pytest.raises(SystemExit): - parser.parse_args(['-h']) - captured = capsys.readouterr() - assert captured.out.startswith( - "usage: vak [-h] {prep,train,eval,predict,learncurve,configfile} ..." - ) - -DUMMY_CONFIGFILE = './configs/config_2018-12-17.toml' +def test_parser_raises(parser): + """Test that an invalid command passed into our ArgumentParser raises a SystemExit""" + with pytest.raises(SystemExit): + parser.parse_args(["not-a-valid-command", DUMMY_CONFIGFILE_STR]) @pytest.mark.parametrize( - 'command, raises', + 'args_list', [ - ('prep', False), - ('train', False), - ('learncurve', False), - ('eval', False), - ('predict', False), - ('not-a-valid-command', True), + ['prep', DUMMY_CONFIGFILE_STR], + ['train', DUMMY_CONFIGFILE_STR], + ['learncurve', DUMMY_CONFIGFILE_STR], + ['eval', DUMMY_CONFIGFILE_STR], + ['predict', DUMMY_CONFIGFILE_STR], + ['configfile', 'train', '--add-prep', '--dst', DUMMY_CONFIGFILE_STR] ] ) -def test_parser(command, - raises, - parser, - capsys): - if raises: - with pytest.raises(SystemExit): - parser.parse_args([command, DUMMY_CONFIGFILE]) - else: - args = parser.parse_args([command, DUMMY_CONFIGFILE]) - assert args.command == command - assert args.configfile == pathlib.Path(DUMMY_CONFIGFILE) \ No newline at end of file +def test_cli( + args_list, parser, +): + """Test that :func:`vak.cli.cli.cli` calls the functions we expect""" + args = parser.parse_args(args_list) + + command = args_list[0] + mock_cli_function = mock.Mock(name=f'mock_{command}') + with mock.patch.dict( + vak.cli.cli.CLI_COMMAND_FUNCTION_MAP, {command: mock_cli_function} + ): + # we can't do this with `subprocess` since the function won't be mocked in the subprocess, + # so we need to test indirectly with `arg_list` passed into `main` + vak.cli.cli.cli(args) + mock_cli_function.assert_called() + + +def test_configfile(): + # FIXME test that configfile works the way we expect + assert False \ No newline at end of file From 1482ad314fd5519d022ec0da94061064147d34e6 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sun, 21 Sep 2025 23:17:20 -0400 Subject: [PATCH 20/40] Finish re-writing unit tests in tests/test_cli/test_cli.py --- tests/test_cli/test_cli.py | 66 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 63 insertions(+), 3 deletions(-) diff --git a/tests/test_cli/test_cli.py b/tests/test_cli/test_cli.py index 753fc2699..da81cd45d 100644 --- a/tests/test_cli/test_cli.py +++ b/tests/test_cli/test_cli.py @@ -96,6 +96,66 @@ def test_cli( mock_cli_function.assert_called() -def test_configfile(): - # FIXME test that configfile works the way we expect - assert False \ No newline at end of file +@pytest.mark.parametrize( + 'args_list, cli_helper_function, module, function_name', + [ + ( + ['prep', DUMMY_CONFIGFILE_STR], + vak.cli.cli.prep, + vak.cli.prep, + "prep", + ), + ( + ['train', DUMMY_CONFIGFILE_STR], + vak.cli.cli.train, + vak.cli.train, + "train", + ), + ( + ['learncurve', DUMMY_CONFIGFILE_STR], + vak.cli.cli.learncurve, + vak.cli.learncurve, + "learning_curve", + ), + ( + ['eval', DUMMY_CONFIGFILE_STR], + vak.cli.cli.eval, + vak.cli.eval, + "eval", + ), + ( + ['predict', DUMMY_CONFIGFILE_STR], + vak.cli.cli.predict, + vak.cli.predict, + "predict", + ), + ( + ['configfile', 'train', '--add-prep', '--dst', DUMMY_CONFIGFILE_STR], + vak.cli.cli.configfile, + vak.config.generate, + "generate", + ), + ] +) +def test_cli_helper_functions( + args_list, cli_helper_function, module, function_name, parser +): + """Test that helper functions we use to map commands to actual functions in the cli module + or elsewhere call those functions as expected""" + # this feels like I'm testing low-level implementation details + # but I feel like I should have some unit tests for what's happening in this module + args = parser.parse_args(args_list) + + with mock.patch.object(module, function_name, autospec=True) as mock_cli_function: + cli_helper_function(args) + + if args.command == "configfile": + mock_call = mock.call( + kind=args.kind, add_prep=args.add_prep, dst=args.dst + ) + else: + mock_call = mock.call(toml_path=args.configfile) + + assert mock_cli_function.mock_calls == [ + mock_call + ] From d77ae28ec87802cfdb4665939fc1e47425c65a6a Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sun, 21 Sep 2025 23:17:38 -0400 Subject: [PATCH 21/40] Import generate *module* not function in src/vak/config/__init__.py --- src/vak/config/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index 8e073827a..5318c8d37 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -4,6 +4,7 @@ config, dataset, eval, + generate, learncurve, load, model, @@ -17,7 +18,6 @@ from .config import Config from .dataset import DatasetConfig from .eval import EvalConfig -from .generate import generate from .learncurve import LearncurveConfig from .model import ModelConfig from .predict import PredictConfig From 0b314ba5319b2f6268208194cd39210874f254f5 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Sun, 21 Sep 2025 23:18:05 -0400 Subject: [PATCH 22/40] Remove print statements from vak.cli.cli.configfile helper function --- src/vak/cli/cli.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/vak/cli/cli.py b/src/vak/cli/cli.py index ea534b951..468db7972 100644 --- a/src/vak/cli/cli.py +++ b/src/vak/cli/cli.py @@ -36,17 +36,6 @@ def prep(args): def configfile(args): - print( - f"Generating TOML configuration file of kind: {args.kind}" - ) - if args.add_prep: - print( - f"Will add prep table" - ) - else: - print( - f"Will not add prep table" - ) from ..config.generate import generate generate( kind=args.kind, From 1b14a341bdcaf2dfa445d665a919a8e48baf9017 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Mon, 22 Sep 2025 11:57:24 -0400 Subject: [PATCH 23/40] Fix how we get command names to validate in config.validators.are_tables_valid --- src/vak/config/validators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/config/validators.py b/src/vak/config/validators.py index f349db746..f4f519286 100644 --- a/src/vak/config/validators.py +++ b/src/vak/config/validators.py @@ -79,7 +79,7 @@ def are_tables_valid(config_dict, toml_path=None): from ..cli.cli import CLI_COMMANDS # avoid circular import cli_commands_besides_prep = [ - command for command in CLI_COMMANDS if command != "prep" + command.name for command in CLI_COMMANDS if command.name != "prep" ] tables_that_are_commands_besides_prep = [ table for table in tables if table in cli_commands_besides_prep From 1ee8cd1122eddcb5758a39b1ca999caf10740e78 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Mon, 22 Sep 2025 11:58:26 -0400 Subject: [PATCH 24/40] Fix how we set default for dst in vak.config.generate --- src/vak/config/generate.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/vak/config/generate.py b/src/vak/config/generate.py index c46c3629c..6daf6c57c 100644 --- a/src/vak/config/generate.py +++ b/src/vak/config/generate.py @@ -20,14 +20,9 @@ def generate( kind: str, add_prep: bool = False, - dst: str | pathlib.Path = pathlib.Path.cwd(), + dst: str | pathlib.Path | None = None, ) -> None: - """Generate a TOML configuration file - - This is the function called by - :func:`vak.cli.cli.generate` - when a user runs the command ``vak configfile`` - using the command-line interface. + """Generate a TOML configuration file for :mod:`vak` Parameters ---------- @@ -43,7 +38,20 @@ def generate( or a directory, in which case a default filename will be used. The default `dst` is the current working directory. + + Notes + ----- + This is the function called by + :func:`vak.cli.cli.generate` + when a user runs the command ``vak configfile`` + using the command-line interface. + """ + if dst is None: + # we can't make this the default value of the parameter in the function signature + # since it would get the value at import time, and we need the value at runtime + dst = pathlib.Path.cwd() + dst = pathlib.Path(dst) if not dst.is_dir() and dst.exists(): raise ValueError( From 6963d691e78464507ca2e8e3ee3bbd122ba1ed4b Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Mon, 22 Sep 2025 11:59:23 -0400 Subject: [PATCH 25/40] Remove unit test in tests/test_cli/test_cli.py that was too tightly coupled with implementation details --- tests/test_cli/test_cli.py | 65 -------------------------------------- 1 file changed, 65 deletions(-) diff --git a/tests/test_cli/test_cli.py b/tests/test_cli/test_cli.py index da81cd45d..f2cac331c 100644 --- a/tests/test_cli/test_cli.py +++ b/tests/test_cli/test_cli.py @@ -94,68 +94,3 @@ def test_cli( # so we need to test indirectly with `arg_list` passed into `main` vak.cli.cli.cli(args) mock_cli_function.assert_called() - - -@pytest.mark.parametrize( - 'args_list, cli_helper_function, module, function_name', - [ - ( - ['prep', DUMMY_CONFIGFILE_STR], - vak.cli.cli.prep, - vak.cli.prep, - "prep", - ), - ( - ['train', DUMMY_CONFIGFILE_STR], - vak.cli.cli.train, - vak.cli.train, - "train", - ), - ( - ['learncurve', DUMMY_CONFIGFILE_STR], - vak.cli.cli.learncurve, - vak.cli.learncurve, - "learning_curve", - ), - ( - ['eval', DUMMY_CONFIGFILE_STR], - vak.cli.cli.eval, - vak.cli.eval, - "eval", - ), - ( - ['predict', DUMMY_CONFIGFILE_STR], - vak.cli.cli.predict, - vak.cli.predict, - "predict", - ), - ( - ['configfile', 'train', '--add-prep', '--dst', DUMMY_CONFIGFILE_STR], - vak.cli.cli.configfile, - vak.config.generate, - "generate", - ), - ] -) -def test_cli_helper_functions( - args_list, cli_helper_function, module, function_name, parser -): - """Test that helper functions we use to map commands to actual functions in the cli module - or elsewhere call those functions as expected""" - # this feels like I'm testing low-level implementation details - # but I feel like I should have some unit tests for what's happening in this module - args = parser.parse_args(args_list) - - with mock.patch.object(module, function_name, autospec=True) as mock_cli_function: - cli_helper_function(args) - - if args.command == "configfile": - mock_call = mock.call( - kind=args.kind, add_prep=args.add_prep, dst=args.dst - ) - else: - mock_call = mock.call(toml_path=args.configfile) - - assert mock_cli_function.mock_calls == [ - mock_call - ] From f2fbf8da9adc00def5381d059ee3b77857f8231d Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Mon, 22 Sep 2025 11:59:38 -0400 Subject: [PATCH 26/40] WIP: Add tests/test_config/test_generate.py --- tests/test_config/test_generate.py | 52 ++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 tests/test_config/test_generate.py diff --git a/tests/test_config/test_generate.py b/tests/test_config/test_generate.py new file mode 100644 index 000000000..edcc872d9 --- /dev/null +++ b/tests/test_config/test_generate.py @@ -0,0 +1,52 @@ +import os + +import pytest + +import vak.config.generate + + +@pytest.mark.parametrize( + 'kind, add_prep, dst_name', + [ + ( + "train", + False, + None + ) + ] +) +def test_generate(kind, add_prep, dst_name, tmp_path): + """Test :func:`vak.config.generate.generate`""" + # FIXME: handle case where `dst` is a filename -- handle .toml extension + if dst_name is None: + dst = tmp_path / "tmp-dst-None" + else: + dst = tmp_path / dst_name + dst.mkdir() + + if dst_name is None: + os.chdir(dst) + vak.config.generate.generate(kind=kind, add_prep=add_prep) + else: + dst = tmp_path / dst + vak.config.generate.generate(kind=kind, add_prep=add_prep, dst=dst) + + if dst.is_dir(): + # we need to get the actual generated TOML + generated_toml_path = sorted(dst.glob("*toml")) + assert len(generated_toml_path) == 1 + generated_toml_path = generated_toml_path[0] + else: + generated_toml_path = dst + + cfg = vak.config.Config.from_toml_path(generated_toml_path) + assert hasattr(cfg, kind) + if add_prep: + assert hasattr(cfg, "prep") + else: + assert not hasattr(cfg, "prep") + + +def test_generate_raises(): + # FIXME: test we raise error if dst already exists + assert False \ No newline at end of file From ec78ccc0bb2ca3f9d79401a644163d1d3c286ded Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 24 Sep 2025 19:25:55 -0400 Subject: [PATCH 27/40] WIP: add test_configfile_command to tests/test__main__.py --- tests/test___main__.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/test___main__.py b/tests/test___main__.py index 3a6cc9e9f..00b2d2304 100644 --- a/tests/test___main__.py +++ b/tests/test___main__.py @@ -21,13 +21,23 @@ ] ) def test_main(args_list): - """Test that :func:`vak.__main__.main` calls the function we expect through :func:`vak.cli.cli`""" + """Test that :func:`vak.__main__.main` calls the function we expect through :func:`vak.cli.cli` + + Notes + ----- + We mock these and call it a unit test + because actually calling and running :func:vak.cli.prep` + would be expensive. + + The exception is `vak configfile` + that we test directly (in other test functions below). + """ command = args_list[0] mock_cli_function = mock.Mock(name=f'mock_{command}') with mock.patch.dict( vak.cli.cli.CLI_COMMAND_FUNCTION_MAP, {command: mock_cli_function} ): - # we can't do this with `subprocess` since the function won't be mocked in the subprocess, + # wAFAICT e can't do this with `subprocess` since the function won't be mocked in the subprocess, # so we need to test indirectly with `arg_list` passed into `main` vak.__main__.main(args_list) mock_cli_function.assert_called() @@ -43,3 +53,10 @@ def test___main__prints_help_with_no_args(parser, capsys): output = result.stdout.rstrip() assert output == expected_output + + +def test_configfile_command(): + # FIXME: copy whatever unit tests we write for `vak.config.generate.generate` + # FIXME: except we change the actual part of the test where we call the function + # FIXME: and we're going to use an `args_list` instead of providing parameters directly + assert False \ No newline at end of file From 9658c0f97dbd211165df132eaa74eb9419ea2fff Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 24 Sep 2025 20:00:33 -0400 Subject: [PATCH 28/40] Remove prep section from a 'template' configfile that's not supposed to have it --- .../configfile_train.toml | 45 ------------------- 1 file changed, 45 deletions(-) diff --git a/src/vak/config/_toml_config_templates/configfile_train.toml b/src/vak/config/_toml_config_templates/configfile_train.toml index 1aa0dc45f..a6300d1d3 100644 --- a/src/vak/config/_toml_config_templates/configfile_train.toml +++ b/src/vak/config/_toml_config_templates/configfile_train.toml @@ -1,48 +1,3 @@ -# [vak.prep]: options for preparing dataset -[vak.prep] -# dataset_type: corresponds to the model family such as "frame classification" or "parametric umap" -dataset_type = "frame classification" -# input_type: input to model, either audio ("audio") or spectrogram ("spect") -input_type = "spect" -# data_dir: directory with data to use when preparing dataset -data_dir = "/PATH/TO/FOLDER/gyor6/032212" -# output_dir: directory where dataset will be created (as a sub-directory within output_dir) -output_dir = "/PATH/TO/FOLDER/prep/train" -# audio_format: format of audio, either wav or cbin -audio_format = "wav" -# annot_format: format of annotations -annot_format = "simple-seq" -# labelset: string or array with unique set of labels used in annotations -labelset = "iabcdefghjk" -# train_dur: duration of training split in dataset, in seconds -train_dur = 2000 -# val_dur: duration of validation split in dataset, in seconds -val_dur = 170 -# test_dur: duration of test split in dataset, in seconds -test_dur = 350 - -# [vak.prep.spect_params]: parameters for computing spectrograms -[vak.prep.spect_params] -# fft_size: size of window used for Fast Fourier Transform, in number of samples -fft_size = 512 -# step_size: size of step to take when computing spectra with FFT for spectrogram -# also known as hop size -step_size = 64 -# qualitatively, we find that log transforming the spectrograms improves performance; -# think of this as increasing the contrast between high power and low power regions -transform_type = "log_spect" -# specifying cutoff frequencies of the spectrogram can (1) make the model more -# computationally efficient and (2) improve performance by only fitting the model -# to parts of the spectrum that are relevant for sounds of interest. -# Note these cutoffs are applied by computing the whole spectrogram first -# and then throwing away frequencies above and below the cutoffs; -# we do not apply a bandpass filter to the audio. -freq_cutoffs = [500, 8000] -# Note that for the TweetyNet model, the default is to set the hidden_size of the RNN -# equal to the input_size, so if you reduce the size of the spectrogram, this will reduce the -# hidden size of the RNN. If you observe impaired performance of TweetyNet after applying the frequency cutoffs, -# consider manually specifying a larger hidden (see `[vak.train.model.TweetyNet]` table below). - # [vak.train]: options for training model [vak.train] # root_results_dir: directory where results should be saved, as a sub-directory within `root_results_dir` From 0836858c4505ba55ef7eb0ebae9d2bf5b12106af Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 24 Sep 2025 20:21:15 -0400 Subject: [PATCH 29/40] Finish writing unit tests in tests/test_config/test_generate.py --- tests/test_config/test_generate.py | 158 +++++++++++++++++++++++++++-- 1 file changed, 148 insertions(+), 10 deletions(-) diff --git a/tests/test_config/test_generate.py b/tests/test_config/test_generate.py index edcc872d9..cfb6feea0 100644 --- a/tests/test_config/test_generate.py +++ b/tests/test_config/test_generate.py @@ -1,4 +1,5 @@ import os +import tempfile import pytest @@ -8,11 +9,130 @@ @pytest.mark.parametrize( 'kind, add_prep, dst_name', [ + # ---- train ( "train", False, None - ) + ), + ( + "train", + True, + None + ), + ( + "train", + False, + "configs-dir" + ), + ( + "train", + True, + "configs-dir" + ), + ( + "train", + False, + "configs-dir/config.toml" + ), + ( + "train", + True, + "configs-dir/config.toml" + ), + # ---- eval + ( + "eval", + False, + None + ), + ( + "eval", + True, + None + ), + ( + "eval", + False, + "configs-dir" + ), + ( + "eval", + True, + "configs-dir" + ), + ( + "eval", + False, + "configs-dir/config.toml" + ), + ( + "eval", + True, + "configs-dir/config.toml" + ), + # ---- predict + ( + "predict", + False, + None + ), + ( + "predict", + True, + None + ), + ( + "predict", + False, + "configs-dir" + ), + ( + "predict", + True, + "configs-dir" + ), + ( + "predict", + False, + "configs-dir/config.toml" + ), + ( + "predict", + True, + "configs-dir/config.toml" + ), + # ---- learncurve + ( + "learncurve", + False, + None + ), + ( + "learncurve", + True, + None + ), + ( + "learncurve", + False, + "configs-dir" + ), + ( + "learncurve", + True, + "configs-dir" + ), + ( + "learncurve", + False, + "configs-dir/config.toml" + ), + ( + "learncurve", + True, + "configs-dir/config.toml" + ), ] ) def test_generate(kind, add_prep, dst_name, tmp_path): @@ -22,13 +142,17 @@ def test_generate(kind, add_prep, dst_name, tmp_path): dst = tmp_path / "tmp-dst-None" else: dst = tmp_path / dst_name - dst.mkdir() + if dst.suffix == ".toml": + # if dst ends with a toml extension + # then its *parent* is the dir we need to make + dst.parent.mkdir() + else: + dst.mkdir() if dst_name is None: os.chdir(dst) vak.config.generate.generate(kind=kind, add_prep=add_prep) else: - dst = tmp_path / dst vak.config.generate.generate(kind=kind, add_prep=add_prep, dst=dst) if dst.is_dir(): @@ -38,15 +162,29 @@ def test_generate(kind, add_prep, dst_name, tmp_path): generated_toml_path = generated_toml_path[0] else: generated_toml_path = dst + # next line: the rest of the assertions would fail if this one did + # but we're being super explicit here: + # if we specified a file name for dst then it should exist as a file + assert generated_toml_path.exists() - cfg = vak.config.Config.from_toml_path(generated_toml_path) - assert hasattr(cfg, kind) + # we can't load with `vak.config.Config.from_toml_path` + # because the generated config doesn't have a [vak.dataset.path] key-value pair yet, + # and the corresponding attrs class that represents that table will throw an error. + # So we load as a Python dict and check the expected keys are there. + # I don't have any better ideas at the moment for how to test + cfg_dict = vak.config.load._load_toml_from_path(generated_toml_path) + # N.B. that `vak.config.load._load_toml_from_path` accesses the top-level key "vak" + # and returns the result of that, so we don't need to do something like `cfg_dict["vak"]["prep"]` + assert kind in cfg_dict if add_prep: - assert hasattr(cfg, "prep") + assert "prep" in cfg_dict else: - assert not hasattr(cfg, "prep") + assert "prep" not in cfg_dict -def test_generate_raises(): - # FIXME: test we raise error if dst already exists - assert False \ No newline at end of file +def test_generate_raises(tmp_path): + dst = tmp_path / "fake.config.toml" + with dst.open("w") as fp: + fp.write("[fake.config]") + with pytest.raises(FileExistsError): + vak.config.generate.generate("train", add_prep=True, dst=dst) From 11fc46e79354f3e93ad5a3ba04785f98a108c6c5 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 24 Sep 2025 20:21:38 -0400 Subject: [PATCH 30/40] Raise FileExistsError in vak.config.generate, not ValueError, if dst already exists --- src/vak/config/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/config/generate.py b/src/vak/config/generate.py index 6daf6c57c..db10d0a3b 100644 --- a/src/vak/config/generate.py +++ b/src/vak/config/generate.py @@ -54,7 +54,7 @@ def generate( dst = pathlib.Path(dst) if not dst.is_dir() and dst.exists(): - raise ValueError( + raise FileExistsError( f"Destination for generated config file `dst` is already a file that exists:\n{dst}\n" "Please specify a value for the `--dst` argument that will not overwrite an existing file." ) From 524d96526d7d6a01bc636f1a663641eeba1174b1 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 24 Sep 2025 20:28:20 -0400 Subject: [PATCH 31/40] Validate extension of `dst` in `vak.config.generate.generate` -- make sure it is .toml --- src/vak/config/generate.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/vak/config/generate.py b/src/vak/config/generate.py index db10d0a3b..f171cf331 100644 --- a/src/vak/config/generate.py +++ b/src/vak/config/generate.py @@ -22,7 +22,7 @@ def generate( add_prep: bool = False, dst: str | pathlib.Path | None = None, ) -> None: - """Generate a TOML configuration file for :mod:`vak` + """Generate a TOML configuration file for :mod:`vak`. Parameters ---------- @@ -59,6 +59,11 @@ def generate( "Please specify a value for the `--dst` argument that will not overwrite an existing file." ) + if not dst.is_dir() and dst.suffix != ".toml": + raise ValueError( + f"If `dst` is a path that ends in a filename, not a directory, then the extension must be '.toml', but was: {dst.suffix}" + ) + # for now, we "add a prep section" by using a naming convention # and loading an existing toml file that has a `[vak.prep]` table if add_prep: From 1d1eddbb957f7c7899d7e9a1e63062a6ae2c8d19 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 24 Sep 2025 20:28:46 -0400 Subject: [PATCH 32/40] Add unit test to test that vak.config.generate.generate raises ValueError when dst is a path to a file but extension is not '.toml' --- tests/test_config/test_generate.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/test_config/test_generate.py b/tests/test_config/test_generate.py index cfb6feea0..87ef97a5d 100644 --- a/tests/test_config/test_generate.py +++ b/tests/test_config/test_generate.py @@ -182,9 +182,22 @@ def test_generate(kind, add_prep, dst_name, tmp_path): assert "prep" not in cfg_dict -def test_generate_raises(tmp_path): +def test_generate_raises_FileExistsError(tmp_path): + """Test that func:`vak.config.generate.generate` raises + a FileExistsError if `dst` already exists""" dst = tmp_path / "fake.config.toml" with dst.open("w") as fp: fp.write("[fake.config]") with pytest.raises(FileExistsError): vak.config.generate.generate("train", add_prep=True, dst=dst) + + + +def test_generate_raises_ValueError(tmp_path): + """Test that :func:`vak.config.generate.generate` raises + a ValueError if `dst` is a path to a filename but the extension is not '.toml'""" + dst = tmp_path / "fake.config.json" + with dst.open("w") as fp: + fp.write("[fake.config]") + with pytest.raises(FileExistsError): + vak.config.generate.generate("train", add_prep=True, dst=dst) From 852116c6fde6124e34c1f56cc0646b37dcfdc985 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 24 Sep 2025 20:33:57 -0400 Subject: [PATCH 33/40] Add examples section to docstring of vak.config.generate.generate --- src/vak/config/generate.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/vak/config/generate.py b/src/vak/config/generate.py index f171cf331..df57ffe31 100644 --- a/src/vak/config/generate.py +++ b/src/vak/config/generate.py @@ -39,13 +39,32 @@ def generate( will be used. The default `dst` is the current working directory. + Examples + -------- + + Generate a TOML configuration file in the current working directory to prepare a dataset and train a model. + + >>> vak.config.generate.generate("train", add_prep=True) + + Generate a TOML configuration file in a specified directory to train a model, e.g. on an existing dataset. + + >>> import pathlib + >>> dst = pathlib.Path("./data/configs") + >>> vak.config.generate.generate("train", add_prep=True, dst=dst) + + Generate a TOML configuration file with a specific file name to train a model, e.g. on an existing dataset. + + >>> import pathlib + >>> dst = pathlib.Path("./data/configs/train-bfsongrepo.toml") + >>> vak.config.generate.generate("train", add_prep=True, dst=dst) + + Notes ----- This is the function called by :func:`vak.cli.cli.generate` when a user runs the command ``vak configfile`` using the command-line interface. - """ if dst is None: # we can't make this the default value of the parameter in the function signature From de1ebfdcfdd278953a43f5eaac3e5e929a1836ce Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 24 Sep 2025 20:36:35 -0400 Subject: [PATCH 34/40] Rename -> vak/config/_generate.py --- src/vak/config/{generate.py => _generate.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/vak/config/{generate.py => _generate.py} (100%) diff --git a/src/vak/config/generate.py b/src/vak/config/_generate.py similarity index 100% rename from src/vak/config/generate.py rename to src/vak/config/_generate.py From 5248c36bc6812944e33a59d18790080c56e2d751 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 24 Sep 2025 20:36:50 -0400 Subject: [PATCH 35/40] Import generate function from _generate module in vak/config/__init__.py --- src/vak/config/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index 5318c8d37..4c8ab1947 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -4,7 +4,6 @@ config, dataset, eval, - generate, learncurve, load, model, @@ -18,6 +17,7 @@ from .config import Config from .dataset import DatasetConfig from .eval import EvalConfig +from ._generate import generate from .learncurve import LearncurveConfig from .model import ModelConfig from .predict import PredictConfig From 393ef215a9faa7aab47d577128eaabab6d3e1cbf Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 24 Sep 2025 20:38:39 -0400 Subject: [PATCH 36/40] Rewrite examples in vak.config._generate.generate docstring to use shorter vak.config.generate --- src/vak/config/_generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/vak/config/_generate.py b/src/vak/config/_generate.py index df57ffe31..4393cb3e2 100644 --- a/src/vak/config/_generate.py +++ b/src/vak/config/_generate.py @@ -44,19 +44,19 @@ def generate( Generate a TOML configuration file in the current working directory to prepare a dataset and train a model. - >>> vak.config.generate.generate("train", add_prep=True) + >>> vak.config.generate("train", add_prep=True) Generate a TOML configuration file in a specified directory to train a model, e.g. on an existing dataset. >>> import pathlib >>> dst = pathlib.Path("./data/configs") - >>> vak.config.generate.generate("train", add_prep=True, dst=dst) + >>> vak.config.generate("train", add_prep=True, dst=dst) Generate a TOML configuration file with a specific file name to train a model, e.g. on an existing dataset. >>> import pathlib >>> dst = pathlib.Path("./data/configs/train-bfsongrepo.toml") - >>> vak.config.generate.generate("train", add_prep=True, dst=dst) + >>> vak.config.generate("train", add_prep=True, dst=dst) Notes From af35c1af2930bae731beb8f1f7d73d9c892a4828 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Wed, 24 Sep 2025 20:38:52 -0400 Subject: [PATCH 37/40] Fix name -> vak.config.generate in test_config/test_generate.py --- tests/test_config/test_generate.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_config/test_generate.py b/tests/test_config/test_generate.py index 87ef97a5d..96cd18a65 100644 --- a/tests/test_config/test_generate.py +++ b/tests/test_config/test_generate.py @@ -3,7 +3,7 @@ import pytest -import vak.config.generate +import vak.config @pytest.mark.parametrize( @@ -151,9 +151,9 @@ def test_generate(kind, add_prep, dst_name, tmp_path): if dst_name is None: os.chdir(dst) - vak.config.generate.generate(kind=kind, add_prep=add_prep) + vak.config.generate(kind=kind, add_prep=add_prep) else: - vak.config.generate.generate(kind=kind, add_prep=add_prep, dst=dst) + vak.config.generate(kind=kind, add_prep=add_prep, dst=dst) if dst.is_dir(): # we need to get the actual generated TOML @@ -189,7 +189,7 @@ def test_generate_raises_FileExistsError(tmp_path): with dst.open("w") as fp: fp.write("[fake.config]") with pytest.raises(FileExistsError): - vak.config.generate.generate("train", add_prep=True, dst=dst) + vak.config.generate("train", add_prep=True, dst=dst) @@ -200,4 +200,4 @@ def test_generate_raises_ValueError(tmp_path): with dst.open("w") as fp: fp.write("[fake.config]") with pytest.raises(FileExistsError): - vak.config.generate.generate("train", add_prep=True, dst=dst) + vak.config.generate("train", add_prep=True, dst=dst) From 5974364c61c9d1a4a138304e9273597c42229073 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 25 Sep 2025 09:44:34 -0400 Subject: [PATCH 38/40] Remove #FIXME comment that I fixed in test_generate.py --- tests/test_config/test_generate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_config/test_generate.py b/tests/test_config/test_generate.py index 96cd18a65..23ebfcc16 100644 --- a/tests/test_config/test_generate.py +++ b/tests/test_config/test_generate.py @@ -137,7 +137,6 @@ ) def test_generate(kind, add_prep, dst_name, tmp_path): """Test :func:`vak.config.generate.generate`""" - # FIXME: handle case where `dst` is a filename -- handle .toml extension if dst_name is None: dst = tmp_path / "tmp-dst-None" else: From 1b260e1b24378409c262a49ea786e13139d0f8f8 Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 25 Sep 2025 09:50:01 -0400 Subject: [PATCH 39/40] Fix import in vak/cli/cli.py --- src/vak/cli/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vak/cli/cli.py b/src/vak/cli/cli.py index 468db7972..ef46489aa 100644 --- a/src/vak/cli/cli.py +++ b/src/vak/cli/cli.py @@ -36,8 +36,8 @@ def prep(args): def configfile(args): - from ..config.generate import generate - generate( + from .. import config + config.generate( kind=args.kind, add_prep=args.add_prep, dst=args.dst, From ab06a698e3c44e1a42ff549b91d567134671728f Mon Sep 17 00:00:00 2001 From: David Nicholson Date: Thu, 25 Sep 2025 09:51:07 -0400 Subject: [PATCH 40/40] Write unit test: tests/test___main__.py:test_configfile_command --- tests/test___main__.py | 183 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 178 insertions(+), 5 deletions(-) diff --git a/tests/test___main__.py b/tests/test___main__.py index 00b2d2304..7d9737d06 100644 --- a/tests/test___main__.py +++ b/tests/test___main__.py @@ -1,3 +1,4 @@ +import os import subprocess from unittest import mock @@ -55,8 +56,180 @@ def test___main__prints_help_with_no_args(parser, capsys): assert output == expected_output -def test_configfile_command(): - # FIXME: copy whatever unit tests we write for `vak.config.generate.generate` - # FIXME: except we change the actual part of the test where we call the function - # FIXME: and we're going to use an `args_list` instead of providing parameters directly - assert False \ No newline at end of file +@pytest.mark.parametrize( + 'kind, add_prep, dst_name', + [ + # ---- train + ( + "train", + False, + None + ), + ( + "train", + True, + None + ), + ( + "train", + False, + "configs-dir" + ), + ( + "train", + True, + "configs-dir" + ), + ( + "train", + False, + "configs-dir/config.toml" + ), + ( + "train", + True, + "configs-dir/config.toml" + ), + # ---- eval + ( + "eval", + False, + None + ), + ( + "eval", + True, + None + ), + ( + "eval", + False, + "configs-dir" + ), + ( + "eval", + True, + "configs-dir" + ), + ( + "eval", + False, + "configs-dir/config.toml" + ), + ( + "eval", + True, + "configs-dir/config.toml" + ), + # ---- predict + ( + "predict", + False, + None + ), + ( + "predict", + True, + None + ), + ( + "predict", + False, + "configs-dir" + ), + ( + "predict", + True, + "configs-dir" + ), + ( + "predict", + False, + "configs-dir/config.toml" + ), + ( + "predict", + True, + "configs-dir/config.toml" + ), + # ---- learncurve + ( + "learncurve", + False, + None + ), + ( + "learncurve", + True, + None + ), + ( + "learncurve", + False, + "configs-dir" + ), + ( + "learncurve", + True, + "configs-dir" + ), + ( + "learncurve", + False, + "configs-dir/config.toml" + ), + ( + "learncurve", + True, + "configs-dir/config.toml" + ), + ] +) +def test_configfile_command(kind, add_prep, dst_name, tmp_path): + """Test :func:`vak.config.generate.generate`""" + if dst_name is None: + dst = tmp_path / "tmp-dst-None" + else: + dst = tmp_path / dst_name + if dst.suffix == ".toml": + # if dst ends with a toml extension + # then its *parent* is the dir we need to make + dst.parent.mkdir() + else: + dst.mkdir() + + if dst_name is None: + os.chdir(dst) + + args = ["vak", "configfile", kind] + if add_prep: + args = args + ["--add-prep"] + if dst_name is not None: + args = args + ["--dst", str(dst)] + subprocess.run(args) + + if dst.is_dir(): + # we need to get the actual generated TOML + generated_toml_path = sorted(dst.glob("*toml")) + assert len(generated_toml_path) == 1 + generated_toml_path = generated_toml_path[0] + else: + generated_toml_path = dst + # next line: the rest of the assertions would fail if this one did + # but we're being super explicit here: + # if we specified a file name for dst then it should exist as a file + assert generated_toml_path.exists() + + # we can't load with `vak.config.Config.from_toml_path` + # because the generated config doesn't have a [vak.dataset.path] key-value pair yet, + # and the corresponding attrs class that represents that table will throw an error. + # So we load as a Python dict and check the expected keys are there. + # I don't have any better ideas at the moment for how to test + cfg_dict = vak.config.load._load_toml_from_path(generated_toml_path) + # N.B. that `vak.config.load._load_toml_from_path` accesses the top-level key "vak" + # and returns the result of that, so we don't need to do something like `cfg_dict["vak"]["prep"]` + assert kind in cfg_dict + if add_prep: + assert "prep" in cfg_dict + else: + assert "prep" not in cfg_dict