diff --git a/encoderfile-runtime/src/main.rs b/encoderfile-runtime/src/main.rs index 1e61b689..003d7d42 100644 --- a/encoderfile-runtime/src/main.rs +++ b/encoderfile-runtime/src/main.rs @@ -12,8 +12,7 @@ use encoderfile::{ ModelType, model_type::{Embedding, SentenceEmbedding, SequenceClassification, TokenClassification}, }, - format::codec::EncoderfileCodec, - runtime::{EncoderfileLoader, EncoderfileState}, + runtime::{EncoderfileLoader, EncoderfileState, load_assets}, transport::cli::Cli, }; @@ -77,10 +76,3 @@ async fn entrypoint<'a, R: Read + Seek>(loader: &mut EncoderfileLoader<'a, R>) - ), } } - -fn load_assets<'a, R: Read + Seek>(file: &'a mut R) -> Result> { - let encoderfile = EncoderfileCodec::read(file)?; - let loader = EncoderfileLoader::new(encoderfile, file); - - Ok(loader) -} diff --git a/encoderfile/Cargo.toml b/encoderfile/Cargo.toml index 447f9c2e..c9579d76 100644 --- a/encoderfile/Cargo.toml +++ b/encoderfile/Cargo.toml @@ -52,6 +52,11 @@ name = "test_build_encoderfile" path = "tests/integration/test_build.rs" required-features = ["cli", "dev-utils"] +[[test]] +name = "test_inspect_encoderfile" +path = "tests/integration/test_inspect.rs" +required-features = ["cli", "dev-utils"] + [package] name = "encoderfile" version = "0.4.0-rc.1" diff --git a/encoderfile/src/build_cli/cli/inspect.rs b/encoderfile/src/build_cli/cli/inspect.rs new file mode 100644 index 00000000..ce3cebc2 --- /dev/null +++ b/encoderfile/src/build_cli/cli/inspect.rs @@ -0,0 +1,35 @@ +use std::fs::File; +use std::io::BufReader; +use std::path::Path; + +use serde::Serialize; +use serde_json::to_string_pretty; + +use anyhow::Result; + +use crate::{ + common::{Config, ModelConfig}, + runtime::load_assets, +}; + +// inspect struct with info + +#[derive(Debug, Serialize)] +pub struct InspectInfo { + pub model_config: ModelConfig, + pub encoderfile_config: Config, +} + +pub fn inspect_encoderfile(path_str: &String) -> Result { + let file = File::open(Path::new(&path_str))?; + let mut file = BufReader::new(file); + let mut loader = load_assets(&mut file)?; + + let config = loader.encoderfile_config()?; + let model_config = loader.model_config()?; + + Ok(to_string_pretty(&InspectInfo { + model_config, + encoderfile_config: config, + })?) +} diff --git a/encoderfile/src/build_cli/cli/mod.rs b/encoderfile/src/build_cli/cli/mod.rs index 1752e78c..46d45c27 100644 --- a/encoderfile/src/build_cli/cli/mod.rs +++ b/encoderfile/src/build_cli/cli/mod.rs @@ -4,11 +4,14 @@ use std::path::PathBuf; use clap_derive::{Args, Parser, Subcommand}; mod build; +mod inspect; mod runtime; #[cfg(feature = "dev-utils")] pub use build::test_build_args; +pub use inspect::inspect_encoderfile; + #[derive(Debug, Parser)] pub struct Cli { #[command(subcommand)] @@ -48,6 +51,11 @@ pub enum Commands { #[arg(short = 'm', long = "model-type", help = "Model type")] model_type: String, }, + #[command(about = "Inspect the metadata of an encoderfile.")] + Inspect { + #[arg(required = true, help = "Path to encoderfile.")] + path: String, + }, } impl Commands { @@ -60,6 +68,10 @@ impl Commands { } Self::Runtime(r) => r.execute(global), Self::NewTransform { model_type } => super::transforms::new_transform(model_type), + Self::Inspect { path } => { + println!("{}", inspect::inspect_encoderfile(&path)?); + Ok(()) + } } } } diff --git a/encoderfile/src/format/footer.rs b/encoderfile/src/format/footer.rs index fc4c68dc..9ef20388 100644 --- a/encoderfile/src/format/footer.rs +++ b/encoderfile/src/format/footer.rs @@ -5,7 +5,7 @@ use std::io::{Read, Seek, SeekFrom, Write}; pub const FLAG_METADATA_PROTOBUF: u32 = 1 << 0; #[repr(C)] -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, serde::Serialize)] pub struct EncoderfileFooter { pub magic: [u8; 8], pub format_version: u32, diff --git a/encoderfile/src/runtime/loader.rs b/encoderfile/src/runtime/loader.rs index 5ed53a83..c88598c9 100644 --- a/encoderfile/src/runtime/loader.rs +++ b/encoderfile/src/runtime/loader.rs @@ -6,7 +6,7 @@ use ort::session::Session; use crate::{ common::{Config, ModelConfig, ModelType}, - format::{assets::AssetKind, container::Encoderfile}, + format::{assets::AssetKind, codec::EncoderfileCodec, container::Encoderfile}, generated::manifest::TransformType, runtime::TokenizerService, }; @@ -113,3 +113,10 @@ impl<'a, R: Read + Seek> EncoderfileLoader<'a, R> { } } } + +pub fn load_assets<'a, R: Read + Seek>(file: &'a mut R) -> Result> { + let encoderfile = EncoderfileCodec::read(file)?; + let loader = EncoderfileLoader::new(encoderfile, file); + + Ok(loader) +} diff --git a/encoderfile/src/runtime/mod.rs b/encoderfile/src/runtime/mod.rs index d77ae5fa..41d2bf86 100644 --- a/encoderfile/src/runtime/mod.rs +++ b/encoderfile/src/runtime/mod.rs @@ -5,7 +5,7 @@ mod loader; mod state; mod tokenizer; -pub use loader::EncoderfileLoader; +pub use loader::{EncoderfileLoader, load_assets}; pub use state::{AppState, EncoderfileState}; pub use tokenizer::TokenizerService; diff --git a/encoderfile/tests/integration/test_build.rs b/encoderfile/tests/integration/test_build.rs index f6a9577d..5e87ae06 100644 --- a/encoderfile/tests/integration/test_build.rs +++ b/encoderfile/tests/integration/test_build.rs @@ -12,11 +12,11 @@ use tempfile::tempdir; const BINARY_NAME: &str = "test.encoderfile"; -fn config(model_path: &Path, output_path: &Path) -> String { +fn config(model_name: &String, model_path: &Path, output_path: &Path) -> String { format!( r##" encoderfile: - name: test-model + name: {:?} path: {:?} model_type: token_classification output_path: {:?} @@ -36,7 +36,7 @@ encoderfile: return arr:softmax(3) end "##, - model_path, output_path + model_name, model_path, output_path ) } @@ -76,7 +76,11 @@ fn test_build_encoderfile() -> Result<()> { .expect("Failed to canonicalize base binary path"); // write encoderfile config - let config = config(tmp_model_path.as_path(), encoderfile_path.as_path()); + let config = config( + &String::from("test-model"), + tmp_model_path.as_path(), + encoderfile_path.as_path(), + ); fs::write(ef_config_path.as_path(), config.as_bytes()) .expect("Failed to write encoderfile config"); diff --git a/encoderfile/tests/integration/test_inspect.rs b/encoderfile/tests/integration/test_inspect.rs new file mode 100644 index 00000000..e4f8bc2f --- /dev/null +++ b/encoderfile/tests/integration/test_inspect.rs @@ -0,0 +1,167 @@ +use anyhow::{Context, Result, bail}; + +use encoderfile::build_cli::cli::{GlobalArguments, inspect_encoderfile}; +use std::{ + fs, + path::Path, + process::{Command, Output}, +}; +use tempfile::tempdir; + +const BINARY_NAME: &str = "test.encoderfile"; + +fn config(model_name: &String, model_path: &Path, output_path: &Path) -> String { + format!( + r##" +encoderfile: + name: {:?} + path: {:?} + model_type: token_classification + output_path: {:?} + transform: | + --- Applies a softmax across token classification logits. + --- Each token classification is normalized independently. + --- + --- Args: + --- arr (Tensor): A tensor of shape [batch_size, n_tokens, n_labels]. + --- The softmax is applied along the third axis (n_labels). + --- + --- Returns: + --- Tensor: The input tensor with softmax-normalized embeddings. + ---@param arr Tensor + ---@return Tensor + function Postprocess(arr) + return arr:softmax(3) + end + "##, + model_name, model_path, output_path + ) +} + +const MODEL_ASSETS_PATH: &str = "../models/token_classification"; + +#[test] +fn test_inspect_encoderfile() -> Result<()> { + let dir = tempdir()?; + let path = dir + .path() + .canonicalize() + .expect("Failed to canonicalize temp path"); + + let tmp_model_path = path.join("models").join("token_classification"); + + let ef_config_path = path.join("encoderfile.yml"); + let encoderfile_path = path.join(BINARY_NAME); + let model_name = String::from("some-custom-name"); + + // copy model assets to temp dir + copy_dir_all(MODEL_ASSETS_PATH, tmp_model_path.as_path()) + .expect("Failed to copy model assets to temp directory"); + + if !tmp_model_path.join("model.onnx").exists() { + bail!( + "Path {:?} does not exist", + tmp_model_path.join("model.onnx") + ); + } + + // compile base binary and copy to temp dir + let _ = Command::new("cargo") + .args(["build"]) + .status() + .expect("Failed to build encoderfile-runtime"); + + let base_binary_path = fs::canonicalize("../target/debug/encoderfile-runtime") + .expect("Failed to canonicalize base binary path"); + + let ef_binary_path = fs::canonicalize("../target/debug/encoderfile") + .expect("Failed to canonicalize base binary path"); + + // write encoderfile config + let config = config( + &model_name, + tmp_model_path.as_path(), + encoderfile_path.as_path(), + ); + + fs::write(ef_config_path.as_path(), config.as_bytes()) + .expect("Failed to write encoderfile config"); + + let build_args = + encoderfile::build_cli::cli::test_build_args(ef_config_path.as_path(), base_binary_path); + + // build encoderfile + let global_args = GlobalArguments::default(); + + build_args + .run(&global_args) + .context("Failed to build encoderfile")?; + + let ef_path_str = String::from( + encoderfile_path + .to_str() + .expect("Encoderfile path name failed to convert to string"), + ); + + let _inspect_output = inspect_encoderfile(&ef_path_str)?; + + let output = run_inspect_encoderfile( + ef_binary_path + .to_str() + .expect("Failed to create encoderfile binary path"), + &ef_path_str, + )?; + + let stdout = String::from_utf8(output.stdout)?; + let stderr = String::from_utf8(output.stderr)?; + + println!("STDOUT: {}", stdout); + println!("STDERR: {}", stderr); + + let inspect_output_json = serde_json::from_str::(&stdout) + .context("Failed to parse inspect output as JSON")?; + inspect_output_json + .get("encoderfile_config") + .and_then(|efc| efc.get("name")) + .and_then(|name| name.as_str()) + .filter(|name_str| *name_str == model_name.as_str()) + .ok_or_else(|| anyhow::anyhow!("Model name in inspect output does not match expected"))?; + + Ok(()) +} + +fn copy_dir_all(src: impl AsRef, dst: impl AsRef) -> anyhow::Result<()> { + let src = src.as_ref(); + let dst = dst.as_ref(); + + fs::create_dir_all(dst).context(format!("Failed to create directory {:?}", &dst))?; + + for entry in fs::read_dir(src)? { + let entry = entry?; + let ty = entry.file_type()?; + let dest_path = dst.join(entry.file_name()); + + if ty.is_dir() { + copy_dir_all(entry.path(), dest_path.as_path()).context(format!( + "Failed to copy {:?} to {:?}", + entry.path(), + dest_path.as_path() + ))?; + } else { + fs::copy(entry.path(), dest_path.as_path()).context(format!( + "Failed to copy {:?} to {:?}", + entry.path(), + dest_path.as_path() + ))?; + } + } + + Ok(()) +} + +fn run_inspect_encoderfile(path: &str, ef_path: &str) -> Result { + let mut cmd = Command::new(path); + cmd.arg("inspect").arg(ef_path); + println!("{:?}", cmd); + cmd.output().context("Failed inspect command") +}