From c2ef717c682846e9abb5be8e760ee98266aff2b6 Mon Sep 17 00:00:00 2001 From: ry2009 Date: Tue, 26 May 2026 18:51:17 -0400 Subject: [PATCH] Export uzu traces as safetensors --- crates/backend-uzu/Cargo.toml | 1 + crates/backend-uzu/src/lib.rs | 4 +- crates/backend-uzu/src/parameters/mod.rs | 4 +- .../src/parameters/safetensors_metadata.rs | 92 +- crates/backend-uzu/src/session/types/error.rs | 2 + .../parameters/safetensors_metadata_test.rs | 62 +- .../integration/tracer/trace_validator.rs | 922 ++++-------------- .../tracer/trace_validator_test.rs | 36 +- 8 files changed, 335 insertions(+), 788 deletions(-) diff --git a/crates/backend-uzu/Cargo.toml b/crates/backend-uzu/Cargo.toml index 1c340800d..407fb8dc1 100644 --- a/crates/backend-uzu/Cargo.toml +++ b/crates/backend-uzu/Cargo.toml @@ -65,6 +65,7 @@ rstest.workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] criterion = { workspace = true, default-features = true } proptest = { workspace = true, default-features = true } +tempfile.workspace = true [build-dependencies] anyhow.workspace = true diff --git a/crates/backend-uzu/src/lib.rs b/crates/backend-uzu/src/lib.rs index 5daa9b751..e0a59f304 100644 --- a/crates/backend-uzu/src/lib.rs +++ b/crates/backend-uzu/src/lib.rs @@ -26,14 +26,14 @@ pub use audio::{NanoCodecFsqRuntime, NanoCodecFsqRuntimeConfig}; pub use backends::common::{AllocationAccessError, allocation_copy_from_slice, allocation_to_vec}; pub use data_type::{ArrayElement, DataType}; pub use language_model::gumbel::{gumbel_float, revidx}; -pub use parameters::{ParameterLoader, read_safetensors_metadata}; +pub use parameters::{ParameterLoader, SafeTensorData, read_safetensors_metadata, write_safetensors}; pub use utils::{TOOLCHAIN_VERSION, VERSION}; #[cfg(feature = "tracing")] pub mod _private { pub use crate::{ classifier::Classifier, - config::{ModelConfig, ModelMetadata, ModelType}, + config::{ModelConfig, ModelMetadata, ModelType, TransformerLayerConfig}, encodable_block::{DecoderDecodeInput, Sampling}, forward_pass::{ cache_layers::CacheLayers, kv_cache_layer::KVCacheLayer, token_inputs::TokenInputs, traces::ActivationTrace, diff --git a/crates/backend-uzu/src/parameters/mod.rs b/crates/backend-uzu/src/parameters/mod.rs index b2715220f..f5c0568be 100644 --- a/crates/backend-uzu/src/parameters/mod.rs +++ b/crates/backend-uzu/src/parameters/mod.rs @@ -4,4 +4,6 @@ mod safetensors_metadata; // Re-export the safetensors header reader so other modules (e.g. decoder // runner) can estimate parameter memory before creating a Context. pub use loader::{ParameterLeaf, ParameterLoader, ParameterLoaderError, ParameterTree}; -pub use safetensors_metadata::{HeaderLoadingError, read_metadata as read_safetensors_metadata}; +pub use safetensors_metadata::{ + HeaderLoadingError, SafeTensorData, read_metadata as read_safetensors_metadata, write_safetensors, +}; diff --git a/crates/backend-uzu/src/parameters/safetensors_metadata.rs b/crates/backend-uzu/src/parameters/safetensors_metadata.rs index 2a027a470..cadb1d929 100644 --- a/crates/backend-uzu/src/parameters/safetensors_metadata.rs +++ b/crates/backend-uzu/src/parameters/safetensors_metadata.rs @@ -1,6 +1,10 @@ // This code is based on the safetensors implementation: https://docs.rs/safetensors/latest/src/safetensors/tensor.rs.html -use std::{collections::HashMap, fs::File}; +use std::{ + collections::HashMap, + fs::File, + io::{self, Write}, +}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -95,17 +99,21 @@ impl From for DataType { } } -impl From for Dtype { - fn from(dtype: DataType) -> Self { +impl TryFrom for Dtype { + type Error = DataType; + + fn try_from(dtype: DataType) -> Result { match dtype { - DataType::F16 => Dtype::F16, - DataType::BF16 => Dtype::BF16, - DataType::F32 => Dtype::F32, - DataType::I8 => Dtype::I8, - DataType::I32 => Dtype::I32, - DataType::I64 => Dtype::I64, - DataType::U64 => Dtype::U64, - _ => panic!("Unsupported dtype: {:?}", dtype), + DataType::F16 => Ok(Dtype::F16), + DataType::BF16 => Ok(Dtype::BF16), + DataType::F32 => Ok(Dtype::F32), + DataType::I8 => Ok(Dtype::I8), + DataType::U8 => Ok(Dtype::U8), + DataType::I32 => Ok(Dtype::I32), + DataType::U32 => Ok(Dtype::U32), + DataType::I64 => Ok(Dtype::I64), + DataType::U64 => Ok(Dtype::U64), + DataType::F64 | DataType::I4 | DataType::U4 | DataType::I16 | DataType::U16 => Err(dtype), } } } @@ -129,3 +137,65 @@ pub fn read_metadata(file: &File) -> Result<(usize, HashMetadata), HeaderLoading serde_json::from_str(string).map_err(|_| HeaderLoadingError::InvalidHeaderDeserialization)?; Ok((stop, metadata)) } + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct SafeTensorData { + pub name: String, + pub shape: Box<[usize]>, + pub data_type: DataType, + pub data: Box<[u8]>, +} + +pub fn write_safetensors( + writer: &mut W, + tensors: &[SafeTensorData], +) -> Result<(), io::Error> { + let mut sorted_tensors: Vec<&SafeTensorData> = tensors.iter().collect(); + sorted_tensors.sort_by(|left, right| { + right.data_type.size_in_bytes().cmp(&left.data_type.size_in_bytes()).then_with(|| left.name.cmp(&right.name)) + }); + let Some(first_tensor) = sorted_tensors.first() else { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "safetensors requires at least one tensor")); + }; + + let mut header = serde_json::Map::new(); + let mut offset = 0; + for tensor in sorted_tensors.iter() { + let dtype = Dtype::try_from(tensor.data_type) + .map_err(|dtype| io::Error::new(io::ErrorKind::InvalidInput, format!("unsupported dtype: {dtype:?}")))?; + let expected_len = tensor + .shape + .iter() + .try_fold(tensor.data_type.size_in_bytes(), |size, dim| size.checked_mul(*dim)) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "safetensors tensor byte size overflows usize") + })?; + if expected_len != tensor.data.len() { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "safetensors tensor shape does not match data")); + } + let end = offset + tensor.data.len(); + header.insert( + tensor.name.clone(), + serde_json::to_value(TensorInfo { + dtype, + shape: tensor.shape.to_vec(), + data_offsets: (offset, end), + })?, + ); + offset = end; + } + + let data_alignment = first_tensor.data_type.size_in_bytes().max(8); + let mut header_bytes = serde_json::to_vec(&header)?; + let padding = (data_alignment - header_bytes.len() % data_alignment) % data_alignment; + header_bytes.extend(std::iter::repeat_n(b' ', padding)); + + let header_len = u64::try_from(header_bytes.len()) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "safetensors header length does not fit into u64"))?; + writer.write_all(&header_len.to_le_bytes())?; + writer.write_all(&header_bytes)?; + for tensor in sorted_tensors { + writer.write_all(&tensor.data)?; + } + Ok(()) +} diff --git a/crates/backend-uzu/src/session/types/error.rs b/crates/backend-uzu/src/session/types/error.rs index 794454017..480f2647a 100644 --- a/crates/backend-uzu/src/session/types/error.rs +++ b/crates/backend-uzu/src/session/types/error.rs @@ -178,6 +178,8 @@ pub enum Error { InvalidTtsRunConfig(#[from] TtsRunConfigError), #[error("Unable to load model weights")] UnableToLoadWeights, + #[error("Unable to write trace")] + UnableToWriteTrace, #[error("Unable to load tokenizer")] UnableToLoadTokenizer, #[error("Model is too large to fit into available RAM")] diff --git a/crates/backend-uzu/tests/integration/parameters/safetensors_metadata_test.rs b/crates/backend-uzu/tests/integration/parameters/safetensors_metadata_test.rs index 4988e6ceb..e0f89f216 100644 --- a/crates/backend-uzu/tests/integration/parameters/safetensors_metadata_test.rs +++ b/crates/backend-uzu/tests/integration/parameters/safetensors_metadata_test.rs @@ -1,9 +1,53 @@ use std::fs::File; -use backend_uzu::read_safetensors_metadata; +use backend_uzu::{ + ArrayContextExt, DataType, ParameterLoader, SafeTensorData, + backends::{ + common::{Backend, Context, allocation_as_bytes}, + cpu::Cpu, + }, + read_safetensors_metadata, write_safetensors, +}; use crate::common::path::get_test_weights_path; +fn tensor_from_array( + name: &str, + array: &backend_uzu::Array, +) -> SafeTensorData { + SafeTensorData { + name: name.to_string(), + shape: array.shape().into(), + data_type: array.data_type(), + data: allocation_as_bytes(array.allocation()).into(), + } +} + +#[test] +fn test_safetensors_metadata_writer_rejects_empty_tensors() { + let mut bytes = Vec::new(); + let err = write_safetensors(&mut bytes, &[]).expect_err("empty safetensors should fail"); + + assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput); +} + +#[test] +fn test_safetensors_metadata_writer_rejects_shape_data_mismatch() { + let mut bytes = Vec::new(); + let err = write_safetensors( + &mut bytes, + &[SafeTensorData { + name: "bad".to_string(), + shape: [2].into(), + data_type: DataType::F32, + data: vec![0; 4].into_boxed_slice(), + }], + ) + .expect_err("shape mismatch should fail"); + + assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput); +} + #[test] fn test_metadata_loading() { let path = get_test_weights_path(); @@ -11,3 +55,19 @@ fn test_metadata_loading() { let (_offset, metadata) = read_safetensors_metadata(&file).expect("read metadata"); assert!(metadata.tensors.len() > 0); } + +#[test] +fn test_safetensors_metadata_writer_roundtrips_arrays() { + let context = ::Context::new().expect("create CPU context"); + let floats = context.create_array_from(&[2, 2], &[1.0f32, 2.0, 3.0, 4.0]); + let ints = context.create_array_from(&[2], &[7i32, 9]); + let mut file = tempfile::NamedTempFile::new().expect("create safetensors file"); + + write_safetensors(file.as_file_mut(), &[tensor_from_array("floats", &floats), tensor_from_array("ints", &ints)]) + .expect("write safetensors file"); + + let loader_file = file.reopen().expect("open safetensors file"); + let loader = ParameterLoader::new(&loader_file, context.as_ref()).expect("create loader"); + assert_eq!(loader.tree().leaf_array("floats").unwrap().as_slice::(), &[1.0, 2.0, 3.0, 4.0]); + assert_eq!(loader.tree().leaf_array("ints").unwrap().as_slice::(), &[7, 9]); +} diff --git a/crates/backend-uzu/tests/integration/tracer/trace_validator.rs b/crates/backend-uzu/tests/integration/tracer/trace_validator.rs index eb1f32bcd..4c0b6c0fb 100644 --- a/crates/backend-uzu/tests/integration/tracer/trace_validator.rs +++ b/crates/backend-uzu/tests/integration/tracer/trace_validator.rs @@ -1,9 +1,3 @@ -//! Unified trace validation for any model type (LLM or classifier). -//! -//! This module provides a single `TraceValidator` that can validate activation -//! traces for any model type. It automatically detects the model type from the -//! config and runs the appropriate validation. - use std::{ cell::RefCell, fs::File, @@ -13,155 +7,38 @@ use std::{ use backend_uzu::{ _private::{ - ActivationTrace, ArgmaxSampler, CacheLayers, Classifier, DecoderDecodeInput, KVCacheLayer, - LanguageModelGeneratorContext, LogitsSampler, ModelConfig, ModelMetadata, ModelType, ParameterTree, Sampling, - TokenInputs, + ActivationTrace, CacheLayers, Classifier, DecoderDecodeInput, LanguageModelGeneratorContext, ModelConfig, + ModelMetadata, ModelType, ParameterTree, Sampling, TokenInputs, TransformerLayerConfig, }, - Array, ArrayElement, DataType, ParameterLoader, allocation_to_vec, - backends::common::{Allocation, Backend, Encoder, kernel::kv_cache_update::KVCacheUpdate}, + Array, ArrayElement, DataType, ParameterLoader, SafeTensorData, + backends::common::{Backend, Encoder, allocation_as_bytes, kernel::kv_cache_update::KVCacheUpdate}, read_safetensors_metadata, session::{ config::{DecodingConfig, SpeculatorConfig}, parameter::{AsyncBatchSize, ConfigResolvableValue, ContextLength, ContextMode, PrefillStepSize, SamplingSeed}, types::Error, }, + write_safetensors, }; -use half::{bf16, f16}; -use ndarray::{IxDyn, s}; use num_traits::NumCast; -use crate::common; -// ============================================================================ -// Validation Types -// ============================================================================ - -/// Metrics from validating a single tensor. -pub struct TracerValidationMetrics { - pub atol: f32, - pub rtol: f32, - #[allow(unused)] - pub fraction_of_allowed_violations: f32, - pub reference_shape: Vec, - pub result_shape: Vec, - pub num_violations: usize, - pub max_allowed_violations: usize, - pub max_err_idx: usize, - pub max_err: f32, - pub max_err_rel: f32, - pub max_err_reference_value: f32, - pub rms_diff: f32, - pub rms_result: f32, - pub rms_reference: f32, - pub rel_rms_reference: f32, - pub diff_max: f32, - pub diff_avg: f32, - pub result_nan: bool, -} - -impl TracerValidationMetrics { - pub fn is_valid(&self) -> bool { - self.num_violations <= self.max_allowed_violations && !self.result_nan - } - - pub fn message(&self) -> String { - if self.result_nan { - return "Result contains NaN values".to_string(); - } - - let allowed_violation_explainer = if self.max_allowed_violations > 0 { - format!(" (Max {} allowed)", self.max_allowed_violations) - } else { - String::new() - }; - - let reference_size: usize = self.reference_shape.iter().product(); - - format!( - "{} violations > {:.1e} + {:.2}% out of total {} elements{}.\n\ - Worst violation: {:.3} ({:.2}%) at index {} (reference value: {:.3}).\n\ - Error RMS: {:.3}.\n\ - RMS of result: {:.3}, RMS of reference: {:.3}.\n\ - Relative error RMS: {:.2}% of RMS of reference.\n\ - Shape: {:?}\n\ - Max diff: {:.3}, Avg diff: {:.3}", - self.num_violations, - self.atol, - self.rtol * 100.0, - reference_size, - allowed_violation_explainer, - self.max_err, - self.max_err_rel * 100.0, - self.max_err_idx, - self.max_err_reference_value, - self.rms_diff, - self.rms_result, - self.rms_reference, - self.rel_rms_reference * 100.0, - self.result_shape, - self.diff_max, - self.diff_avg, - ) - } -} - -/// Result of validating a single tensor. -pub struct TracerValidationResult { - pub name: String, - pub metrics: TracerValidationMetrics, -} - -/// Results from validating all traces. -pub struct TracerValidationResults { - pub suffix_length: usize, - pub results: Vec, - pub tokens_violation_indices: Vec, -} - -impl TracerValidationResults { - pub fn number_of_tokens_violations(&self) -> usize { - self.tokens_violation_indices.len() - } - - pub fn number_of_allowed_tokens_violations(&self) -> usize { - let threshold: f64 = 0.01; - (self.suffix_length as f64 * threshold).ceil() as usize - } +#[derive(Clone, Copy)] +enum ExportShape { + Native, + Batched, } -/// Transform to apply to produced arrays before comparison. -pub enum ArrayTransform { - /// Slice KV cache to match expected shape. - KVCacheSlice, - /// Transform SSM conv state layout. - SsmConvState, -} - -// ============================================================================ -// Model Context (internal) -// ============================================================================ - enum ModelContext { LanguageModelGenerator(LanguageModelGeneratorContext), Classifier(Classifier), } -// ============================================================================ -// Unified TraceValidator -// ============================================================================ - -/// Unified trace validator for any model type. -/// -/// Automatically detects whether the model is an LLM or classifier and -/// runs the appropriate validation. pub struct TraceValidator { model_path: PathBuf, context: ModelContext, } impl TraceValidator { - /// Create a new trace validator for the given model path. - /// - /// Automatically detects the model type from config.json. pub fn new(model_path: &Path) -> Result { let config_path = model_path.join("config.json"); if !config_path.exists() { @@ -209,7 +86,7 @@ impl TraceValidator { model_config, grammar_start_tokens: raw_metadata.grammar_start_tokens, }; - let prefill_step_size = Self::determine_prefill_step_size(model_path); + let prefill_step_size = Self::determine_prefill_step_size(model_path)?; let decoding_config = DecodingConfig::new( ContextMode::default(), ContextLength::default(), @@ -232,40 +109,60 @@ impl TraceValidator { }) } - /// Run trace validation and return results. - pub fn run(&mut self) -> Result { + pub fn export_trace( + &mut self, + output_path: &Path, + ) -> Result<(), Error> { let traces_path = self.model_path.join("traces.safetensors"); if !traces_path.exists() { return Err(Error::UnableToLoadWeights); } match &mut self.context { - ModelContext::LanguageModelGenerator(ctx) => Self::run_llm_validation(ctx, &traces_path), - ModelContext::Classifier(classifier) => Self::run_classifier_validation(classifier, &traces_path), + ModelContext::LanguageModelGenerator(ctx) => Self::export_llm_trace(ctx, &traces_path, output_path), + ModelContext::Classifier(classifier) => { + Self::export_classifier_trace(classifier, &traces_path, output_path) + }, } } - // ======================================================================== - // LLM Validation - // ======================================================================== - - fn run_llm_validation( + fn export_llm_trace( ctx: &LanguageModelGeneratorContext, traces_path: &Path, - ) -> Result { + output_path: &Path, + ) -> Result<(), Error> { let traces_file = File::open(traces_path).map_err(|_| Error::UnableToLoadWeights)?; let traces_loader = ParameterLoader::new(&traces_file, ctx.context.as_ref()).map_err(|_| Error::UnableToLoadWeights)?; let traces_view = traces_loader.tree(); + let (token_ids_array, token_positions_array, token_ids, token_positions) = + Self::load_trace_inputs(&traces_view)?; + let traces = Self::run_llm_trace(ctx, &token_ids, &token_positions)?; + + let mut tensors = vec![ + Self::tensor_from_array("activation_trace.token_ids", &token_ids_array), + Self::tensor_from_array("activation_trace.token_positions", &token_positions_array), + ]; + Self::push_activation_trace_tensors( + &mut tensors, + &traces, + &ctx.model_config.model_config.transformer_config.layer_configs, + ExportShape::Batched, + ); + Self::write_trace_file(output_path, &tensors) + } - let token_ids = Self::load_array_as_vec::(&traces_view, "activation_trace.token_ids"); - let token_positions = Self::load_array_as_vec::(&traces_view, "activation_trace.token_positions"); + fn run_llm_trace( + ctx: &LanguageModelGeneratorContext, + token_ids: &[u64], + token_positions: &[usize], + ) -> Result, Error> { let token_inputs = TokenInputs::new_llm( ctx.context.as_ref(), &ctx.model_shape, - &token_ids, + token_ids, None, - &token_positions, + token_positions, None, None, /*sampling_start=*/ 0, @@ -274,9 +171,10 @@ impl TraceValidator { let mut traces = ActivationTrace::new_llm(ctx.context.as_ref(), &ctx.model_shape, token_ids.len()); let mut encoder = - Encoder::::new(ctx.context.as_ref()).map_err(|e| Error::UnableToCreateCommandBuffer(e.into()))?; + Encoder::::new(ctx.context.as_ref()).map_err(|err| Error::UnableToCreateCommandBuffer(err.into()))?; { let mut cache_layers = ctx.cache_layers.borrow_mut(); + cache_layers.prepare_for_forward_pass(ctx.context.as_ref(), token_ids.len()); let decoder_arguments = token_inputs.decoder_arguments( ctx.shared_buffers.as_ref(), Some(&mut *cache_layers), @@ -293,643 +191,172 @@ impl TraceValidator { None, &mut encoder, ) - .map_err(|e| Error::EncodeFailed(Box::new(e)))?; + .map_err(|err| Error::EncodeFailed(Box::new(err)))?; } let pending = encoder.end_encoding().submit(); - pending.wait_until_completed().map_err(|e| Error::CommandBufferFailed(Box::new(e)))?; - - let data_type = ctx.model_shape.activation_data_type(); - - // Common layer validation - let mut results = Self::validate_layer_traces(&traces, &traces_view, data_type); - - // LLM-specific: KV cache validation - let cache = ctx.cache_layers.borrow(); - for (index, layer) in cache.iter_layers() { - let Some(kv) = layer.as_transformer() else { - continue; - }; - - if let Ok(expected) = traces_view.leaf_array(&format!("updated_kv_cache.{}.keys", index)) { - let size = kv.shape().iter().product::() * data_type.size_in_bytes(); - let keys = if let Some(layer) = kv.as_any().downcast_ref::>() { - common::helpers::sparse_buffer_read_allocation(ctx.context.as_ref(), &layer.keys, size) - } else { - panic!("Wrong keys type") - }; - results.push(TracerValidationResult { - name: format!("updated_kv_cache.{}.keys", index), - metrics: Self::validate_allocation( - data_type, - &expected, - &keys, - &kv.shape(), - Some(ArrayTransform::KVCacheSlice), - ), - }); - } - - if let Ok(expected) = traces_view.leaf_array(&format!("updated_kv_cache.{}.values", index)) { - let size = kv.shape().iter().product::() * data_type.size_in_bytes(); - let values = if let Some(layer) = kv.as_any().downcast_ref::>() { - common::helpers::sparse_buffer_read_allocation(ctx.context.as_ref(), &layer.values, size) - } else { - panic!("Wrong values type") - }; - results.push(TracerValidationResult { - name: format!("updated_kv_cache.{}.values", index), - metrics: Self::validate_allocation( - data_type, - &expected, - &values, - &kv.shape(), - Some(ArrayTransform::KVCacheSlice), - ), - }); - } - } - - // LLM-specific: SSM state validation - for (index, layer) in cache.iter_layers() { - let Some(ssm) = layer.as_state_space() else { - continue; - }; - - for path in [ - format!("updated_state.{}.conv_state", index), - format!("activation_trace.layer_results.{}.updated_state.conv_state", index), - ] { - if let Ok(expected) = traces_view.leaf_array(&path) { - results.push(TracerValidationResult { - name: path, - metrics: Self::validate_optional_allocation( - data_type, - &expected, - ssm.conv_state.as_ref(), - &ssm.conv_shape, - Some(ArrayTransform::SsmConvState), - ), - }); - } - } - - for path in [ - format!("updated_state.{}.ssm_state", index), - format!("activation_trace.layer_results.{}.updated_state.ssm_state", index), - ] { - if let Ok(expected) = traces_view.leaf_array(&path) { - results.push(TracerValidationResult { - name: path, - metrics: Self::validate_allocation(data_type, &expected, &ssm.ssm_state, &ssm.ssm_shape, None), - }); - } - } - } - - // LLM-specific: DeltaNet state validation - for (index, layer) in cache.iter_layers() { - let Some(delta) = layer.as_delta_net() else { - continue; - }; - - for path in [ - format!("updated_state.{}.conv_state", index), - format!("activation_trace.layer_results.{}.updated_state.conv_state", index), - ] { - if let Ok(expected) = traces_view.leaf_array(&path) { - results.push(TracerValidationResult { - name: path, - metrics: Self::validate_allocation( - data_type, - &expected, - &delta.conv_state, - &delta.conv_shape, - Some(ArrayTransform::SsmConvState), - ), - }); - } - } - - for path in [ - format!("updated_state.{}.ssm_state", index), - format!("activation_trace.layer_results.{}.updated_state.ssm_state", index), - ] { - if let Ok(expected) = traces_view.leaf_array(&path) { - results.push(TracerValidationResult { - name: path, - metrics: Self::validate_allocation( - data_type, - &expected, - &delta.ssm_state, - &delta.ssm_shape, - None, - ), - }); - } - } - } - - // LLM-specific: Token comparison - let tokens_violation_indices = if let Ok(expected_logits) = traces_view.leaf_array("logits") { - let expected_tokens = Self::get_tokens_from_logits(&expected_logits); - let produced_tokens = Self::get_tokens_from_logits(&traces.logits); - expected_tokens - .iter() - .zip(produced_tokens.iter()) - .enumerate() - .filter_map(|(i, (a, b))| { - if a != b { - Some(i) - } else { - None - } - }) - .collect() - } else { - Vec::new() - }; - - Ok(TracerValidationResults { - suffix_length: token_ids.len(), - results, - tokens_violation_indices, - }) + pending.wait_until_completed().map_err(|err| Error::CommandBufferFailed(Box::new(err)))?; + Ok(traces) } - // ======================================================================== - // Classifier Validation - // ======================================================================== - - fn run_classifier_validation( + fn export_classifier_trace( classifier: &mut Classifier, traces_path: &Path, - ) -> Result { + output_path: &Path, + ) -> Result<(), Error> { let traces_file = File::open(traces_path).map_err(|_| Error::UnableToLoadWeights)?; let context = classifier.context.context.clone(); let traces_loader = ParameterLoader::new(&traces_file, context.as_ref()).map_err(|_| Error::UnableToLoadWeights)?; let traces_view = traces_loader.tree(); - - let has_token_ids = traces_view.leaf_array("activation_trace.token_ids").is_ok(); - let has_token_positions = traces_view.leaf_array("activation_trace.token_positions").is_ok(); - - if !has_token_ids || !has_token_positions { - return Ok(Self::handle_missing_tokens(&traces_view)); - } - - let token_ids = Self::load_array_as_vec::(&traces_view, "activation_trace.token_ids"); - let token_positions = Self::load_array_as_vec::(&traces_view, "activation_trace.token_positions"); - - let suffix_length = token_ids.len(); - + let (token_ids_array, token_positions_array, token_ids, token_positions) = + Self::load_trace_inputs(&traces_view)?; let (_logits, traces) = classifier.forward_pass_with_traces(&token_ids, &token_positions).map_err(|_| Error::GenerateFailed)?; - let data_type = classifier.context.model_shape.activation_data_type(); - - // Common layer validation - let mut results = Self::validate_layer_traces(&traces, &traces_view, data_type); - - // Classifier-specific: embedding_norm, output_pooling - let classifier_results = Self::validate_classifier_traces(&traces, &traces_view, data_type); - results.extend(classifier_results); - - Ok(TracerValidationResults { - suffix_length, - results, - tokens_violation_indices: Vec::new(), // Classifiers don't compare tokens - }) + let mut tensors = vec![ + Self::tensor_from_array("activation_trace.token_ids", &token_ids_array), + Self::tensor_from_array("activation_trace.token_positions", &token_positions_array), + ]; + Self::push_activation_trace_tensors( + &mut tensors, + &traces, + &classifier.context.model_config.model_config.transformer_config.layer_configs, + ExportShape::Native, + ); + Self::write_trace_file(output_path, &tensors) } - fn handle_missing_tokens(traces_view: &ParameterTree) -> TracerValidationResults { - if let Ok(expected_logits) = traces_view.leaf_array("logits") { - let reference_shape = expected_logits.shape().to_vec(); - let metrics = TracerValidationMetrics { - atol: 0.0, - rtol: 0.0, - fraction_of_allowed_violations: 0.0, - reference_shape: reference_shape.clone(), - result_shape: reference_shape, - num_violations: 0, - max_allowed_violations: 0, - max_err_idx: 0, - max_err: 0.0, - max_err_rel: 0.0, - max_err_reference_value: 0.0, - rms_diff: 0.0, - rms_result: 0.0, - rms_reference: 0.0, - rel_rms_reference: 0.0, - diff_max: 0.0, - diff_avg: 0.0, - result_nan: false, - }; - return TracerValidationResults { - suffix_length: 1, - results: vec![TracerValidationResult { - name: "activation_trace.logits".to_string(), - metrics, - }], - tokens_violation_indices: Vec::new(), - }; - } - - TracerValidationResults { - suffix_length: 1, - results: Vec::new(), - tokens_violation_indices: Vec::new(), - } + fn load_trace_inputs( + traces_view: &ParameterTree + ) -> Result<(Array, Array, Vec, Vec), Error> { + let token_ids_array = + traces_view.leaf_array("activation_trace.token_ids").map_err(|_| Error::UnableToLoadWeights)?; + let token_positions_array = + traces_view.leaf_array("activation_trace.token_positions").map_err(|_| Error::UnableToLoadWeights)?; + Self::validate_trace_input_shape(&token_ids_array, &token_positions_array)?; + let token_ids = Self::array_as_vec::(&token_ids_array)?; + let token_positions = Self::array_as_vec::(&token_positions_array)?; + Ok((token_ids_array, token_positions_array, token_ids, token_positions)) } - // ======================================================================== - // Common Validation Helpers - // ======================================================================== - - fn validate_layer_traces( - traces: &ActivationTrace, - traces_view: &ParameterTree, - data_type: DataType, - ) -> Vec { - let mut results = Vec::new(); - - let validate = |path: &str, array: &Array| -> Option { - if let Ok(expected) = traces_view.leaf_array(path) { - Some(TracerValidationResult { - name: path.to_string(), - metrics: Self::validate_allocation(data_type, &expected, array.allocation(), array.shape(), None), - }) - } else { - None - } + fn validate_trace_input_shape( + token_ids: &Array, + token_positions: &Array, + ) -> Result<(), Error> { + let &[batch, suffix_length] = token_ids.shape() else { + return Err(Error::UnableToLoadWeights); }; - - for (index, layer_traces) in traces.layer_results.iter().enumerate() { - let path = |suffix: &str| -> String { - format!("activation_trace.layer_results.{}.activation_trace.{}", index, suffix) - }; - - if let Some(r) = validate(&path("inputs"), &layer_traces.inputs) { - results.push(r); - } - if let Some(r) = validate(&path("pre_mixer_norm"), &layer_traces.pre_attention_norm) { - results.push(r); - } - if let Some(r) = validate(&path("mixer"), &layer_traces.attention) { - results.push(r); - } - if let Some(r) = validate(&path("post_mixer_norm"), &layer_traces.post_attention_norm) { - results.push(r); - } - if let Some(r) = validate(&path("mlp_inputs"), &layer_traces.mlp_inputs) { - results.push(r); - } - if let Some(r) = validate(&path("pre_mlp_norm"), &layer_traces.pre_mlp_norm) { - results.push(r); - } - if let Some(r) = validate(&path("mlp"), &layer_traces.mlp) { - results.push(r); - } - if let Some(r) = validate(&path("post_mlp_norm"), &layer_traces.post_mlp_norm) { - results.push(r); - } - - let outputs_path = format!("activation_trace.layer_results.{}.outputs", index); - if let Some(r) = validate(&outputs_path, &layer_traces.outputs) { - results.push(r); - } - } - - // Output norm (common to all models) - if let Some(r) = validate("activation_trace.output_norm", &traces.output_norm) { - results.push(r); - } - - // Logits (common to all models, but path may vary) - if let Some(r) = validate("activation_trace.logits", &traces.logits) { - results.push(r); - } else if let Some(r) = validate("logits", &traces.logits) { - results.push(r); + if batch != 1 || suffix_length == 0 || token_positions.shape() != token_ids.shape() { + return Err(Error::UnableToLoadWeights); } - - results + Ok(()) } - fn validate_classifier_traces( - traces: &ActivationTrace, - traces_view: &ParameterTree, - data_type: DataType, - ) -> Vec { - let mut results = Vec::new(); - - // Output pooling (classifier-specific) - if let Some(output_pooling) = &traces.output_pooling { - if let Ok(expected) = traces_view.leaf_array("activation_trace.output_pooling") { - results.push(TracerValidationResult { - name: "activation_trace.output_pooling".to_string(), - metrics: Self::validate_allocation( - data_type, - &expected, - output_pooling.allocation(), - output_pooling.shape(), - None, - ), - }); - } - } - - results - } - - // ======================================================================== - // Allocation Validation - // ======================================================================== - - fn validate_allocation( - data_type: DataType, - expected_array: &Array, - produced_allocation: &Allocation, - produced_shape: &[usize], - transform: Option, - ) -> TracerValidationMetrics { - match data_type { - DataType::F16 => { - Self::validate_allocation_of_type::(expected_array, produced_allocation, produced_shape, transform) - }, - DataType::BF16 => Self::validate_allocation_of_type::( - expected_array, - produced_allocation, - produced_shape, - transform, - ), - DataType::F32 => { - Self::validate_allocation_of_type::(expected_array, produced_allocation, produced_shape, transform) - }, - _ => panic!("Unsupported data type: {:?}", data_type), - } + fn array_as_vec( + array: &Array + ) -> Result, Error> { + let slice = array.as_slice::(); + slice.iter().map(|value| NumCast::from(*value).ok_or(Error::UnableToLoadWeights)).collect() } - fn validate_optional_allocation( - data_type: DataType, - expected_array: &Array, - produced_allocation: Option<&Allocation>, - produced_shape: &[usize], - transform: Option, - ) -> TracerValidationMetrics { - match produced_allocation { - Some(produced_allocation) => { - Self::validate_allocation(data_type, expected_array, produced_allocation, produced_shape, transform) - }, - None => match data_type { - DataType::F16 => { - Self::validate_allocation_data_of_type::(expected_array, &[], produced_shape, transform) - }, - DataType::BF16 => { - Self::validate_allocation_data_of_type::(expected_array, &[], produced_shape, transform) - }, - DataType::F32 => { - Self::validate_allocation_data_of_type::(expected_array, &[], produced_shape, transform) - }, - _ => panic!("Unsupported data type: {:?}", data_type), + fn push_array( + tensors: &mut Vec, + path: impl Into, + array: &Array, + export_shape: ExportShape, + ) { + let path = path.into(); + let tensor = match export_shape { + ExportShape::Native => Self::tensor_from_array(path, array), + ExportShape::Batched => { + let shape = std::iter::once(1).chain(array.shape().iter().copied()).collect::>(); + Self::tensor_from_array_with_shape(path, array, shape.into_boxed_slice()) }, - } + }; + tensors.push(tensor); } - fn validate_allocation_of_type( - expected_array: &Array, - produced_allocation: &Allocation, - produced_shape: &[usize], - transform: Option, - ) -> TracerValidationMetrics { - let produced = allocation_to_vec::(produced_allocation); - Self::validate_allocation_data_of_type(expected_array, &produced, produced_shape, transform) + fn tensor_from_array( + name: impl Into, + array: &Array, + ) -> SafeTensorData { + Self::tensor_from_array_with_shape(name, array, array.shape().into()) } - fn validate_allocation_data_of_type( - expected_array: &Array, - produced_slice: &[Precision], - produced_shape: &[usize], - transform: Option, - ) -> TracerValidationMetrics { - let expected_view = expected_array.as_view::(); - let produced_view = ndarray::ArrayView::from_shape(IxDyn(produced_shape), produced_slice) - .expect("Failed to reshape allocation"); - - let (mut expected_data, mut produced_data) = match transform { - Some(ArrayTransform::KVCacheSlice) => { - let permuted = produced_view.permuted_axes(IxDyn(&[1, 0, 2])); - let total_tokens = permuted.shape()[0]; - let expected_tokens = expected_view.shape()[1]; - let start = total_tokens.saturating_sub(expected_tokens); - let sliced = permuted.slice(s![start.., .., ..]); - let reshaped = sliced - .into_owned() - .to_shape(IxDyn(&[1, expected_tokens, permuted.shape()[1], permuted.shape()[2]])) - .expect("Failed to reshape KV cache slice") - .to_owned(); - (expected_view.to_owned(), reshaped) - }, - Some(ArrayTransform::SsmConvState) => { - let produced_shape = produced_view.shape(); - let history_len = produced_shape[1]; - let dim = produced_shape[0]; - - let permuted = expected_view.permuted_axes(IxDyn(&[0, 2, 1])); - let total_time = permuted.shape()[2]; - let start = total_time.saturating_sub(history_len); - let sliced = permuted.slice(s![.., .., start..]); - - let reshaped_expected = sliced - .into_owned() - .to_shape(IxDyn(&[dim, history_len])) - .expect("Failed to reshape SSM conv state slice") - .to_owned(); - - (reshaped_expected, produced_view.to_owned()) - }, - None => (expected_view.to_owned(), produced_view.to_owned()), - }; - - let expected_shape = expected_data.shape().to_vec(); - let produced_shape = produced_data.shape().to_vec(); - - if expected_shape != produced_shape { - if expected_shape.len() == produced_shape.len() + 1 - && expected_shape.get(0) == Some(&1) - && expected_shape[1..] == produced_shape[..] - { - expected_data = - expected_data.to_shape(IxDyn(&produced_shape)).expect("Failed to reshape expected data").to_owned(); - } else if produced_shape.len() == expected_shape.len() + 1 - && produced_shape.get(0) == Some(&1) - && produced_shape[1..] == expected_shape[..] - { - produced_data = - produced_data.to_shape(IxDyn(&expected_shape)).expect("Failed to reshape produced data").to_owned(); - } - } - - if expected_data.shape() != produced_data.shape() { - panic!( - "Shape mismatch after alignment: expected {:?}, produced {:?}", - expected_data.shape(), - produced_data.shape() - ); + fn tensor_from_array_with_shape( + name: impl Into, + array: &Array, + shape: Box<[usize]>, + ) -> SafeTensorData { + SafeTensorData { + name: name.into(), + shape, + data_type: array.data_type(), + data: allocation_as_bytes(array.allocation()).into(), } - - let reference: Vec = expected_data.iter().map(|value| NumCast::from(*value).unwrap_or(0.0)).collect(); - let result: Vec = produced_data.iter().map(|value| NumCast::from(*value).unwrap_or(0.0)).collect(); - - let (atol, rtol, allowed_voilations_tol) = match expected_array.data_type() { - DataType::BF16 => (0.04, 0.06, 0.03), - _ => (0.01, 0.03, 0.01), - }; - - Self::compare_arrays( - &reference, - expected_data.shape().to_vec(), - &result, - produced_data.shape().to_vec(), - atol, - rtol, - allowed_voilations_tol, - ) } - fn compare_arrays( - reference: &[f32], - reference_shape: Vec, - result: &[f32], - result_shape: Vec, - atol: f32, - rtol: f32, - fraction_of_allowed_violations: f32, - ) -> TracerValidationMetrics { - assert_eq!(result.len(), reference.len()); - if reference.is_empty() { - return TracerValidationMetrics { - atol, - rtol, - fraction_of_allowed_violations, - reference_shape, - result_shape, - num_violations: 0, - max_allowed_violations: 0, - max_err_idx: 0, - max_err: 0.0, - max_err_rel: 0.0, - max_err_reference_value: 0.0, - rms_diff: 0.0, - rms_result: 0.0, - rms_reference: 0.0, - rel_rms_reference: 0.0, - diff_max: 0.0, - diff_avg: 0.0, - result_nan: false, - }; + fn push_activation_trace_tensors( + tensors: &mut Vec, + traces: &ActivationTrace, + layer_configs: &[TransformerLayerConfig], + export_shape: ExportShape, + ) { + if let Some(embedding_norm) = &traces.embedding_norm { + Self::push_array(tensors, "activation_trace.embedding_norm", embedding_norm, export_shape); } - let mut num_violations = 0; - let mut max_err = 0.0f32; - let mut max_err_idx = 0; - let mut max_err_rel = 0.0f32; - let mut max_err_reference_value = 0.0f32; - let mut sum_sq_diff = 0.0f32; - let mut sum_sq_result = 0.0f32; - let mut sum_sq_reference = 0.0f32; - let mut diff_sum = 0.0f32; - let mut diff_max = 0.0f32; - let mut result_nan = false; - - for (i, (&exp, &prod)) in reference.iter().zip(result.iter()).enumerate() { - if prod.is_nan() { - result_nan = true; - } - - let abs_diff = (exp - prod).abs(); - let rel_diff = if exp.abs() > 1e-8 { - abs_diff / exp.abs() - } else { - abs_diff - }; - - diff_sum += abs_diff; - diff_max = diff_max.max(abs_diff); - sum_sq_diff += abs_diff * abs_diff; - sum_sq_result += prod * prod; - sum_sq_reference += exp * exp; - - if abs_diff > atol && rel_diff > rtol { - num_violations += 1; + for (index, layer_traces) in traces.layer_results.iter().enumerate() { + let layer_config = &layer_configs[index]; + let path = |suffix: &str| format!("activation_trace.layer_results.{index}.activation_trace.{suffix}"); + + Self::push_array(tensors, path("inputs"), &layer_traces.inputs, export_shape); + Self::push_array(tensors, path("pre_mixer_norm"), &layer_traces.pre_attention_norm, export_shape); + Self::push_array(tensors, path("mixer"), &layer_traces.attention, export_shape); + if layer_config.post_mixer_norm_config.is_some() { + Self::push_array(tensors, path("post_mixer_norm"), &layer_traces.post_attention_norm, export_shape); } - - if abs_diff > max_err { - max_err = abs_diff; - max_err_idx = i; - max_err_rel = rel_diff; - max_err_reference_value = exp; + Self::push_array(tensors, path("mlp_inputs"), &layer_traces.mlp_inputs, export_shape); + Self::push_array(tensors, path("pre_mlp_norm"), &layer_traces.pre_mlp_norm, export_shape); + Self::push_array(tensors, path("mlp"), &layer_traces.mlp, export_shape); + if layer_config.post_mlp_norm_config.is_some() { + Self::push_array(tensors, path("post_mlp_norm"), &layer_traces.post_mlp_norm, export_shape); } + Self::push_array( + tensors, + format!("activation_trace.layer_results.{index}.outputs"), + &layer_traces.outputs, + export_shape, + ); } - let n = reference.len() as f32; - let rms_diff = (sum_sq_diff / n).sqrt(); - let rms_result = (sum_sq_result / n).sqrt(); - let rms_reference = (sum_sq_reference / n).sqrt(); - let rel_rms_reference = if rms_reference > 1e-8 { - rms_diff / rms_reference - } else { - rms_diff - }; - let diff_avg = diff_sum / n; - - let max_allowed_violations = (fraction_of_allowed_violations * n).ceil() as usize; - - TracerValidationMetrics { - atol, - rtol, - fraction_of_allowed_violations, - reference_shape, - result_shape, - num_violations, - max_allowed_violations, - max_err_idx, - max_err, - max_err_rel, - max_err_reference_value, - rms_diff, - rms_result, - rms_reference, - rel_rms_reference, - diff_max, - diff_avg, - result_nan, + Self::push_array(tensors, "activation_trace.output_norm", &traces.output_norm, export_shape); + if let Some(output_pooling) = &traces.output_pooling { + Self::push_array(tensors, "activation_trace.output_pooling", output_pooling, export_shape); } + Self::push_array(tensors, "logits", &traces.logits, export_shape); } - // ======================================================================== - // Utility Functions - // ======================================================================== - - fn load_array_as_vec( - traces_view: &ParameterTree, - name: &str, - ) -> Vec { - let array = traces_view.leaf_array(name).unwrap(); - let slice = array.as_slice::(); - slice.iter().map(|x| NumCast::from(*x).unwrap()).collect() + fn write_trace_file( + output_path: &Path, + tensors: &[SafeTensorData], + ) -> Result<(), Error> { + let mut file = File::create_new(output_path).map_err(|_| Error::UnableToWriteTrace)?; + write_safetensors(&mut file, tensors).map_err(|_| Error::UnableToWriteTrace) } - fn determine_prefill_step_size(model_path: &Path) -> usize { + fn determine_prefill_step_size(model_path: &Path) -> Result { let traces_path = model_path.join("traces.safetensors"); - if let Ok(file) = File::open(&traces_path) { - if let Ok((_header_len, metadata)) = read_safetensors_metadata(&file) { - if let Some(tensor) = metadata.tensors.get("activation_trace.token_ids") { - if let Some(&length) = tensor.shape.first() { - return tensor.shape.iter().copied().max().unwrap_or(length).max(1); - } - } - } + let file = File::open(&traces_path).map_err(|_| Error::UnableToLoadWeights)?; + let (_header_len, metadata) = read_safetensors_metadata(&file).map_err(|_| Error::UnableToLoadWeights)?; + let tensor = metadata.tensors.get("activation_trace.token_ids").ok_or(Error::UnableToLoadWeights)?; + let &[batch, suffix_length] = tensor.shape.as_slice() else { + return Err(Error::UnableToLoadWeights); + }; + if batch != 1 || suffix_length == 0 { + return Err(Error::UnableToLoadWeights); } - 1 + Ok(suffix_length) } fn ensure_llm_context_capacity( @@ -965,19 +392,4 @@ impl TraceValidator { context.gpu_sampler = Sampling::new(context.context.as_ref(), intermediate_dtype).expect("Failed to create sampling kernel"); } - - fn get_tokens_from_logits(logits: &Array) -> Vec { - let data_type = logits.data_type(); - match data_type { - DataType::F16 => Self::get_tokens_from_logits_of_type::(logits), - DataType::BF16 => Self::get_tokens_from_logits_of_type::(logits), - DataType::F32 => Self::get_tokens_from_logits_of_type::(logits), - _ => panic!("Unsupported data type: {:?}", data_type), - } - } - - fn get_tokens_from_logits_of_type(logits: &Array) -> Vec { - let sampler = ArgmaxSampler {}; - sampler.sample(logits.as_view::()) - } } diff --git a/crates/backend-uzu/tests/integration/tracer/trace_validator_test.rs b/crates/backend-uzu/tests/integration/tracer/trace_validator_test.rs index 6983c8d2e..02e63c669 100644 --- a/crates/backend-uzu/tests/integration/tracer/trace_validator_test.rs +++ b/crates/backend-uzu/tests/integration/tracer/trace_validator_test.rs @@ -1,4 +1,6 @@ -use backend_uzu::backends::common::Backend; +use std::{fs::File, path::PathBuf}; + +use backend_uzu::{backends::common::Backend, read_safetensors_metadata}; use crate::{ common::{ @@ -11,24 +13,22 @@ use crate::{ fn test_tracer_internal() { let model_path = get_test_model_path(); let mut tracer = TraceValidator::::new(&model_path).expect("Failed to create TraceValidator"); - let results = tracer.run().expect("Failed to run tracer"); - for result in results.results.iter() { - // this layers contains too many errors - if result.name == "activation_trace.output_norm" || result.name == "logits" { - continue; - } - assert!(result.metrics.is_valid(), "{} error:\n{}", result.name, result.metrics.message().as_str()); - } + let (export_path, _temp_file) = match std::env::var_os("UZU_TRACE_EXPORT_PATH") { + Some(path) => (PathBuf::from(path), None), + None => { + let directory = tempfile::TempDir::new().expect("create exported trace directory"); + (directory.path().join("uzu-traces.safetensors"), Some(directory)) + }, + }; + tracer.export_trace(&export_path).expect("export uzu trace"); - let total_token_violations = results.number_of_tokens_violations(); - let allowed_token_violations = results.number_of_allowed_tokens_violations(); - assert!( - total_token_violations < allowed_token_violations, - "Too much token violations: {} / {}. Indices: {:?}", - total_token_violations, - allowed_token_violations, - results.tokens_violation_indices - ); + let file = File::open(&export_path).expect("open exported trace"); + let (_offset, metadata) = read_safetensors_metadata(&file).expect("read exported trace metadata"); + assert!(metadata.tensors.contains_key("activation_trace.token_ids")); + assert!(metadata.tensors.contains_key("activation_trace.token_positions")); + assert!(metadata.tensors.contains_key("activation_trace.output_norm")); + assert!(metadata.tensors.contains_key("logits")); + assert!(metadata.tensors.keys().any(|path| path.ends_with(".activation_trace.inputs"))); } #[test]