Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/backend-uzu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ rstest.workspace = true
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
criterion = { workspace = true, default-features = true }
proptest = { workspace = true, default-features = true }
tempfile.workspace = true

[build-dependencies]
anyhow.workspace = true
Expand Down
4 changes: 2 additions & 2 deletions crates/backend-uzu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ pub use audio::{NanoCodecFsqRuntime, NanoCodecFsqRuntimeConfig};
pub use backends::common::{AllocationAccessError, allocation_copy_from_slice, allocation_to_vec};
pub use data_type::{ArrayElement, DataType};
pub use language_model::gumbel::{gumbel_float, revidx};
pub use parameters::{ParameterLoader, read_safetensors_metadata};
pub use parameters::{ParameterLoader, SafeTensorData, read_safetensors_metadata, write_safetensors};
pub use utils::{TOOLCHAIN_VERSION, VERSION};

#[cfg(feature = "tracing")]
pub mod _private {
pub use crate::{
classifier::Classifier,
config::{ModelConfig, ModelMetadata, ModelType},
config::{ModelConfig, ModelMetadata, ModelType, TransformerLayerConfig},
encodable_block::{DecoderDecodeInput, Sampling},
forward_pass::{
cache_layers::CacheLayers, kv_cache_layer::KVCacheLayer, token_inputs::TokenInputs, traces::ActivationTrace,
Expand Down
4 changes: 3 additions & 1 deletion crates/backend-uzu/src/parameters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ mod safetensors_metadata;
// Re-export the safetensors header reader so other modules (e.g. decoder
// runner) can estimate parameter memory before creating a Context.
pub use loader::{ParameterLeaf, ParameterLoader, ParameterLoaderError, ParameterTree};
pub use safetensors_metadata::{HeaderLoadingError, read_metadata as read_safetensors_metadata};
pub use safetensors_metadata::{
HeaderLoadingError, SafeTensorData, read_metadata as read_safetensors_metadata, write_safetensors,
};
92 changes: 81 additions & 11 deletions crates/backend-uzu/src/parameters/safetensors_metadata.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// This code is based on the safetensors implementation: https://docs.rs/safetensors/latest/src/safetensors/tensor.rs.html

use std::{collections::HashMap, fs::File};
use std::{
collections::HashMap,
fs::File,
io::{self, Write},
};

use serde::{Deserialize, Serialize};
use thiserror::Error;
Expand Down Expand Up @@ -95,17 +99,21 @@ impl From<Dtype> for DataType {
}
}

impl From<DataType> for Dtype {
fn from(dtype: DataType) -> Self {
impl TryFrom<DataType> for Dtype {
type Error = DataType;

fn try_from(dtype: DataType) -> Result<Self, Self::Error> {
match dtype {
DataType::F16 => Dtype::F16,
DataType::BF16 => Dtype::BF16,
DataType::F32 => Dtype::F32,
DataType::I8 => Dtype::I8,
DataType::I32 => Dtype::I32,
DataType::I64 => Dtype::I64,
DataType::U64 => Dtype::U64,
_ => panic!("Unsupported dtype: {:?}", dtype),
DataType::F16 => Ok(Dtype::F16),
DataType::BF16 => Ok(Dtype::BF16),
DataType::F32 => Ok(Dtype::F32),
DataType::I8 => Ok(Dtype::I8),
DataType::U8 => Ok(Dtype::U8),
DataType::I32 => Ok(Dtype::I32),
DataType::U32 => Ok(Dtype::U32),
DataType::I64 => Ok(Dtype::I64),
DataType::U64 => Ok(Dtype::U64),
DataType::F64 | DataType::I4 | DataType::U4 | DataType::I16 | DataType::U16 => Err(dtype),
}
}
}
Expand All @@ -129,3 +137,65 @@ pub fn read_metadata(file: &File) -> Result<(usize, HashMetadata), HeaderLoading
serde_json::from_str(string).map_err(|_| HeaderLoadingError::InvalidHeaderDeserialization)?;
Ok((stop, metadata))
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SafeTensorData {
pub name: String,
pub shape: Box<[usize]>,
pub data_type: DataType,
pub data: Box<[u8]>,
}

pub fn write_safetensors<W: Write>(
writer: &mut W,
tensors: &[SafeTensorData],
) -> Result<(), io::Error> {
let mut sorted_tensors: Vec<&SafeTensorData> = tensors.iter().collect();
sorted_tensors.sort_by(|left, right| {
right.data_type.size_in_bytes().cmp(&left.data_type.size_in_bytes()).then_with(|| left.name.cmp(&right.name))
});
let Some(first_tensor) = sorted_tensors.first() else {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "safetensors requires at least one tensor"));
};

let mut header = serde_json::Map::new();
let mut offset = 0;
for tensor in sorted_tensors.iter() {
let dtype = Dtype::try_from(tensor.data_type)
.map_err(|dtype| io::Error::new(io::ErrorKind::InvalidInput, format!("unsupported dtype: {dtype:?}")))?;
let expected_len = tensor
.shape
.iter()
.try_fold(tensor.data_type.size_in_bytes(), |size, dim| size.checked_mul(*dim))
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "safetensors tensor byte size overflows usize")
})?;
if expected_len != tensor.data.len() {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "safetensors tensor shape does not match data"));
}
let end = offset + tensor.data.len();
header.insert(
tensor.name.clone(),
serde_json::to_value(TensorInfo {
dtype,
shape: tensor.shape.to_vec(),
data_offsets: (offset, end),
})?,
);
offset = end;
}

