diff --git a/go.mod b/go.mod index ad0edfb2..7560d264 100644 --- a/go.mod +++ b/go.mod @@ -34,6 +34,7 @@ require ( golang.org/x/sync v0.19.0 golang.org/x/sys v0.41.0 google.golang.org/grpc v1.80.0 + google.golang.org/protobuf v1.36.11 oras.land/oras-go/v2 v2.6.0 ) @@ -164,7 +165,6 @@ require ( google.golang.org/api v0.214.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect - google.golang.org/protobuf v1.36.11 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/pkg/modelfile/constants.go b/pkg/modelfile/constants.go index efa4494e..c98f776c 100644 --- a/pkg/modelfile/constants.go +++ b/pkg/modelfile/constants.go @@ -140,6 +140,10 @@ var ( "*.ftz", // FastText compressed model "*.ark", // Kaldi ark format (speech/audio models) "*.db", // Database files (LMDB, etc.) + + // TensorFlow SavedModel literal-name files (no extension). + "feature_map", // TF SavedModel feature map definition + "checkpoint", // TF checkpoint pointer file (literal name) } // Code file patterns - supported script and notebook files. diff --git a/pkg/modelfile/constants_test.go b/pkg/modelfile/constants_test.go index 77fa50e6..e94d4621 100644 --- a/pkg/modelfile/constants_test.go +++ b/pkg/modelfile/constants_test.go @@ -121,6 +121,15 @@ func TestInferFileType(t *testing.T) { {"at threshold", "borderline", WeightFileSizeThreshold, FileTypeCode}, // Just above threshold should be model {"above threshold", "borderline", WeightFileSizeThreshold + 1, FileTypeModel}, + + // TF SavedModel literal-name files: must be MODEL even when 0 bytes, + // independent of the size heuristic that would otherwise classify them as CODE. + {"feature_map literal", "feature_map", 0, FileTypeModel}, + {"feature_map small", "feature_map", 1024, FileTypeModel}, + {"checkpoint literal small", "checkpoint", 32, FileTypeModel}, + // Negative: the literal patterns must not match same-stem-different-extension files. + {"feature_map.json is config", "feature_map.json", 1024, FileTypeConfig}, + {"checkpoint.bin is model via *.bin", "checkpoint.bin", 1024, FileTypeModel}, } assert := assert.New(t) diff --git a/pkg/modelfile/modelfile.go b/pkg/modelfile/modelfile.go index 8c715d02..1054ee18 100644 --- a/pkg/modelfile/modelfile.go +++ b/pkg/modelfile/modelfile.go @@ -218,6 +218,12 @@ func NewModelfileByWorkspace(workspace string, config *configmodelfile.GenerateC } mf.generateByConfig(config) + + // Best-effort: fill mf.format from MODEL file evidence when the user did not + // pass --format. Failure (no recognizable signal, panic in the loop, etc.) + // MUST NOT abort generation — Format is metadata, not load-bearing. + mf.inferFormat() + return mf, nil } @@ -346,6 +352,12 @@ func (mf *modelfile) generateByWorkspace(config *configmodelfile.GenerateConfig) return err } + // ONNX external_data post-processing: any tensor file referenced by an .onnx + // file via external_data.location is unconditionally a model weight file, + // regardless of its name or size. Walker may have classified small external + // tensor files as code/config/doc by extension/size heuristic; reclassify them. + mf.reclassifyONNXExternalData() + if mf.model.Size() == 0 && mf.code.Size() == 0 && mf.dataset.Size() == 0 { return fmt.Errorf("no model/code/dataset found - you have to create the Modelfile by yourself") } @@ -353,6 +365,145 @@ func (mf *modelfile) generateByWorkspace(config *configmodelfile.GenerateConfig) return nil } +// reclassifyONNXExternalData scans every .onnx file already in mf.model, +// extracts external_data.location paths, and moves those paths from whichever +// bucket the walker placed them in into mf.model. +// +// To avoid bypassing the walker's filtering (ExcludePatterns, isSkippable, file +// count / size limits, workspace boundary), this function ONLY reclassifies +// paths that are already present in one of the existing hashsets (config / +// code / doc / model). Paths that the walker excluded — including paths +// outside the workspace produced by a malformed `../` location — are silently +// ignored. ONNX parse failures degrade gracefully: a WARNING is printed and +// the affected .onnx's external tensors keep whatever classification the +// walker assigned (the pre-fix behavior). +func (mf *modelfile) reclassifyONNXExternalData() { + walkerCollected := func(rel string) bool { + return mf.model.Contains(rel) || mf.code.Contains(rel) || + mf.config.Contains(rel) || mf.doc.Contains(rel) + } + + for _, raw := range mf.model.Values() { + modelRel, ok := raw.(string) + if !ok || !strings.HasSuffix(strings.ToLower(modelRel), ".onnx") { + continue + } + onnxAbs := filepath.Join(mf.workspace, modelRel) + extPaths, err := ExtractONNXExternalDataPaths(onnxAbs) + if err != nil { + fmt.Fprintf(os.Stderr, + "WARNING: modelfile: failed to parse ONNX external_data from %s: %v "+ + "-- external tensor files (if any) will keep walker-assigned classification\n", + modelRel, err) + continue + } + onnxDir := filepath.Dir(modelRel) + for _, ext := range extPaths { + // Reject absolute external_data.location values outright. ONNX + // spec defines location as relative to the .onnx file's + // directory, so an absolute path is malformed; worse, + // filepath.Join silently strips the leading separator + // (Join(".", "/etc/secret") -> "etc/secret"), which would let + // an unrelated workspace file get reclassified to MODEL. + if filepath.IsAbs(ext) { + continue + } + relExt := filepath.Clean(filepath.Join(onnxDir, ext)) + // Walker membership check absorbs all of: + // - exclude pattern (walker dropped it -> not in any bucket) + // - skippable directories (.git, etc.) + // - file count / size limits (walker errored before adding) + // - workspace boundary (walker never sees ../outside paths) + // - file simply doesn't exist on disk + if !walkerCollected(relExt) { + continue + } + mf.code.Remove(relExt) + mf.config.Remove(relExt) + mf.doc.Remove(relExt) + mf.model.Add(relExt) + } + } +} + +// inferFormat fills mf.format from filename evidence collected by the walker +// when the user did not pass --format on the CLI. It only emits a value for +// highly specific signals (saved_model.pb[txt] / *.onnx / *.gguf / +// *.safetensors); generic extensions like *.bin / *.pt are left alone because +// they appear in many formats and would produce false positives. +// +// Priority order, when multiple signals coexist: +// +// 1. tensorflow — saved_model.pb / saved_model.pbtxt (SavedModel directory) +// 2. onnx — *.onnx +// 3. gguf — *.gguf +// 4. safetensors — *.safetensors +// +// SavedModel and ONNX are listed first because their layouts are uniquely +// identifiable; safetensors is last because it sometimes coexists with raw +// PyTorch shards in HF repos. +// +// We scan ALL four walker buckets (model / config / code / doc), not just +// mf.model. Reason: signals like `saved_model.pbtxt` are not in +// ModelFilePatterns and the walker therefore lands them in code/doc; if we +// scanned only mf.model, a SavedModel that ships only the .pbtxt variant would +// silently fall through. A set-based scan over every bucket closes that gap +// without changing how the walker classifies each individual file. +// +// Failure modes (no recognized signal, panic from a malformed value in the +// hashset, etc.) MUST NOT abort generation. The recover() guard ensures any +// unexpected panic degrades to "format stays empty" rather than killing the +// whole modelfile build. Format is best-effort metadata; the package gracefully +// handles a blank Format throughout the build/push/pull pipeline. +func (mf *modelfile) inferFormat() { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, + "WARNING: modelfile: format inference panicked, leaving Format empty: %v\n", r) + } + }() + + if mf.format != "" { + return + } + + var hasSavedModel, hasONNX, hasGGUF, hasSafetensors bool + scan := func(set *hashset.Set) { + for _, raw := range set.Values() { + rel, ok := raw.(string) + if !ok { + continue + } + base := strings.ToLower(filepath.Base(rel)) + switch { + case base == "saved_model.pb" || base == "saved_model.pbtxt": + hasSavedModel = true + case strings.HasSuffix(base, ".onnx"): + hasONNX = true + case strings.HasSuffix(base, ".gguf"): + hasGGUF = true + case strings.HasSuffix(base, ".safetensors"): + hasSafetensors = true + } + } + } + scan(mf.model) + scan(mf.config) + scan(mf.code) + scan(mf.doc) + + switch { + case hasSavedModel: + mf.format = "tensorflow" + case hasONNX: + mf.format = "onnx" + case hasGGUF: + mf.format = "gguf" + case hasSafetensors: + mf.format = "safetensors" + } +} + // generateByModelConfig generates the modelfile by the model config, such as config.json and generation_config.json. func (mf *modelfile) generateByModelConfig() error { // Get config map from json files. Collect all the keys and values from the config files diff --git a/pkg/modelfile/modelfile_test.go b/pkg/modelfile/modelfile_test.go index 46c05870..b65d5a13 100644 --- a/pkg/modelfile/modelfile_test.go +++ b/pkg/modelfile/modelfile_test.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "os" "path/filepath" "sort" @@ -388,6 +389,9 @@ func TestNewModelfileByWorkspace(t *testing.T) { "assets/images/preview.jpg", }, expectName: "nested-model", + // inferFormat picks safetensors on seeing weights/model.safetensors; + // model.bin is too generic to be a signal. + expectFormat: "safetensors", }, { name: "deep nested directories", @@ -621,6 +625,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { expectName: "llama-7b", expectArch: "transformer", expectFamily: "llama", + expectFormat: "safetensors", expectPrecision: "bfloat16", expectParamsize: "7B", }, @@ -811,6 +816,116 @@ func TestNewModelfileByWorkspace(t *testing.T) { expectCodes: []string{}, expectName: "selective-include", }, + { + // Mirrors screenshot 1 (TF SavedModel directory) plus the screenshots 4/5 + // /home/admin//{checkpoint,variables,...} layout: variables/ shards + // and a sibling checkpoint/ directory with model.ckpt-NNN.{data-*,index,meta}. + // Pre-fix, feature_map and the literal `checkpoint` files would have been + // classified as CODE because they have no extension and are small. + name: "tensorflow saved_model directory", + setupFiles: map[string]string{ + "saved_model.pb": "", + "feature_map": "", + "tf_signature.txt": "", + "alps.meta": "", + "ais-sync.done": "", + "variables/checkpoint": "", + "variables/variables.index": "", + "variables/variables.data-00000-of-00002": "", + "variables/variables.data-00001-of-00002": "", + "checkpoint/checkpoint": "", + "checkpoint/model.ckpt-756921.index": "", + "checkpoint/model.ckpt-756921.meta": "", + "checkpoint/model.ckpt-756921.data-00000-of-00002": "", + "checkpoint/model.ckpt-756921.data-00001-of-00002": "", + }, + config: &configmodelfile.GenerateConfig{ + Name: "tf-savedmodel", + }, + expectError: false, + // *.meta hits ConfigFilePatterns first (alps.meta, model.ckpt-*.meta). + expectConfigs: []string{ + "alps.meta", + "checkpoint/model.ckpt-756921.meta", + }, + expectModels: []string{ + "saved_model.pb", + "feature_map", + "variables/checkpoint", + "variables/variables.index", + "variables/variables.data-00000-of-00002", + "variables/variables.data-00001-of-00002", + "checkpoint/checkpoint", + "checkpoint/model.ckpt-756921.index", + "checkpoint/model.ckpt-756921.data-00000-of-00002", + "checkpoint/model.ckpt-756921.data-00001-of-00002", + }, + // ais-sync.done has no extension and is empty -> falls back to CODE by + // the size heuristic. Documented intentional behavior (user accepted). + expectCodes: []string{"ais-sync.done"}, + expectDocs: []string{"tf_signature.txt"}, + expectName: "tf-savedmodel", + // saved_model.pb in MODEL set is the canonical SavedModel signal. + expectFormat: "tensorflow", + }, + { + // Mirrors screenshot 3/4/5: a parent directory containing base// + // and delta// subdirs, each holding a complete TF SavedModel + + // checkpoint layout. Verifies the walker recurses correctly and that + // relative paths in the resulting Modelfile preserve the timestamp + // directory prefix so pull-side reconstruction restores the full tree. + name: "online-learning base plus delta tree", + setupFiles: map[string]string{ + "base/20251202030708/saved_model.pb": "", + "base/20251202030708/feature_map": "", + "base/20251202030708/tf_signature.txt": "", + "base/20251202030708/alps.meta": "", + "base/20251202030708/ais-sync.done": "", + "base/20251202030708/variables/checkpoint": "", + "base/20251202030708/variables/variables.index": "", + "base/20251202030708/variables/variables.data-00000-of-00002": "", + "delta/20251202060205/saved_model.pb": "", + "delta/20251202060205/feature_map": "", + "delta/20251202060205/tf_signature.txt": "", + "delta/20251202060205/alps.meta": "", + "delta/20251202060205/ais-sync.done": "", + "delta/20251202060205/variables/checkpoint": "", + "delta/20251202060205/variables/variables.index": "", + "delta/20251202060205/variables/variables.data-00000-of-00002": "", + }, + config: &configmodelfile.GenerateConfig{ + Name: "online-learning", + }, + expectError: false, + expectConfigs: []string{ + "base/20251202030708/alps.meta", + "delta/20251202060205/alps.meta", + }, + expectModels: []string{ + "base/20251202030708/saved_model.pb", + "base/20251202030708/feature_map", + "base/20251202030708/variables/checkpoint", + "base/20251202030708/variables/variables.index", + "base/20251202030708/variables/variables.data-00000-of-00002", + "delta/20251202060205/saved_model.pb", + "delta/20251202060205/feature_map", + "delta/20251202060205/variables/checkpoint", + "delta/20251202060205/variables/variables.index", + "delta/20251202060205/variables/variables.data-00000-of-00002", + }, + expectCodes: []string{ + "base/20251202030708/ais-sync.done", + "delta/20251202060205/ais-sync.done", + }, + expectDocs: []string{ + "base/20251202030708/tf_signature.txt", + "delta/20251202060205/tf_signature.txt", + }, + expectName: "online-learning", + // Multiple saved_model.pb instances under base/ and delta/ all signal + // tensorflow; inferFormat is set-based, so duplicates do not matter. + expectFormat: "tensorflow", + }, } assert := assert.New(t) @@ -2125,3 +2240,315 @@ func min(a, b int64) int64 { } return b } + +// TestNewModelfileByWorkspace_ONNXExternalData mirrors screenshot 2: an ONNX +// directory where model.onnx references many extension-less tensor files via +// external_data. Pre-fix, small external tensor files were classified as CODE +// by the size heuristic; the ONNX post-processor must reclassify them all as +// MODEL deterministically — regardless of the file's size. +func TestNewModelfileByWorkspace_ONNXExternalData(t *testing.T) { + // Subset of the 36 extension-less files seen in screenshot 2. Names taken + // verbatim from the actual workflow_10062365 directory. + externalLocations := []string{ + "tower_deep_layer_0_kernel_read__448_1", + "tower_shallow_layer_0_kernel_read__440_0", + "moe_layer_layer_0_kernel__399_15", + "moe_layer_layer_0_domain_bn_ExpandDims__400_16", + "moe_layer_search_layer_0_kernel__413_3", + "moe_layer_non_search_layer_0_kernel__427_9", + "feature_gate_main_kernel_read__352_27", + "feature_gate_main_domain_bn_ExpandDims__354_28", + "gated_dcn_layer_0_transform_kernel__385_21", + "conversion_layer_0_kernel_read__364_35", + "main_domain_embedding_Mark_output_user_main_domain_kernel_read__270_20", + "external_data_for_resource_handle", + } + + tempDir, err := os.MkdirTemp("", "modelfile-onnx-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Write the ONNX file with external_data references encoded into protobuf bytes. + require.NoError(t, os.WriteFile( + filepath.Join(tempDir, "model.onnx"), + buildMinimalONNXBytes(externalLocations), + 0o644, + )) + + // Create each external tensor file as an empty placeholder. + for _, name := range externalLocations { + require.NoError(t, os.WriteFile(filepath.Join(tempDir, name), nil, 0o644)) + } + + // Sibling files from the screenshot. + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "tf_signature.txt"), nil, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "variable_sequence_def_0.onnx"), buildMinimalONNXBytes(nil), 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "onnx-external-data", + Workspace: tempDir, + }) + require.NoError(t, err) + + models := mf.GetModels() + codes := mf.GetCodes() + configs := mf.GetConfigs() + + // Every external_data location must appear in MODEL, regardless of file size. + for _, name := range externalLocations { + assert.Contains(t, models, name, "external_data tensor %q must be reclassified as MODEL", name) + assert.NotContains(t, codes, name, "external_data tensor %q must NOT be in CODE", name) + assert.NotContains(t, configs, name, "external_data tensor %q must NOT be in CONFIG", name) + } + // Both .onnx files themselves are MODELs by extension. + assert.Contains(t, models, "model.onnx") + assert.Contains(t, models, "variable_sequence_def_0.onnx") + // tf_signature.txt is a doc per default DocFilePatterns (*.txt). + assert.Contains(t, mf.GetDocs(), "tf_signature.txt") + // inferFormat: any *.onnx in MODEL emits format=onnx. + assert.Equal(t, "onnx", mf.GetFormat()) +} + +// P1 coverage: an external_data.location of "../escape" must NOT promote +// anything outside the workspace. The walker never collected such a path, so +// it must be silently dropped from reclassification. +func TestNewModelfileByWorkspace_ONNXExternalDataPathTraversalIgnored(t *testing.T) { + tempDir := t.TempDir() + + require.NoError(t, os.WriteFile( + filepath.Join(tempDir, "model.onnx"), + buildMinimalONNXBytes([]string{"../escape", "weights.bin"}), + 0o644, + )) + // weights.bin is a real, walker-collected file — should reclassify normally. + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "weights.bin"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "onnx-traversal", + Workspace: tempDir, + }) + require.NoError(t, err) + + models := mf.GetModels() + assert.Contains(t, models, "weights.bin", "in-workspace external tensor must be reclassified") + for _, m := range models { + assert.NotContains(t, m, "..", "no model entry should escape the workspace") + assert.NotEqual(t, "../escape", m) + assert.NotEqual(t, "escape", m) + } +} + +// P1 coverage: ExcludePatterns must not be silently overridden by ONNX +// reclassification. If the walker excluded a tensor file, it must remain +// excluded — even if model.onnx references it via external_data. +func TestNewModelfileByWorkspace_ONNXExternalDataRespectsExclude(t *testing.T) { + tempDir := t.TempDir() + + require.NoError(t, os.WriteFile( + filepath.Join(tempDir, "model.onnx"), + buildMinimalONNXBytes([]string{"weights_keep.bin", "weights_drop.bin"}), + 0o644, + )) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "weights_keep.bin"), nil, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "weights_drop.bin"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "onnx-exclude", + Workspace: tempDir, + ExcludePatterns: []string{"weights_drop.bin"}, + }) + require.NoError(t, err) + + models := mf.GetModels() + assert.Contains(t, models, "weights_keep.bin") + assert.NotContains(t, models, "weights_drop.bin", + "excluded path must not be re-added by ONNX reclassification") + // And it must not have leaked into any other bucket either. + all := append(append(append([]string{}, mf.GetCodes()...), mf.GetConfigs()...), mf.GetDocs()...) + assert.NotContains(t, all, "weights_drop.bin") +} + +// captureStderr redirects os.Stderr through a pipe while fn runs and returns +// everything written to it. Standard Go pattern for asserting on stderr text +// without coupling tests to a specific log library. +func captureStderr(t *testing.T, fn func()) string { + t.Helper() + r, w, err := os.Pipe() + require.NoError(t, err) + old := os.Stderr + os.Stderr = w + defer func() { os.Stderr = old }() + + done := make(chan string, 1) + go func() { + var buf strings.Builder + _, _ = io.Copy(&buf, r) + done <- buf.String() + }() + + fn() + require.NoError(t, w.Close()) + return <-done +} + +// P3 coverage: a corrupted .onnx file must not crash generate. The .onnx +// itself stays in MODEL (it matches *.onnx), any sibling tensor files keep +// whatever classification the walker already gave them, AND a clearly-prefixed +// WARNING is emitted to stderr — locking that user-visible contract. +func TestNewModelfileByWorkspace_ONNXParseFailureFallsBack(t *testing.T) { + tempDir := t.TempDir() + + // Corrupted bytes that look like ONNX but aren't valid wire format. + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "model.onnx"), []byte{0xff, 0xff, 0xff, 0xff}, 0o644)) + // A small extension-less file the walker would classify as CODE — stays CODE on parse failure. + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "auxiliary_blob"), nil, 0o644)) + + var ( + mf Modelfile + err error + ) + stderr := captureStderr(t, func() { + mf, err = NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "onnx-corrupt", + Workspace: tempDir, + }) + }) + + require.NoError(t, err, "corrupted ONNX must not abort generate") + assert.Contains(t, mf.GetModels(), "model.onnx", ".onnx file itself stays in MODEL by extension") + assert.NotEmpty(t, mf.GetModels()) + + // Lock the WARNING contract: prefix + offending file path + fallback note. + assert.Contains(t, stderr, "WARNING:", "warning prefix must be printed") + assert.Contains(t, stderr, "model.onnx", "warning must name the failing file") + assert.Contains(t, stderr, "keep walker-assigned classification", + "warning must explain the fallback so users know external tensors were not reclassified") +} + +// P1 follow-up: an absolute external_data.location must be rejected outright, +// even when the absolute path's stripped form happens to match a real +// workspace file. Previously filepath.Join(".", "/decoy") would produce +// "decoy" and reclassify an unrelated file. +func TestNewModelfileByWorkspace_ONNXExternalDataAbsolutePathRejected(t *testing.T) { + tempDir := t.TempDir() + + // Locations: one absolute (must be rejected), one normal (must reclassify). + require.NoError(t, os.WriteFile( + filepath.Join(tempDir, "model.onnx"), + buildMinimalONNXBytes([]string{"/decoy", "weights.bin"}), + 0o644, + )) + // "decoy" exists in the workspace at root — would be the misclassification + // target if Join silently strips the leading slash from "/decoy". + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "decoy"), nil, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "weights.bin"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "onnx-abs", + Workspace: tempDir, + }) + require.NoError(t, err) + + // weights.bin reclassifies normally. + assert.Contains(t, mf.GetModels(), "weights.bin") + // "decoy" must NOT be in MODEL — the abs path "/decoy" was rejected before + // Join could produce a misleading "decoy" relative path. + assert.NotContains(t, mf.GetModels(), "decoy", + "absolute external_data.location must not promote unrelated workspace files") +} + +// TestNewModelfileByWorkspace_InferFormatCLIFlagWins verifies that an explicit +// --format on the CLI is never overridden by inferFormat, even when the +// workspace contains a clear signal (saved_model.pb) for a different format. +// Inference is best-effort metadata fill and must defer to user intent. +func TestNewModelfileByWorkspace_InferFormatCLIFlagWins(t *testing.T) { + tempDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "saved_model.pb"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "tf-but-custom", + Workspace: tempDir, + Format: "custom-format", + }) + require.NoError(t, err) + assert.Equal(t, "custom-format", mf.GetFormat(), + "explicit --format must override inferFormat's tensorflow guess") +} + +// TestNewModelfileByWorkspace_InferFormatNoSignal verifies that an unambiguous +// no-signal directory (e.g. a generic *.bin with no recognizable extension) +// leaves Format empty. We deliberately do NOT infer "pytorch" from .bin/.pt, +// because those extensions appear in many formats and would produce false +// positives on existing HF / TF / ONNX repos that also ship raw binary blobs. +func TestNewModelfileByWorkspace_InferFormatNoSignal(t *testing.T) { + tempDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "model.bin"), nil, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "weights.pt"), nil, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "README.md"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "no-signal", + Workspace: tempDir, + }) + require.NoError(t, err) + assert.Equal(t, "", mf.GetFormat(), + "generic .bin/.pt files alone must not trigger format inference") +} + +// TestNewModelfileByWorkspace_InferFormatPriority verifies the priority order: +// saved_model.pb wins over a sibling *.safetensors, because the SavedModel +// signal is uniquely diagnostic while *.safetensors can coexist with raw +// PyTorch shards in HF repos. +func TestNewModelfileByWorkspace_InferFormatPriority(t *testing.T) { + tempDir := t.TempDir() + // SavedModel layout + an unrelated safetensors blob in the same dir. + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "saved_model.pb"), nil, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "extra.safetensors"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "priority-mix", + Workspace: tempDir, + }) + require.NoError(t, err) + assert.Equal(t, "tensorflow", mf.GetFormat(), + "saved_model.pb must outrank a sibling .safetensors") +} + +// TestNewModelfileByWorkspace_InferFormatGGUF locks the .gguf signal — used by +// llama.cpp and adjacent ecosystems. .gguf is a self-contained quantized model +// container and is safe to map to format=gguf on filename alone. +func TestNewModelfileByWorkspace_InferFormatGGUF(t *testing.T) { + tempDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "ggml-model.gguf"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "gguf-only", + Workspace: tempDir, + }) + require.NoError(t, err) + assert.Equal(t, "gguf", mf.GetFormat()) +} + +// TestNewModelfileByWorkspace_InferFormatSavedModelPbtxt is regression coverage +// for a Codex finding: `saved_model.pbtxt` is not in ModelFilePatterns, so the +// walker lands it in CODE (or DOC). An earlier inferFormat that scanned only +// mf.model would silently miss the textual SavedModel variant. The fixed +// implementation scans all walker buckets, so a workspace whose only TF signal +// is a sibling .pbtxt still resolves to format=tensorflow. +func TestNewModelfileByWorkspace_InferFormatSavedModelPbtxt(t *testing.T) { + tempDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "saved_model.pbtxt"), nil, 0o644)) + // A real model file so generateByWorkspace doesn't error on empty MODEL set. + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "weights.safetensors"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "tf-textual", + Workspace: tempDir, + }) + require.NoError(t, err) + + // SavedModel signal must outrank the sibling .safetensors per the + // documented priority order, even though the .pbtxt is not in MODEL. + assert.Equal(t, "tensorflow", mf.GetFormat(), + "saved_model.pbtxt in any walker bucket must trigger format=tensorflow") +} diff --git a/pkg/modelfile/onnx.go b/pkg/modelfile/onnx.go new file mode 100644 index 00000000..266a8c11 --- /dev/null +++ b/pkg/modelfile/onnx.go @@ -0,0 +1,250 @@ +/* + * Copyright 2025 The ModelPack Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelfile + +import ( + "fmt" + "os" + + "google.golang.org/protobuf/encoding/protowire" +) + +// ONNX protobuf field numbers (subset relevant to external_data discovery). +// Schema reference: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto +const ( + onnxModelProtoGraphField = 7 + onnxModelProtoTrainingInfoField = 20 + onnxTrainingInfoInitializationField = 1 + onnxTrainingInfoAlgorithmField = 2 + onnxGraphProtoNodeField = 1 + onnxGraphProtoInitializerField = 5 + onnxGraphProtoSparseInitializerField = 15 + onnxNodeProtoAttributeField = 5 + // AttributeProto carries either a tensor (t / tensors / sparse_tensor / + // sparse_tensors) or a subgraph (g / graphs). Field numbers per onnx.proto. + onnxAttributeProtoTensorField = 5 + onnxAttributeProtoGraphField = 6 + onnxAttributeProtoTensorsField = 10 + onnxAttributeProtoGraphsField = 11 + onnxAttributeProtoSparseTensorField = 22 + onnxAttributeProtoSparseTensorsField = 23 + onnxTensorProtoExternalDataField = 13 + onnxSparseTensorProtoValuesField = 1 + onnxSparseTensorProtoIndicesField = 2 + onnxStringStringEntryProtoKeyField = 1 + onnxStringStringEntryProtoValueField = 2 + onnxExternalDataLocationKey = "location" + onnxMaxParseSize int = 512 * 1024 * 1024 + // Defends against pathological / adversarial ONNX with deeply nested If/Loop + // subgraphs. Real models rarely nest beyond 2-3 levels. + onnxMaxSubgraphDepth = 32 +) + +// ExtractONNXExternalDataPaths parses the ONNX model file at onnxPath and returns +// the relative paths of all external_data tensor files referenced by it. +// +// Paths are returned exactly as recorded in the ONNX file (relative to the +// directory containing the .onnx file, per ONNX convention). The slice is +// deduplicated and ordered by first appearance. +// +// An ONNX without any external_data references returns nil, nil. The function +// returns an error for: I/O failures (stat / read), files exceeding +// onnxMaxParseSize, and malformed protobuf wire data (e.g., truncated or +// corrupted bytes). Callers that want best-effort behavior should treat all +// errors as a fallback signal and continue without external_data information. +func ExtractONNXExternalDataPaths(onnxPath string) ([]string, error) { + info, err := os.Stat(onnxPath) + if err != nil { + return nil, fmt.Errorf("stat onnx file: %w", err) + } + if info.Size() > int64(onnxMaxParseSize) { + return nil, fmt.Errorf("onnx file %s exceeds parse size cap (%d bytes)", onnxPath, onnxMaxParseSize) + } + + data, err := os.ReadFile(onnxPath) + if err != nil { + return nil, fmt.Errorf("read onnx file: %w", err) + } + + var ( + seen = map[string]struct{}{} + locations []string + ) + collect := func(loc string) { + if loc == "" { + return + } + if _, dup := seen[loc]; dup { + return + } + seen[loc] = struct{}{} + locations = append(locations, loc) + } + + // Iterate ModelProto fields directly so we cover both the inference graph + // (field 7) and any training_info entries (field 20, repeated). Training- + // enabled ONNX (IR v7+) carries additional GraphProtos in + // training_info[*].initialization and training_info[*].algorithm; their + // initializers can also carry external_data references. + err = forEachField(data, func(num protowire.Number, _ protowire.Type, value []byte) error { + switch num { + case onnxModelProtoGraphField: + return walkGraph(value, collect, 0) + case onnxModelProtoTrainingInfoField: + return walkTrainingInfo(value, collect) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("walk ONNX model: %w", err) + } + return locations, nil +} + +// walkTrainingInfo descends into a TrainingInfoProto and walks its +// initialization and algorithm GraphProtos. The other fields +// (initialization_binding, update_binding) cannot carry external_data. +func walkTrainingInfo(ti []byte, collect func(string)) error { + return forEachField(ti, func(num protowire.Number, _ protowire.Type, value []byte) error { + switch num { + case onnxTrainingInfoInitializationField, onnxTrainingInfoAlgorithmField: + return walkGraph(value, collect, 0) + } + return nil + }) +} + +// walkGraph iterates a GraphProto: top-level initializer / sparse_initializer +// entries plus every NodeProto.attribute, recursing into subgraphs (If / Loop / +// Scan branches) up to onnxMaxSubgraphDepth levels deep. +func walkGraph(graph []byte, collect func(string), depth int) error { + if depth > onnxMaxSubgraphDepth { + return nil + } + return forEachField(graph, func(num protowire.Number, _ protowire.Type, value []byte) error { + switch num { + case onnxGraphProtoInitializerField: + return walkTensorExternalData(value, collect) + case onnxGraphProtoSparseInitializerField: + return walkSparseTensor(value, collect, depth) + case onnxGraphProtoNodeField: + return walkNode(value, collect, depth) + } + return nil + }) +} + +// walkNode descends into NodeProto.attribute entries to surface external_data +// references attached via Constant / If / Loop / Scan and similar ops. +func walkNode(node []byte, collect func(string), depth int) error { + return forEachField(node, func(num protowire.Number, _ protowire.Type, value []byte) error { + if num != onnxNodeProtoAttributeField { + return nil + } + return walkAttribute(value, collect, depth) + }) +} + +// walkAttribute handles the tensor- and subgraph-bearing fields of +// AttributeProto. Other attribute types (floats / ints / strings / type_protos) +// are skipped — none can carry external_data. +func walkAttribute(attr []byte, collect func(string), depth int) error { + return forEachField(attr, func(num protowire.Number, _ protowire.Type, value []byte) error { + switch num { + case onnxAttributeProtoTensorField, onnxAttributeProtoTensorsField: + return walkTensorExternalData(value, collect) + case onnxAttributeProtoSparseTensorField, onnxAttributeProtoSparseTensorsField: + return walkSparseTensor(value, collect, depth) + case onnxAttributeProtoGraphField, onnxAttributeProtoGraphsField: + return walkGraph(value, collect, depth+1) + } + return nil + }) +} + +// walkSparseTensor descends into a SparseTensorProto and walks its values + indices TensorProto. +func walkSparseTensor(sparse []byte, collect func(string), _ int) error { + return forEachField(sparse, func(num protowire.Number, _ protowire.Type, value []byte) error { + switch num { + case onnxSparseTensorProtoValuesField, onnxSparseTensorProtoIndicesField: + return walkTensorExternalData(value, collect) + } + return nil + }) +} + +// walkTensorExternalData scans a TensorProto for external_data StringStringEntryProto entries. +func walkTensorExternalData(tensor []byte, collect func(string)) error { + return forEachField(tensor, func(num protowire.Number, _ protowire.Type, value []byte) error { + if num != onnxTensorProtoExternalDataField { + return nil + } + var key, val string + if err := forEachField(value, func(n protowire.Number, _ protowire.Type, v []byte) error { + switch n { + case onnxStringStringEntryProtoKeyField: + key = string(v) + case onnxStringStringEntryProtoValueField: + val = string(v) + } + return nil + }); err != nil { + return err + } + if key == onnxExternalDataLocationKey { + collect(val) + } + return nil + }) +} + +// forEachField iterates protobuf wire-format fields in buf, invoking fn for each. +// For length-delimited fields, value is the inner payload bytes (header stripped). +// For varint/fixed32/fixed64/group fields, value is nil and the caller can rely on +// the field number/type to decide whether to consume further. Unknown/skipped +// fields advance the cursor without error. +func forEachField(buf []byte, fn func(num protowire.Number, typ protowire.Type, value []byte) error) error { + for len(buf) > 0 { + num, typ, n := protowire.ConsumeTag(buf) + if err := protowire.ParseError(n); err != nil { + return err + } + buf = buf[n:] + + switch typ { + case protowire.BytesType: + value, m := protowire.ConsumeBytes(buf) + if err := protowire.ParseError(m); err != nil { + return err + } + if err := fn(num, typ, value); err != nil { + return err + } + buf = buf[m:] + default: + m := protowire.ConsumeFieldValue(num, typ, buf) + if err := protowire.ParseError(m); err != nil { + return err + } + if err := fn(num, typ, nil); err != nil { + return err + } + buf = buf[m:] + } + } + return nil +} diff --git a/pkg/modelfile/onnx_test.go b/pkg/modelfile/onnx_test.go new file mode 100644 index 00000000..70203fdb --- /dev/null +++ b/pkg/modelfile/onnx_test.go @@ -0,0 +1,245 @@ +/* + * Copyright 2025 The ModelPack Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelfile + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" +) + +// buildMinimalONNXBytes constructs a wire-format byte sequence matching the +// minimal subtree of ONNX needed to exercise external_data extraction: +// +// ModelProto { +// graph: GraphProto { +// initializer: TensorProto[i] { +// external_data: StringStringEntryProto[]{ key="location", value= } +// } +// } +// } +// +// One initializer is emitted per externalLocation. Field numbers come from +// onnx.proto and are mirrored as constants in onnx.go. +func buildMinimalONNXBytes(externalLocations []string) []byte { + return encodeBytesField(onnxModelProtoGraphField, buildONNXGraphInitializerBytes(externalLocations)) +} + +// buildONNXGraphInitializerBytes returns the inner GraphProto payload — used +// directly when wrapping into a subgraph attribute or wherever raw graph bytes +// are needed without the outer ModelProto.graph header. +func buildONNXGraphInitializerBytes(externalLocations []string) []byte { + var graph []byte + for _, loc := range externalLocations { + entry := encodeStringStringEntry(onnxExternalDataLocationKey, loc) + tensor := encodeRepeated(onnxTensorProtoExternalDataField, entry) + graph = append(graph, encodeBytesField(onnxGraphProtoInitializerField, tensor)...) + } + return graph +} + +// buildONNXBytesWithConstantNodeAttribute constructs ONNX bytes where an +// external_data tensor is attached to a NodeProto.attribute.t (the shape +// produced by ONNX Constant ops). Pre-P2 the parser would have missed this. +func buildONNXBytesWithConstantNodeAttribute(location string) []byte { + entry := encodeStringStringEntry(onnxExternalDataLocationKey, location) + tensor := encodeRepeated(onnxTensorProtoExternalDataField, entry) + attr := encodeBytesField(onnxAttributeProtoTensorField, tensor) + node := encodeBytesField(onnxNodeProtoAttributeField, attr) + graph := encodeBytesField(onnxGraphProtoNodeField, node) + return encodeBytesField(onnxModelProtoGraphField, graph) +} + +// buildONNXBytesWithSubgraph nests a GraphProto inside a NodeProto.attribute.g +// (the shape produced by ONNX If / Loop / Scan ops). Pre-P2 the parser would +// have missed external_data references inside the subgraph. +func buildONNXBytesWithSubgraph(location string) []byte { + inner := buildONNXGraphInitializerBytes([]string{location}) + attr := encodeBytesField(onnxAttributeProtoGraphField, inner) + node := encodeBytesField(onnxNodeProtoAttributeField, attr) + outer := encodeBytesField(onnxGraphProtoNodeField, node) + return encodeBytesField(onnxModelProtoGraphField, outer) +} + +// buildONNXBytesWithTrainingInfo emits a ModelProto whose only graph-bearing +// fields are training_info[*].initialization and training_info[*].algorithm. +// The inference graph is intentionally absent — exercises the post-P4 +// behavior where training-only ONNX still yields its external tensors. +func buildONNXBytesWithTrainingInfo(initLoc, algoLoc string) []byte { + initGraph := buildONNXGraphInitializerBytes([]string{initLoc}) + algoGraph := buildONNXGraphInitializerBytes([]string{algoLoc}) + + var ti []byte + ti = append(ti, encodeBytesField(onnxTrainingInfoInitializationField, initGraph)...) + ti = append(ti, encodeBytesField(onnxTrainingInfoAlgorithmField, algoGraph)...) + + return encodeBytesField(onnxModelProtoTrainingInfoField, ti) +} + +func encodeStringStringEntry(key, value string) []byte { + out := encodeBytesField(onnxStringStringEntryProtoKeyField, []byte(key)) + out = append(out, encodeBytesField(onnxStringStringEntryProtoValueField, []byte(value))...) + return out +} + +// encodeRepeated wraps a single repeated-message instance so it appears as one +// entry of a repeated bytes-typed field. +func encodeRepeated(fieldNum int, payload []byte) []byte { + return encodeBytesField(fieldNum, payload) +} + +func encodeBytesField(fieldNum int, payload []byte) []byte { + out := protowire.AppendTag(nil, protowire.Number(fieldNum), protowire.BytesType) + out = protowire.AppendBytes(out, payload) + return out +} + +func TestExtractONNXExternalDataPaths_Empty(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, buildMinimalONNXBytes(nil), 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.Empty(t, paths) +} + +func TestExtractONNXExternalDataPaths_SingleExternal(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, buildMinimalONNXBytes([]string{"weights.bin"}), 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.Equal(t, []string{"weights.bin"}, paths) +} + +func TestExtractONNXExternalDataPaths_MultipleExternals(t *testing.T) { + // Mimics the screenshot ONNX layout: a model.onnx referencing many tensor + // files with arbitrary, extension-less names. + locations := []string{ + "tower_deep_layer_0_kernel_read__448_1", + "tower_shallow_layer_0_kernel_read__440_0", + "moe_layer_layer_0_kernel__399_15", + "feature_gate_main_kernel_read__352_27", + "external_data_for_resource_handle", + } + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, buildMinimalONNXBytes(locations), 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.Equal(t, locations, paths) +} + +func TestExtractONNXExternalDataPaths_DeduplicatesRepeatedLocation(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, buildMinimalONNXBytes([]string{"a", "b", "a", "c", "b"}), 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.Equal(t, []string{"a", "b", "c"}, paths) +} + +func TestExtractONNXExternalDataPaths_FileMissing(t *testing.T) { + _, err := ExtractONNXExternalDataPaths(filepath.Join(t.TempDir(), "nope.onnx")) + assert.Error(t, err) +} + +func TestExtractONNXExternalDataPaths_NonONNXBytes(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, []byte("not a protobuf"), 0o644)) + + // Garbage bytes are not valid wire format; expect either no paths or a + // parse error. Both are acceptable — the contract is "don't panic, don't + // silently return wrong data." + paths, err := ExtractONNXExternalDataPaths(onnxPath) + if err == nil { + assert.Empty(t, paths) + } +} + +// P2 coverage: external_data attached to a NodeProto.attribute.t (Constant op) +// must be discovered alongside top-level GraphProto.initializer entries. +func TestExtractONNXExternalDataPaths_NodeAttributeTensor(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, buildONNXBytesWithConstantNodeAttribute("const_weights.bin"), 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.Equal(t, []string{"const_weights.bin"}, paths) +} + +// P2 coverage: external_data inside a subgraph (If / Loop / Scan branch) must +// be discovered via attribute.g recursion. +func TestExtractONNXExternalDataPaths_Subgraph(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, buildONNXBytesWithSubgraph("branch_weights.bin"), 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.Equal(t, []string{"branch_weights.bin"}, paths) +} + +// P4 coverage: training-enabled ONNX (IR v7+) carries graphs in +// ModelProto.training_info[*].initialization and .algorithm. external_data in +// either of these must be discovered alongside the inference graph. +func TestExtractONNXExternalDataPaths_TrainingInfoGraphs(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile( + onnxPath, + buildONNXBytesWithTrainingInfo("init_state.bin", "optimizer_state.bin"), + 0o644, + )) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.ElementsMatch(t, []string{"init_state.bin", "optimizer_state.bin"}, paths) +} + +// P4 coverage: a model with both an inference graph AND training_info must +// surface external_data from both locations. +func TestExtractONNXExternalDataPaths_InferenceAndTrainingCombined(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + + // Inference graph initializer. + inferenceGraph := buildONNXGraphInitializerBytes([]string{"weights.bin"}) + // training_info with one initialization graph. + initGraph := buildONNXGraphInitializerBytes([]string{"init_state.bin"}) + ti := encodeBytesField(onnxTrainingInfoInitializationField, initGraph) + + // Compose ModelProto with both fields. + var model []byte + model = append(model, encodeBytesField(onnxModelProtoGraphField, inferenceGraph)...) + model = append(model, encodeBytesField(onnxModelProtoTrainingInfoField, ti)...) + require.NoError(t, os.WriteFile(onnxPath, model, 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.ElementsMatch(t, []string{"weights.bin", "init_state.bin"}, paths) +}