let data_alignment = first_tensor.data_type.size_in_bytes().max(8);
let mut header_bytes = serde_json::to_vec(&header)?;
let padding = (data_alignment - header_bytes.len() % data_alignment) % data_alignment;
header_bytes.extend(std::iter::repeat_n(b' ', padding));

let header_len = u64::try_from(header_bytes.len())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "safetensors header length does not fit into u64"))?;
writer.write_all(&header_len.to_le_bytes())?;
writer.write_all(&header_bytes)?;
for tensor in sorted_tensors {
writer.write_all(&tensor.data)?;
}
Ok(())
}
2 changes: 2 additions & 0 deletions crates/backend-uzu/src/session/types/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ pub enum Error {
InvalidTtsRunConfig(#[from] TtsRunConfigError),
#[error("Unable to load model weights")]
UnableToLoadWeights,
#[error("Unable to write trace")]
UnableToWriteTrace,
#[error("Unable to load tokenizer")]
UnableToLoadTokenizer,
#[error("Model is too large to fit into available RAM")]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,73 @@
use std::fs::File;

use backend_uzu::read_safetensors_metadata;
use backend_uzu::{
ArrayContextExt, DataType, ParameterLoader, SafeTensorData,
backends::{
common::{Backend, Context, allocation_as_bytes},
cpu::Cpu,
},
read_safetensors_metadata, write_safetensors,
};

use crate::common::path::get_test_weights_path;

fn tensor_from_array<B: Backend>(
name: &str,
array: &backend_uzu::Array<B>,
) -> SafeTensorData {
SafeTensorData {
name: name.to_string(),
shape: array.shape().into(),
data_type: array.data_type(),
data: allocation_as_bytes(array.allocation()).into(),
}
}

#[test]
fn test_safetensors_metadata_writer_rejects_empty_tensors() {
let mut bytes = Vec::new();
let err = write_safetensors(&mut bytes, &[]).expect_err("empty safetensors should fail");

assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
}

#[test]
fn test_safetensors_metadata_writer_rejects_shape_data_mismatch() {
let mut bytes = Vec::new();
let err = write_safetensors(
&mut bytes,
&[SafeTensorData {
name: "bad".to_string(),
shape: [2].into(),
data_type: DataType::F32,
data: vec![0; 4].into_boxed_slice(),
}],
)
.expect_err("shape mismatch should fail");

assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput);
}

#[test]
fn test_metadata_loading() {
let path = get_test_weights_path();
let file = File::open(&path).expect("weights not found");
let (_offset, metadata) = read_safetensors_metadata(&file).expect("read metadata");
assert!(metadata.tensors.len() > 0);
}

#[test]
fn test_safetensors_metadata_writer_roundtrips_arrays() {
let context = <Cpu as Backend>::Context::new().expect("create CPU context");
let floats = context.create_array_from(&[2, 2], &[1.0f32, 2.0, 3.0, 4.0]);
let ints = context.create_array_from(&[2], &[7i32, 9]);
let mut file = tempfile::NamedTempFile::new().expect("create safetensors file");

write_safetensors(file.as_file_mut(), &[tensor_from_array("floats", &floats), tensor_from_array("ints", &ints)])
.expect("write safetensors file");

let loader_file = file.reopen().expect("open safetensors file");
let loader = ParameterLoader::new(&loader_file, context.as_ref()).expect("create loader");
assert_eq!(loader.tree().leaf_array("floats").unwrap().as_slice::<f32>(), &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(loader.tree().leaf_array("ints").unwrap().as_slice::<i32>(), &[7, 9]);
}
Loading