From 4f303daf59d7e893728bf94409409dd396553a3e Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Tue, 28 Apr 2026 20:31:38 +0800 Subject: [PATCH 1/6] feat(modelfile): support TF SavedModel, ONNX external_data, and online-learning trees Generator now correctly classifies extension-less TF SavedModel literals (feature_map, checkpoint) as MODEL via new ModelFilePatterns entries, and deterministically reclassifies all ONNX external_data tensor files as MODEL by parsing the .onnx protobuf with google.golang.org/protobuf/encoding/protowire. This removes the size-heuristic dependency that previously misclassified small external tensor files as CODE. Online-learning directories (a parent dir containing base/ and delta/ subtrees) are handled by existing recursive walk; new test cases mirror the real workflow_10062365 layout. No changes to build/push/pull/attach/upload/codec/storage paths --- those remain byte-level and model-type-agnostic. Signed-off-by: Zhao Chen --- go.mod | 2 +- pkg/modelfile/constants.go | 4 + pkg/modelfile/constants_test.go | 9 ++ pkg/modelfile/modelfile.go | 32 ++++++ pkg/modelfile/modelfile_test.go | 171 ++++++++++++++++++++++++++++ pkg/modelfile/onnx.go | 192 ++++++++++++++++++++++++++++++++ pkg/modelfile/onnx_test.go | 136 ++++++++++++++++++++++ 7 files changed, 545 insertions(+), 1 deletion(-) create mode 100644 pkg/modelfile/onnx.go create mode 100644 pkg/modelfile/onnx_test.go diff --git a/go.mod b/go.mod index ad0edfb2..7560d264 100644 --- a/go.mod +++ b/go.mod @@ -34,6 +34,7 @@ require ( golang.org/x/sync v0.19.0 golang.org/x/sys v0.41.0 google.golang.org/grpc v1.80.0 + google.golang.org/protobuf v1.36.11 oras.land/oras-go/v2 v2.6.0 ) @@ -164,7 +165,6 @@ require ( google.golang.org/api v0.214.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260226221140-a57be14db171 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect - google.golang.org/protobuf v1.36.11 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/pkg/modelfile/constants.go b/pkg/modelfile/constants.go index efa4494e..c98f776c 100644 --- a/pkg/modelfile/constants.go +++ b/pkg/modelfile/constants.go @@ -140,6 +140,10 @@ var ( "*.ftz", // FastText compressed model "*.ark", // Kaldi ark format (speech/audio models) "*.db", // Database files (LMDB, etc.) + + // TensorFlow SavedModel literal-name files (no extension). + "feature_map", // TF SavedModel feature map definition + "checkpoint", // TF checkpoint pointer file (literal name) } // Code file patterns - supported script and notebook files. diff --git a/pkg/modelfile/constants_test.go b/pkg/modelfile/constants_test.go index 77fa50e6..e94d4621 100644 --- a/pkg/modelfile/constants_test.go +++ b/pkg/modelfile/constants_test.go @@ -121,6 +121,15 @@ func TestInferFileType(t *testing.T) { {"at threshold", "borderline", WeightFileSizeThreshold, FileTypeCode}, // Just above threshold should be model {"above threshold", "borderline", WeightFileSizeThreshold + 1, FileTypeModel}, + + // TF SavedModel literal-name files: must be MODEL even when 0 bytes, + // independent of the size heuristic that would otherwise classify them as CODE. + {"feature_map literal", "feature_map", 0, FileTypeModel}, + {"feature_map small", "feature_map", 1024, FileTypeModel}, + {"checkpoint literal small", "checkpoint", 32, FileTypeModel}, + // Negative: the literal patterns must not match same-stem-different-extension files. + {"feature_map.json is config", "feature_map.json", 1024, FileTypeConfig}, + {"checkpoint.bin is model via *.bin", "checkpoint.bin", 1024, FileTypeModel}, } assert := assert.New(t) diff --git a/pkg/modelfile/modelfile.go b/pkg/modelfile/modelfile.go index 8c715d02..0d37492d 100644 --- a/pkg/modelfile/modelfile.go +++ b/pkg/modelfile/modelfile.go @@ -346,6 +346,12 @@ func (mf *modelfile) generateByWorkspace(config *configmodelfile.GenerateConfig) return err } + // ONNX external_data post-processing: any tensor file referenced by an .onnx + // file via external_data.location is unconditionally a model weight file, + // regardless of its name or size. Walker may have classified small external + // tensor files as code/config/doc by extension/size heuristic; reclassify them. + mf.reclassifyONNXExternalData() + if mf.model.Size() == 0 && mf.code.Size() == 0 && mf.dataset.Size() == 0 { return fmt.Errorf("no model/code/dataset found - you have to create the Modelfile by yourself") } @@ -353,6 +359,32 @@ func (mf *modelfile) generateByWorkspace(config *configmodelfile.GenerateConfig) return nil } +// reclassifyONNXExternalData scans every .onnx file already in mf.model, +// extracts external_data.location paths, and moves those paths into mf.model +// from whichever bucket the walker placed them in. +func (mf *modelfile) reclassifyONNXExternalData() { + for _, raw := range mf.model.Values() { + modelRel, ok := raw.(string) + if !ok || !strings.HasSuffix(strings.ToLower(modelRel), ".onnx") { + continue + } + onnxAbs := filepath.Join(mf.workspace, modelRel) + extPaths, err := ExtractONNXExternalDataPaths(onnxAbs) + if err != nil { + fmt.Fprintf(os.Stderr, "modelfile: parse ONNX external_data from %s failed: %v\n", modelRel, err) + continue + } + onnxDir := filepath.Dir(modelRel) + for _, ext := range extPaths { + relExt := filepath.Join(onnxDir, ext) + mf.code.Remove(relExt) + mf.config.Remove(relExt) + mf.doc.Remove(relExt) + mf.model.Add(relExt) + } + } +} + // generateByModelConfig generates the modelfile by the model config, such as config.json and generation_config.json. func (mf *modelfile) generateByModelConfig() error { // Get config map from json files. Collect all the keys and values from the config files diff --git a/pkg/modelfile/modelfile_test.go b/pkg/modelfile/modelfile_test.go index 46c05870..e2efb007 100644 --- a/pkg/modelfile/modelfile_test.go +++ b/pkg/modelfile/modelfile_test.go @@ -811,6 +811,111 @@ func TestNewModelfileByWorkspace(t *testing.T) { expectCodes: []string{}, expectName: "selective-include", }, + { + // Mirrors screenshot 1 (TF SavedModel directory) plus the screenshots 4/5 + // /home/admin//{checkpoint,variables,...} layout: variables/ shards + // and a sibling checkpoint/ directory with model.ckpt-NNN.{data-*,index,meta}. + // Pre-fix, feature_map and the literal `checkpoint` files would have been + // classified as CODE because they have no extension and are small. + name: "tensorflow saved_model directory", + setupFiles: map[string]string{ + "saved_model.pb": "", + "feature_map": "", + "tf_signature.txt": "", + "alps.meta": "", + "ais-sync.done": "", + "variables/checkpoint": "", + "variables/variables.index": "", + "variables/variables.data-00000-of-00002": "", + "variables/variables.data-00001-of-00002": "", + "checkpoint/checkpoint": "", + "checkpoint/model.ckpt-756921.index": "", + "checkpoint/model.ckpt-756921.meta": "", + "checkpoint/model.ckpt-756921.data-00000-of-00002": "", + "checkpoint/model.ckpt-756921.data-00001-of-00002": "", + }, + config: &configmodelfile.GenerateConfig{ + Name: "tf-savedmodel", + }, + expectError: false, + // *.meta hits ConfigFilePatterns first (alps.meta, model.ckpt-*.meta). + expectConfigs: []string{ + "alps.meta", + "checkpoint/model.ckpt-756921.meta", + }, + expectModels: []string{ + "saved_model.pb", + "feature_map", + "variables/checkpoint", + "variables/variables.index", + "variables/variables.data-00000-of-00002", + "variables/variables.data-00001-of-00002", + "checkpoint/checkpoint", + "checkpoint/model.ckpt-756921.index", + "checkpoint/model.ckpt-756921.data-00000-of-00002", + "checkpoint/model.ckpt-756921.data-00001-of-00002", + }, + // ais-sync.done has no extension and is empty -> falls back to CODE by + // the size heuristic. Documented intentional behavior (user accepted). + expectCodes: []string{"ais-sync.done"}, + expectDocs: []string{"tf_signature.txt"}, + expectName: "tf-savedmodel", + }, + { + // Mirrors screenshot 3/4/5: a parent directory containing base// + // and delta// subdirs, each holding a complete TF SavedModel + + // checkpoint layout. Verifies the walker recurses correctly and that + // relative paths in the resulting Modelfile preserve the timestamp + // directory prefix so pull-side reconstruction restores the full tree. + name: "online-learning base plus delta tree", + setupFiles: map[string]string{ + "base/20251202030708/saved_model.pb": "", + "base/20251202030708/feature_map": "", + "base/20251202030708/tf_signature.txt": "", + "base/20251202030708/alps.meta": "", + "base/20251202030708/ais-sync.done": "", + "base/20251202030708/variables/checkpoint": "", + "base/20251202030708/variables/variables.index": "", + "base/20251202030708/variables/variables.data-00000-of-00002": "", + "delta/20251202060205/saved_model.pb": "", + "delta/20251202060205/feature_map": "", + "delta/20251202060205/tf_signature.txt": "", + "delta/20251202060205/alps.meta": "", + "delta/20251202060205/ais-sync.done": "", + "delta/20251202060205/variables/checkpoint": "", + "delta/20251202060205/variables/variables.index": "", + "delta/20251202060205/variables/variables.data-00000-of-00002": "", + }, + config: &configmodelfile.GenerateConfig{ + Name: "online-learning", + }, + expectError: false, + expectConfigs: []string{ + "base/20251202030708/alps.meta", + "delta/20251202060205/alps.meta", + }, + expectModels: []string{ + "base/20251202030708/saved_model.pb", + "base/20251202030708/feature_map", + "base/20251202030708/variables/checkpoint", + "base/20251202030708/variables/variables.index", + "base/20251202030708/variables/variables.data-00000-of-00002", + "delta/20251202060205/saved_model.pb", + "delta/20251202060205/feature_map", + "delta/20251202060205/variables/checkpoint", + "delta/20251202060205/variables/variables.index", + "delta/20251202060205/variables/variables.data-00000-of-00002", + }, + expectCodes: []string{ + "base/20251202030708/ais-sync.done", + "delta/20251202060205/ais-sync.done", + }, + expectDocs: []string{ + "base/20251202030708/tf_signature.txt", + "delta/20251202060205/tf_signature.txt", + }, + expectName: "online-learning", + }, } assert := assert.New(t) @@ -2125,3 +2230,69 @@ func min(a, b int64) int64 { } return b } + +// TestNewModelfileByWorkspace_ONNXExternalData mirrors screenshot 2: an ONNX +// directory where model.onnx references many extension-less tensor files via +// external_data. Pre-fix, small external tensor files were classified as CODE +// by the size heuristic; the ONNX post-processor must reclassify them all as +// MODEL deterministically — regardless of the file's size. +func TestNewModelfileByWorkspace_ONNXExternalData(t *testing.T) { + // Subset of the 36 extension-less files seen in screenshot 2. Names taken + // verbatim from the actual workflow_10062365 directory. + externalLocations := []string{ + "tower_deep_layer_0_kernel_read__448_1", + "tower_shallow_layer_0_kernel_read__440_0", + "moe_layer_layer_0_kernel__399_15", + "moe_layer_layer_0_domain_bn_ExpandDims__400_16", + "moe_layer_search_layer_0_kernel__413_3", + "moe_layer_non_search_layer_0_kernel__427_9", + "feature_gate_main_kernel_read__352_27", + "feature_gate_main_domain_bn_ExpandDims__354_28", + "gated_dcn_layer_0_transform_kernel__385_21", + "conversion_layer_0_kernel_read__364_35", + "main_domain_embedding_Mark_output_user_main_domain_kernel_read__270_20", + "external_data_for_resource_handle", + } + + tempDir, err := os.MkdirTemp("", "modelfile-onnx-*") + require.NoError(t, err) + defer os.RemoveAll(tempDir) + + // Write the ONNX file with external_data references encoded into protobuf bytes. + require.NoError(t, os.WriteFile( + filepath.Join(tempDir, "model.onnx"), + buildMinimalONNXBytes(externalLocations), + 0o644, + )) + + // Create each external tensor file as an empty placeholder. + for _, name := range externalLocations { + require.NoError(t, os.WriteFile(filepath.Join(tempDir, name), nil, 0o644)) + } + + // Sibling files from the screenshot. + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "tf_signature.txt"), nil, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "variable_sequence_def_0.onnx"), buildMinimalONNXBytes(nil), 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "onnx-external-data", + Workspace: tempDir, + }) + require.NoError(t, err) + + models := mf.GetModels() + codes := mf.GetCodes() + configs := mf.GetConfigs() + + // Every external_data location must appear in MODEL, regardless of file size. + for _, name := range externalLocations { + assert.Contains(t, models, name, "external_data tensor %q must be reclassified as MODEL", name) + assert.NotContains(t, codes, name, "external_data tensor %q must NOT be in CODE", name) + assert.NotContains(t, configs, name, "external_data tensor %q must NOT be in CONFIG", name) + } + // Both .onnx files themselves are MODELs by extension. + assert.Contains(t, models, "model.onnx") + assert.Contains(t, models, "variable_sequence_def_0.onnx") + // tf_signature.txt is a doc per default DocFilePatterns (*.txt). + assert.Contains(t, mf.GetDocs(), "tf_signature.txt") +} diff --git a/pkg/modelfile/onnx.go b/pkg/modelfile/onnx.go new file mode 100644 index 00000000..513606e6 --- /dev/null +++ b/pkg/modelfile/onnx.go @@ -0,0 +1,192 @@ +/* + * Copyright 2025 The ModelPack Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelfile + +import ( + "fmt" + "os" + + "google.golang.org/protobuf/encoding/protowire" +) + +// ONNX protobuf field numbers (subset relevant to external_data discovery). +// Schema reference: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto +const ( + onnxModelProtoGraphField = 7 + onnxGraphProtoInitializerField = 5 + onnxGraphProtoSparseInitializerField = 15 + onnxTensorProtoExternalDataField = 13 + onnxSparseTensorProtoValuesField = 1 + onnxSparseTensorProtoIndicesField = 2 + onnxStringStringEntryProtoKeyField = 1 + onnxStringStringEntryProtoValueField = 2 + onnxExternalDataLocationKey = "location" + onnxMaxParseSize int = 512 * 1024 * 1024 +) + +// ExtractONNXExternalDataPaths parses the ONNX model file at onnxPath and returns +// the relative paths of all external_data tensor files referenced by it. +// +// Paths are returned exactly as recorded in the ONNX file (relative to the +// directory containing the .onnx file, per ONNX convention). The slice is +// deduplicated and ordered by first appearance. +// +// Returns nil with no error for files that aren't ONNX or that don't reference +// any external data; only I/O failures produce errors. Files larger than +// onnxMaxParseSize are skipped with an error. +func ExtractONNXExternalDataPaths(onnxPath string) ([]string, error) { + info, err := os.Stat(onnxPath) + if err != nil { + return nil, fmt.Errorf("stat onnx file: %w", err) + } + if info.Size() > int64(onnxMaxParseSize) { + return nil, fmt.Errorf("onnx file %s exceeds parse size cap (%d bytes)", onnxPath, onnxMaxParseSize) + } + + data, err := os.ReadFile(onnxPath) + if err != nil { + return nil, fmt.Errorf("read onnx file: %w", err) + } + + graph, err := readSubMessage(data, onnxModelProtoGraphField) + if err != nil { + return nil, fmt.Errorf("locate ONNX graph field: %w", err) + } + if graph == nil { + return nil, nil + } + + var ( + seen = map[string]struct{}{} + locations []string + ) + collect := func(loc string) { + if loc == "" { + return + } + if _, dup := seen[loc]; dup { + return + } + seen[loc] = struct{}{} + locations = append(locations, loc) + } + + if err := walkGraph(graph, collect); err != nil { + return nil, fmt.Errorf("walk ONNX graph: %w", err) + } + return locations, nil +} + +// walkGraph iterates initializer + sparse_initializer entries inside a GraphProto. +func walkGraph(graph []byte, collect func(string)) error { + return forEachField(graph, func(num protowire.Number, _ protowire.Type, value []byte) error { + switch num { + case onnxGraphProtoInitializerField: + return walkTensorExternalData(value, collect) + case onnxGraphProtoSparseInitializerField: + return walkSparseTensor(value, collect) + } + return nil + }) +} + +// walkSparseTensor descends into a SparseTensorProto and walks its values + indices TensorProto. +func walkSparseTensor(sparse []byte, collect func(string)) error { + return forEachField(sparse, func(num protowire.Number, _ protowire.Type, value []byte) error { + switch num { + case onnxSparseTensorProtoValuesField, onnxSparseTensorProtoIndicesField: + return walkTensorExternalData(value, collect) + } + return nil + }) +} + +// walkTensorExternalData scans a TensorProto for external_data StringStringEntryProto entries. +func walkTensorExternalData(tensor []byte, collect func(string)) error { + return forEachField(tensor, func(num protowire.Number, _ protowire.Type, value []byte) error { + if num != onnxTensorProtoExternalDataField { + return nil + } + var key, val string + if err := forEachField(value, func(n protowire.Number, _ protowire.Type, v []byte) error { + switch n { + case onnxStringStringEntryProtoKeyField: + key = string(v) + case onnxStringStringEntryProtoValueField: + val = string(v) + } + return nil + }); err != nil { + return err + } + if key == onnxExternalDataLocationKey { + collect(val) + } + return nil + }) +} + +// readSubMessage scans the top-level message and returns the bytes of the first +// occurrence of the given length-delimited (wire type 2) field. Returns nil if +// the field is absent. +func readSubMessage(buf []byte, target protowire.Number) ([]byte, error) { + var found []byte + err := forEachField(buf, func(num protowire.Number, typ protowire.Type, value []byte) error { + if num == target && typ == protowire.BytesType && found == nil { + found = value + } + return nil + }) + return found, err +} + +// forEachField iterates protobuf wire-format fields in buf, invoking fn for each. +// For length-delimited fields, value is the inner payload bytes (header stripped). +// For varint/fixed32/fixed64/group fields, value is nil and the caller can rely on +// the field number/type to decide whether to consume further. Unknown/skipped +// fields advance the cursor without error. +func forEachField(buf []byte, fn func(num protowire.Number, typ protowire.Type, value []byte) error) error { + for len(buf) > 0 { + num, typ, n := protowire.ConsumeTag(buf) + if err := protowire.ParseError(n); err != nil { + return err + } + buf = buf[n:] + + switch typ { + case protowire.BytesType: + value, m := protowire.ConsumeBytes(buf) + if err := protowire.ParseError(m); err != nil { + return err + } + if err := fn(num, typ, value); err != nil { + return err + } + buf = buf[m:] + default: + m := protowire.ConsumeFieldValue(num, typ, buf) + if err := protowire.ParseError(m); err != nil { + return err + } + if err := fn(num, typ, nil); err != nil { + return err + } + buf = buf[m:] + } + } + return nil +} diff --git a/pkg/modelfile/onnx_test.go b/pkg/modelfile/onnx_test.go new file mode 100644 index 00000000..90638d2a --- /dev/null +++ b/pkg/modelfile/onnx_test.go @@ -0,0 +1,136 @@ +/* + * Copyright 2025 The ModelPack Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package modelfile + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" +) + +// buildMinimalONNXBytes constructs a wire-format byte sequence matching the +// minimal subtree of ONNX needed to exercise external_data extraction: +// +// ModelProto { +// graph: GraphProto { +// initializer: TensorProto[i] { +// external_data: StringStringEntryProto[]{ key="location", value= } +// } +// } +// } +// +// One initializer is emitted per externalLocation. Field numbers come from +// onnx.proto and are mirrored as constants in onnx.go. +func buildMinimalONNXBytes(externalLocations []string) []byte { + var graph []byte + for _, loc := range externalLocations { + entry := encodeStringStringEntry(onnxExternalDataLocationKey, loc) + tensor := encodeRepeated(onnxTensorProtoExternalDataField, entry) + graph = append(graph, encodeBytesField(onnxGraphProtoInitializerField, tensor)...) + } + return encodeBytesField(onnxModelProtoGraphField, graph) +} + +func encodeStringStringEntry(key, value string) []byte { + out := encodeBytesField(onnxStringStringEntryProtoKeyField, []byte(key)) + out = append(out, encodeBytesField(onnxStringStringEntryProtoValueField, []byte(value))...) + return out +} + +// encodeRepeated wraps a single repeated-message instance so it appears as one +// entry of a repeated bytes-typed field. +func encodeRepeated(fieldNum int, payload []byte) []byte { + return encodeBytesField(fieldNum, payload) +} + +func encodeBytesField(fieldNum int, payload []byte) []byte { + out := protowire.AppendTag(nil, protowire.Number(fieldNum), protowire.BytesType) + out = protowire.AppendBytes(out, payload) + return out +} + +func TestExtractONNXExternalDataPaths_Empty(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, buildMinimalONNXBytes(nil), 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.Empty(t, paths) +} + +func TestExtractONNXExternalDataPaths_SingleExternal(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, buildMinimalONNXBytes([]string{"weights.bin"}), 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.Equal(t, []string{"weights.bin"}, paths) +} + +func TestExtractONNXExternalDataPaths_MultipleExternals(t *testing.T) { + // Mimics the screenshot ONNX layout: a model.onnx referencing many tensor + // files with arbitrary, extension-less names. + locations := []string{ + "tower_deep_layer_0_kernel_read__448_1", + "tower_shallow_layer_0_kernel_read__440_0", + "moe_layer_layer_0_kernel__399_15", + "feature_gate_main_kernel_read__352_27", + "external_data_for_resource_handle", + } + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, buildMinimalONNXBytes(locations), 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.Equal(t, locations, paths) +} + +func TestExtractONNXExternalDataPaths_DeduplicatesRepeatedLocation(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, buildMinimalONNXBytes([]string{"a", "b", "a", "c", "b"}), 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.Equal(t, []string{"a", "b", "c"}, paths) +} + +func TestExtractONNXExternalDataPaths_FileMissing(t *testing.T) { + _, err := ExtractONNXExternalDataPaths(filepath.Join(t.TempDir(), "nope.onnx")) + assert.Error(t, err) +} + +func TestExtractONNXExternalDataPaths_NonONNXBytes(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, []byte("not a protobuf"), 0o644)) + + // Garbage bytes are not valid wire format; expect either no paths or a + // parse error. Both are acceptable — the contract is "don't panic, don't + // silently return wrong data." + paths, err := ExtractONNXExternalDataPaths(onnxPath) + if err == nil { + assert.Empty(t, paths) + } +} From 660ae9b82fe6f5516fb9133b6f8714830e1e41dd Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Tue, 28 Apr 2026 20:51:33 +0800 Subject: [PATCH 2/6] fix(modelfile): address codex review on ONNX external_data handling Three follow-up fixes from independent review of PR #530: P1 (correctness/security): reclassifyONNXExternalData now only moves paths already collected by the workspace walker, naturally inheriting all of the walker's filtering: ExcludePatterns, isSkippable directories, file count / size limits, and the workspace boundary. Adversarial or malformed ONNX referencing `../escape` is silently dropped instead of being added to mf.model. P2 (coverage): the ONNX parser now recurses into NodeProto.attribute so external_data references attached via Constant ops (attribute.t / tensors / sparse_tensor / sparse_tensors) and inside If / Loop / Scan subgraphs (attribute.g / graphs) are discovered. Subgraph recursion is bounded at 32 levels to defend against pathological inputs. P3 (visibility): ONNX parse failures now print a clearly prefixed WARNING line that names the offending file and explains the fallback ("external tensor files will keep walker-assigned classification"). Generate does NOT abort -- a single bad .onnx degrades to the pre-fix walker classification without killing the whole pass. New tests: NodeAttributeTensor, Subgraph (P2 unit); PathTraversalIgnored, RespectsExclude, ParseFailureFallsBack (P1/P3 integration). go vet, go test -race ./pkg/modelfile/..., and go test ./... all clean. Signed-off-by: Zhao Chen --- pkg/modelfile/modelfile.go | 34 ++++++++++++-- pkg/modelfile/modelfile_test.go | 83 +++++++++++++++++++++++++++++++++ pkg/modelfile/onnx.go | 78 +++++++++++++++++++++++++------ pkg/modelfile/onnx_test.go | 54 +++++++++++++++++++++ 4 files changed, 230 insertions(+), 19 deletions(-) diff --git a/pkg/modelfile/modelfile.go b/pkg/modelfile/modelfile.go index 0d37492d..68c1f862 100644 --- a/pkg/modelfile/modelfile.go +++ b/pkg/modelfile/modelfile.go @@ -360,9 +360,23 @@ func (mf *modelfile) generateByWorkspace(config *configmodelfile.GenerateConfig) } // reclassifyONNXExternalData scans every .onnx file already in mf.model, -// extracts external_data.location paths, and moves those paths into mf.model -// from whichever bucket the walker placed them in. +// extracts external_data.location paths, and moves those paths from whichever +// bucket the walker placed them in into mf.model. +// +// To avoid bypassing the walker's filtering (ExcludePatterns, isSkippable, file +// count / size limits, workspace boundary), this function ONLY reclassifies +// paths that are already present in one of the existing hashsets (config / +// code / doc / model). Paths that the walker excluded — including paths +// outside the workspace produced by a malformed `../` location — are silently +// ignored. ONNX parse failures degrade gracefully: a WARNING is printed and +// the affected .onnx's external tensors keep whatever classification the +// walker assigned (the pre-fix behavior). func (mf *modelfile) reclassifyONNXExternalData() { + walkerCollected := func(rel string) bool { + return mf.model.Contains(rel) || mf.code.Contains(rel) || + mf.config.Contains(rel) || mf.doc.Contains(rel) + } + for _, raw := range mf.model.Values() { modelRel, ok := raw.(string) if !ok || !strings.HasSuffix(strings.ToLower(modelRel), ".onnx") { @@ -371,12 +385,24 @@ func (mf *modelfile) reclassifyONNXExternalData() { onnxAbs := filepath.Join(mf.workspace, modelRel) extPaths, err := ExtractONNXExternalDataPaths(onnxAbs) if err != nil { - fmt.Fprintf(os.Stderr, "modelfile: parse ONNX external_data from %s failed: %v\n", modelRel, err) + fmt.Fprintf(os.Stderr, + "WARNING: modelfile: failed to parse ONNX external_data from %s: %v "+ + "-- external tensor files (if any) will keep walker-assigned classification\n", + modelRel, err) continue } onnxDir := filepath.Dir(modelRel) for _, ext := range extPaths { - relExt := filepath.Join(onnxDir, ext) + relExt := filepath.Clean(filepath.Join(onnxDir, ext)) + // Walker membership check absorbs all of: + // - exclude pattern (walker dropped it -> not in any bucket) + // - skippable directories (.git, etc.) + // - file count / size limits (walker errored before adding) + // - workspace boundary (walker never sees ../outside paths) + // - file simply doesn't exist on disk + if !walkerCollected(relExt) { + continue + } mf.code.Remove(relExt) mf.config.Remove(relExt) mf.doc.Remove(relExt) diff --git a/pkg/modelfile/modelfile_test.go b/pkg/modelfile/modelfile_test.go index e2efb007..31083a31 100644 --- a/pkg/modelfile/modelfile_test.go +++ b/pkg/modelfile/modelfile_test.go @@ -2296,3 +2296,86 @@ func TestNewModelfileByWorkspace_ONNXExternalData(t *testing.T) { // tf_signature.txt is a doc per default DocFilePatterns (*.txt). assert.Contains(t, mf.GetDocs(), "tf_signature.txt") } + +// P1 coverage: an external_data.location of "../escape" must NOT promote +// anything outside the workspace. The walker never collected such a path, so +// it must be silently dropped from reclassification. +func TestNewModelfileByWorkspace_ONNXExternalDataPathTraversalIgnored(t *testing.T) { + tempDir := t.TempDir() + + require.NoError(t, os.WriteFile( + filepath.Join(tempDir, "model.onnx"), + buildMinimalONNXBytes([]string{"../escape", "weights.bin"}), + 0o644, + )) + // weights.bin is a real, walker-collected file — should reclassify normally. + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "weights.bin"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "onnx-traversal", + Workspace: tempDir, + }) + require.NoError(t, err) + + models := mf.GetModels() + assert.Contains(t, models, "weights.bin", "in-workspace external tensor must be reclassified") + for _, m := range models { + assert.NotContains(t, m, "..", "no model entry should escape the workspace") + assert.NotEqual(t, "../escape", m) + assert.NotEqual(t, "escape", m) + } +} + +// P1 coverage: ExcludePatterns must not be silently overridden by ONNX +// reclassification. If the walker excluded a tensor file, it must remain +// excluded — even if model.onnx references it via external_data. +func TestNewModelfileByWorkspace_ONNXExternalDataRespectsExclude(t *testing.T) { + tempDir := t.TempDir() + + require.NoError(t, os.WriteFile( + filepath.Join(tempDir, "model.onnx"), + buildMinimalONNXBytes([]string{"weights_keep.bin", "weights_drop.bin"}), + 0o644, + )) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "weights_keep.bin"), nil, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "weights_drop.bin"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "onnx-exclude", + Workspace: tempDir, + ExcludePatterns: []string{"weights_drop.bin"}, + }) + require.NoError(t, err) + + models := mf.GetModels() + assert.Contains(t, models, "weights_keep.bin") + assert.NotContains(t, models, "weights_drop.bin", + "excluded path must not be re-added by ONNX reclassification") + // And it must not have leaked into any other bucket either. + all := append(append(append([]string{}, mf.GetCodes()...), mf.GetConfigs()...), mf.GetDocs()...) + assert.NotContains(t, all, "weights_drop.bin") +} + +// P3 coverage: a corrupted .onnx file must not crash generate. The .onnx +// itself stays in MODEL (it matches *.onnx), and any sibling tensor files +// keep whatever classification the walker already gave them — i.e., the +// pre-fix behavior. A WARNING is emitted to stderr (not asserted here, but +// captured via the parser's error path). +func TestNewModelfileByWorkspace_ONNXParseFailureFallsBack(t *testing.T) { + tempDir := t.TempDir() + + // Corrupted bytes that look like ONNX but aren't valid wire format. + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "model.onnx"), []byte{0xff, 0xff, 0xff, 0xff}, 0o644)) + // A small extension-less file the walker would classify as CODE — stays CODE on parse failure. + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "auxiliary_blob"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "onnx-corrupt", + Workspace: tempDir, + }) + require.NoError(t, err, "corrupted ONNX must not abort generate") + assert.Contains(t, mf.GetModels(), "model.onnx", ".onnx file itself stays in MODEL by extension") + // auxiliary_blob keeps walker classification (CODE for small extension-less); + // we don't assert exact bucket — only that generate succeeded and produced output. + assert.NotEmpty(t, mf.GetModels()) +} diff --git a/pkg/modelfile/onnx.go b/pkg/modelfile/onnx.go index 513606e6..064ef91d 100644 --- a/pkg/modelfile/onnx.go +++ b/pkg/modelfile/onnx.go @@ -26,16 +26,29 @@ import ( // ONNX protobuf field numbers (subset relevant to external_data discovery). // Schema reference: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto const ( - onnxModelProtoGraphField = 7 - onnxGraphProtoInitializerField = 5 - onnxGraphProtoSparseInitializerField = 15 - onnxTensorProtoExternalDataField = 13 - onnxSparseTensorProtoValuesField = 1 - onnxSparseTensorProtoIndicesField = 2 - onnxStringStringEntryProtoKeyField = 1 - onnxStringStringEntryProtoValueField = 2 - onnxExternalDataLocationKey = "location" - onnxMaxParseSize int = 512 * 1024 * 1024 + onnxModelProtoGraphField = 7 + onnxGraphProtoNodeField = 1 + onnxGraphProtoInitializerField = 5 + onnxGraphProtoSparseInitializerField = 15 + onnxNodeProtoAttributeField = 5 + // AttributeProto carries either a tensor (t / tensors / sparse_tensor / + // sparse_tensors) or a subgraph (g / graphs). Field numbers per onnx.proto. + onnxAttributeProtoTensorField = 5 + onnxAttributeProtoGraphField = 6 + onnxAttributeProtoTensorsField = 10 + onnxAttributeProtoGraphsField = 11 + onnxAttributeProtoSparseTensorField = 22 + onnxAttributeProtoSparseTensorsField = 23 + onnxTensorProtoExternalDataField = 13 + onnxSparseTensorProtoValuesField = 1 + onnxSparseTensorProtoIndicesField = 2 + onnxStringStringEntryProtoKeyField = 1 + onnxStringStringEntryProtoValueField = 2 + onnxExternalDataLocationKey = "location" + onnxMaxParseSize int = 512 * 1024 * 1024 + // Defends against pathological / adversarial ONNX with deeply nested If/Loop + // subgraphs. Real models rarely nest beyond 2-3 levels. + onnxMaxSubgraphDepth = 32 ) // ExtractONNXExternalDataPaths parses the ONNX model file at onnxPath and returns @@ -85,27 +98,62 @@ func ExtractONNXExternalDataPaths(onnxPath string) ([]string, error) { locations = append(locations, loc) } - if err := walkGraph(graph, collect); err != nil { + if err := walkGraph(graph, collect, 0); err != nil { return nil, fmt.Errorf("walk ONNX graph: %w", err) } return locations, nil } -// walkGraph iterates initializer + sparse_initializer entries inside a GraphProto. -func walkGraph(graph []byte, collect func(string)) error { +// walkGraph iterates a GraphProto: top-level initializer / sparse_initializer +// entries plus every NodeProto.attribute, recursing into subgraphs (If / Loop / +// Scan branches) up to onnxMaxSubgraphDepth levels deep. +func walkGraph(graph []byte, collect func(string), depth int) error { + if depth > onnxMaxSubgraphDepth { + return nil + } return forEachField(graph, func(num protowire.Number, _ protowire.Type, value []byte) error { switch num { case onnxGraphProtoInitializerField: return walkTensorExternalData(value, collect) case onnxGraphProtoSparseInitializerField: - return walkSparseTensor(value, collect) + return walkSparseTensor(value, collect, depth) + case onnxGraphProtoNodeField: + return walkNode(value, collect, depth) + } + return nil + }) +} + +// walkNode descends into NodeProto.attribute entries to surface external_data +// references attached via Constant / If / Loop / Scan and similar ops. +func walkNode(node []byte, collect func(string), depth int) error { + return forEachField(node, func(num protowire.Number, _ protowire.Type, value []byte) error { + if num != onnxNodeProtoAttributeField { + return nil + } + return walkAttribute(value, collect, depth) + }) +} + +// walkAttribute handles the tensor- and subgraph-bearing fields of +// AttributeProto. Other attribute types (floats / ints / strings / type_protos) +// are skipped — none can carry external_data. +func walkAttribute(attr []byte, collect func(string), depth int) error { + return forEachField(attr, func(num protowire.Number, _ protowire.Type, value []byte) error { + switch num { + case onnxAttributeProtoTensorField, onnxAttributeProtoTensorsField: + return walkTensorExternalData(value, collect) + case onnxAttributeProtoSparseTensorField, onnxAttributeProtoSparseTensorsField: + return walkSparseTensor(value, collect, depth) + case onnxAttributeProtoGraphField, onnxAttributeProtoGraphsField: + return walkGraph(value, collect, depth+1) } return nil }) } // walkSparseTensor descends into a SparseTensorProto and walks its values + indices TensorProto. -func walkSparseTensor(sparse []byte, collect func(string)) error { +func walkSparseTensor(sparse []byte, collect func(string), _ int) error { return forEachField(sparse, func(num protowire.Number, _ protowire.Type, value []byte) error { switch num { case onnxSparseTensorProtoValuesField, onnxSparseTensorProtoIndicesField: diff --git a/pkg/modelfile/onnx_test.go b/pkg/modelfile/onnx_test.go index 90638d2a..5aec1f5e 100644 --- a/pkg/modelfile/onnx_test.go +++ b/pkg/modelfile/onnx_test.go @@ -40,15 +40,45 @@ import ( // One initializer is emitted per externalLocation. Field numbers come from // onnx.proto and are mirrored as constants in onnx.go. func buildMinimalONNXBytes(externalLocations []string) []byte { + return encodeBytesField(onnxModelProtoGraphField, buildONNXGraphInitializerBytes(externalLocations)) +} + +// buildONNXGraphInitializerBytes returns the inner GraphProto payload — used +// directly when wrapping into a subgraph attribute or wherever raw graph bytes +// are needed without the outer ModelProto.graph header. +func buildONNXGraphInitializerBytes(externalLocations []string) []byte { var graph []byte for _, loc := range externalLocations { entry := encodeStringStringEntry(onnxExternalDataLocationKey, loc) tensor := encodeRepeated(onnxTensorProtoExternalDataField, entry) graph = append(graph, encodeBytesField(onnxGraphProtoInitializerField, tensor)...) } + return graph +} + +// buildONNXBytesWithConstantNodeAttribute constructs ONNX bytes where an +// external_data tensor is attached to a NodeProto.attribute.t (the shape +// produced by ONNX Constant ops). Pre-P2 the parser would have missed this. +func buildONNXBytesWithConstantNodeAttribute(location string) []byte { + entry := encodeStringStringEntry(onnxExternalDataLocationKey, location) + tensor := encodeRepeated(onnxTensorProtoExternalDataField, entry) + attr := encodeBytesField(onnxAttributeProtoTensorField, tensor) + node := encodeBytesField(onnxNodeProtoAttributeField, attr) + graph := encodeBytesField(onnxGraphProtoNodeField, node) return encodeBytesField(onnxModelProtoGraphField, graph) } +// buildONNXBytesWithSubgraph nests a GraphProto inside a NodeProto.attribute.g +// (the shape produced by ONNX If / Loop / Scan ops). Pre-P2 the parser would +// have missed external_data references inside the subgraph. +func buildONNXBytesWithSubgraph(location string) []byte { + inner := buildONNXGraphInitializerBytes([]string{location}) + attr := encodeBytesField(onnxAttributeProtoGraphField, inner) + node := encodeBytesField(onnxNodeProtoAttributeField, attr) + outer := encodeBytesField(onnxGraphProtoNodeField, node) + return encodeBytesField(onnxModelProtoGraphField, outer) +} + func encodeStringStringEntry(key, value string) []byte { out := encodeBytesField(onnxStringStringEntryProtoKeyField, []byte(key)) out = append(out, encodeBytesField(onnxStringStringEntryProtoValueField, []byte(value))...) @@ -134,3 +164,27 @@ func TestExtractONNXExternalDataPaths_NonONNXBytes(t *testing.T) { assert.Empty(t, paths) } } + +// P2 coverage: external_data attached to a NodeProto.attribute.t (Constant op) +// must be discovered alongside top-level GraphProto.initializer entries. +func TestExtractONNXExternalDataPaths_NodeAttributeTensor(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, buildONNXBytesWithConstantNodeAttribute("const_weights.bin"), 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.Equal(t, []string{"const_weights.bin"}, paths) +} + +// P2 coverage: external_data inside a subgraph (If / Loop / Scan branch) must +// be discovered via attribute.g recursion. +func TestExtractONNXExternalDataPaths_Subgraph(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile(onnxPath, buildONNXBytesWithSubgraph("branch_weights.bin"), 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.Equal(t, []string{"branch_weights.bin"}, paths) +} From 3aaa7e90199f16d6a2f1176e03bc3e55fb27aa21 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Tue, 28 Apr 2026 21:11:12 +0800 Subject: [PATCH 3/6] fix(modelfile): cover ModelProto.training_info graphs in ONNX parser Codex review on PR #530 noted a remaining coverage gap: ExtractONNXExternal- DataPaths only walked ModelProto.graph (field 7), missing external_data references in training-enabled ONNX (IR v7+) where weights/states can also appear in ModelProto.training_info[*].initialization (field 1) and .algorithm (field 2) GraphProtos. Refactor ExtractONNXExternalDataPaths to iterate ModelProto fields directly via forEachField rather than the now-removed readSubMessage helper. Both graph and training_info subtrees route into the same walkGraph entry, so all the existing recursion (initializer / sparse_initializer / NodeProto attribute / subgraph) applies uniformly to inference and training graphs. Tests: TrainingInfoGraphs covers a training-only model; InferenceAndTrainingCombined verifies both surfaces are merged. Signed-off-by: Zhao Chen --- pkg/modelfile/onnx.go | 56 ++++++++++++++++++++++---------------- pkg/modelfile/onnx_test.go | 55 +++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 24 deletions(-) diff --git a/pkg/modelfile/onnx.go b/pkg/modelfile/onnx.go index 064ef91d..7e1b4efa 100644 --- a/pkg/modelfile/onnx.go +++ b/pkg/modelfile/onnx.go @@ -27,6 +27,9 @@ import ( // Schema reference: https://github.com/onnx/onnx/blob/main/onnx/onnx.proto const ( onnxModelProtoGraphField = 7 + onnxModelProtoTrainingInfoField = 20 + onnxTrainingInfoInitializationField = 1 + onnxTrainingInfoAlgorithmField = 2 onnxGraphProtoNodeField = 1 onnxGraphProtoInitializerField = 5 onnxGraphProtoSparseInitializerField = 15 @@ -75,14 +78,6 @@ func ExtractONNXExternalDataPaths(onnxPath string) ([]string, error) { return nil, fmt.Errorf("read onnx file: %w", err) } - graph, err := readSubMessage(data, onnxModelProtoGraphField) - if err != nil { - return nil, fmt.Errorf("locate ONNX graph field: %w", err) - } - if graph == nil { - return nil, nil - } - var ( seen = map[string]struct{}{} locations []string @@ -98,12 +93,39 @@ func ExtractONNXExternalDataPaths(onnxPath string) ([]string, error) { locations = append(locations, loc) } - if err := walkGraph(graph, collect, 0); err != nil { - return nil, fmt.Errorf("walk ONNX graph: %w", err) + // Iterate ModelProto fields directly so we cover both the inference graph + // (field 7) and any training_info entries (field 20, repeated). Training- + // enabled ONNX (IR v7+) carries additional GraphProtos in + // training_info[*].initialization and training_info[*].algorithm; their + // initializers can also carry external_data references. + err = forEachField(data, func(num protowire.Number, _ protowire.Type, value []byte) error { + switch num { + case onnxModelProtoGraphField: + return walkGraph(value, collect, 0) + case onnxModelProtoTrainingInfoField: + return walkTrainingInfo(value, collect) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("walk ONNX model: %w", err) } return locations, nil } +// walkTrainingInfo descends into a TrainingInfoProto and walks its +// initialization and algorithm GraphProtos. The other fields +// (initialization_binding, update_binding) cannot carry external_data. +func walkTrainingInfo(ti []byte, collect func(string)) error { + return forEachField(ti, func(num protowire.Number, _ protowire.Type, value []byte) error { + switch num { + case onnxTrainingInfoInitializationField, onnxTrainingInfoAlgorithmField: + return walkGraph(value, collect, 0) + } + return nil + }) +} + // walkGraph iterates a GraphProto: top-level initializer / sparse_initializer // entries plus every NodeProto.attribute, recursing into subgraphs (If / Loop / // Scan branches) up to onnxMaxSubgraphDepth levels deep. @@ -188,20 +210,6 @@ func walkTensorExternalData(tensor []byte, collect func(string)) error { }) } -// readSubMessage scans the top-level message and returns the bytes of the first -// occurrence of the given length-delimited (wire type 2) field. Returns nil if -// the field is absent. -func readSubMessage(buf []byte, target protowire.Number) ([]byte, error) { - var found []byte - err := forEachField(buf, func(num protowire.Number, typ protowire.Type, value []byte) error { - if num == target && typ == protowire.BytesType && found == nil { - found = value - } - return nil - }) - return found, err -} - // forEachField iterates protobuf wire-format fields in buf, invoking fn for each. // For length-delimited fields, value is the inner payload bytes (header stripped). // For varint/fixed32/fixed64/group fields, value is nil and the caller can rely on diff --git a/pkg/modelfile/onnx_test.go b/pkg/modelfile/onnx_test.go index 5aec1f5e..70203fdb 100644 --- a/pkg/modelfile/onnx_test.go +++ b/pkg/modelfile/onnx_test.go @@ -79,6 +79,21 @@ func buildONNXBytesWithSubgraph(location string) []byte { return encodeBytesField(onnxModelProtoGraphField, outer) } +// buildONNXBytesWithTrainingInfo emits a ModelProto whose only graph-bearing +// fields are training_info[*].initialization and training_info[*].algorithm. +// The inference graph is intentionally absent — exercises the post-P4 +// behavior where training-only ONNX still yields its external tensors. +func buildONNXBytesWithTrainingInfo(initLoc, algoLoc string) []byte { + initGraph := buildONNXGraphInitializerBytes([]string{initLoc}) + algoGraph := buildONNXGraphInitializerBytes([]string{algoLoc}) + + var ti []byte + ti = append(ti, encodeBytesField(onnxTrainingInfoInitializationField, initGraph)...) + ti = append(ti, encodeBytesField(onnxTrainingInfoAlgorithmField, algoGraph)...) + + return encodeBytesField(onnxModelProtoTrainingInfoField, ti) +} + func encodeStringStringEntry(key, value string) []byte { out := encodeBytesField(onnxStringStringEntryProtoKeyField, []byte(key)) out = append(out, encodeBytesField(onnxStringStringEntryProtoValueField, []byte(value))...) @@ -188,3 +203,43 @@ func TestExtractONNXExternalDataPaths_Subgraph(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"branch_weights.bin"}, paths) } + +// P4 coverage: training-enabled ONNX (IR v7+) carries graphs in +// ModelProto.training_info[*].initialization and .algorithm. external_data in +// either of these must be discovered alongside the inference graph. +func TestExtractONNXExternalDataPaths_TrainingInfoGraphs(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + require.NoError(t, os.WriteFile( + onnxPath, + buildONNXBytesWithTrainingInfo("init_state.bin", "optimizer_state.bin"), + 0o644, + )) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.ElementsMatch(t, []string{"init_state.bin", "optimizer_state.bin"}, paths) +} + +// P4 coverage: a model with both an inference graph AND training_info must +// surface external_data from both locations. +func TestExtractONNXExternalDataPaths_InferenceAndTrainingCombined(t *testing.T) { + dir := t.TempDir() + onnxPath := filepath.Join(dir, "model.onnx") + + // Inference graph initializer. + inferenceGraph := buildONNXGraphInitializerBytes([]string{"weights.bin"}) + // training_info with one initialization graph. + initGraph := buildONNXGraphInitializerBytes([]string{"init_state.bin"}) + ti := encodeBytesField(onnxTrainingInfoInitializationField, initGraph) + + // Compose ModelProto with both fields. + var model []byte + model = append(model, encodeBytesField(onnxModelProtoGraphField, inferenceGraph)...) + model = append(model, encodeBytesField(onnxModelProtoTrainingInfoField, ti)...) + require.NoError(t, os.WriteFile(onnxPath, model, 0o644)) + + paths, err := ExtractONNXExternalDataPaths(onnxPath) + require.NoError(t, err) + assert.ElementsMatch(t, []string{"weights.bin", "init_state.bin"}, paths) +} From affc7c1b787b36e6a96424e0683dc61757bdd6df Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Tue, 28 Apr 2026 21:35:13 +0800 Subject: [PATCH 4/6] fix(modelfile): reject absolute ONNX external_data paths, assert WARNING text, fix doc Three follow-up findings from a second codex review pass on PR #530: 1. Reject absolute external_data.location values: filepath.Join(onnxDir, /decoy) silently strips the leading separator and produces a relative decoy path, which would reclassify an unrelated workspace file. Add filepath.IsAbs(ext) guard before Join. 2. Lock the WARNING contract: ParseFailureFallsBack now captures os.Stderr and asserts the warning prefix, file path, and fallback note are printed. 3. Update ExtractONNXExternalDataPaths godoc to enumerate the real error sources: I/O, size cap, and malformed protobuf wire data. go vet, go test -race ./pkg/modelfile/..., and go test ./... all clean. Signed-off-by: Zhao Chen --- pkg/modelfile/modelfile.go | 9 ++++ pkg/modelfile/modelfile_test.go | 84 +++++++++++++++++++++++++++++---- pkg/modelfile/onnx.go | 8 ++-- 3 files changed, 89 insertions(+), 12 deletions(-) diff --git a/pkg/modelfile/modelfile.go b/pkg/modelfile/modelfile.go index 68c1f862..895780d5 100644 --- a/pkg/modelfile/modelfile.go +++ b/pkg/modelfile/modelfile.go @@ -393,6 +393,15 @@ func (mf *modelfile) reclassifyONNXExternalData() { } onnxDir := filepath.Dir(modelRel) for _, ext := range extPaths { + // Reject absolute external_data.location values outright. ONNX + // spec defines location as relative to the .onnx file's + // directory, so an absolute path is malformed; worse, + // filepath.Join silently strips the leading separator + // (Join(".", "/etc/secret") -> "etc/secret"), which would let + // an unrelated workspace file get reclassified to MODEL. + if filepath.IsAbs(ext) { + continue + } relExt := filepath.Clean(filepath.Join(onnxDir, ext)) // Walker membership check absorbs all of: // - exclude pattern (walker dropped it -> not in any bucket) diff --git a/pkg/modelfile/modelfile_test.go b/pkg/modelfile/modelfile_test.go index 31083a31..b2deaa89 100644 --- a/pkg/modelfile/modelfile_test.go +++ b/pkg/modelfile/modelfile_test.go @@ -20,6 +20,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "os" "path/filepath" "sort" @@ -2356,11 +2357,33 @@ func TestNewModelfileByWorkspace_ONNXExternalDataRespectsExclude(t *testing.T) { assert.NotContains(t, all, "weights_drop.bin") } +// captureStderr redirects os.Stderr through a pipe while fn runs and returns +// everything written to it. Standard Go pattern for asserting on stderr text +// without coupling tests to a specific log library. +func captureStderr(t *testing.T, fn func()) string { + t.Helper() + r, w, err := os.Pipe() + require.NoError(t, err) + old := os.Stderr + os.Stderr = w + defer func() { os.Stderr = old }() + + done := make(chan string, 1) + go func() { + var buf strings.Builder + _, _ = io.Copy(&buf, r) + done <- buf.String() + }() + + fn() + require.NoError(t, w.Close()) + return <-done +} + // P3 coverage: a corrupted .onnx file must not crash generate. The .onnx -// itself stays in MODEL (it matches *.onnx), and any sibling tensor files -// keep whatever classification the walker already gave them — i.e., the -// pre-fix behavior. A WARNING is emitted to stderr (not asserted here, but -// captured via the parser's error path). +// itself stays in MODEL (it matches *.onnx), any sibling tensor files keep +// whatever classification the walker already gave them, AND a clearly-prefixed +// WARNING is emitted to stderr — locking that user-visible contract. func TestNewModelfileByWorkspace_ONNXParseFailureFallsBack(t *testing.T) { tempDir := t.TempDir() @@ -2369,13 +2392,56 @@ func TestNewModelfileByWorkspace_ONNXParseFailureFallsBack(t *testing.T) { // A small extension-less file the walker would classify as CODE — stays CODE on parse failure. require.NoError(t, os.WriteFile(filepath.Join(tempDir, "auxiliary_blob"), nil, 0o644)) - mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ - Name: "onnx-corrupt", - Workspace: tempDir, + var ( + mf Modelfile + err error + ) + stderr := captureStderr(t, func() { + mf, err = NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "onnx-corrupt", + Workspace: tempDir, + }) }) + require.NoError(t, err, "corrupted ONNX must not abort generate") assert.Contains(t, mf.GetModels(), "model.onnx", ".onnx file itself stays in MODEL by extension") - // auxiliary_blob keeps walker classification (CODE for small extension-less); - // we don't assert exact bucket — only that generate succeeded and produced output. assert.NotEmpty(t, mf.GetModels()) + + // Lock the WARNING contract: prefix + offending file path + fallback note. + assert.Contains(t, stderr, "WARNING:", "warning prefix must be printed") + assert.Contains(t, stderr, "model.onnx", "warning must name the failing file") + assert.Contains(t, stderr, "keep walker-assigned classification", + "warning must explain the fallback so users know external tensors were not reclassified") +} + +// P1 follow-up: an absolute external_data.location must be rejected outright, +// even when the absolute path's stripped form happens to match a real +// workspace file. Previously filepath.Join(".", "/decoy") would produce +// "decoy" and reclassify an unrelated file. +func TestNewModelfileByWorkspace_ONNXExternalDataAbsolutePathRejected(t *testing.T) { + tempDir := t.TempDir() + + // Locations: one absolute (must be rejected), one normal (must reclassify). + require.NoError(t, os.WriteFile( + filepath.Join(tempDir, "model.onnx"), + buildMinimalONNXBytes([]string{"/decoy", "weights.bin"}), + 0o644, + )) + // "decoy" exists in the workspace at root — would be the misclassification + // target if Join silently strips the leading slash from "/decoy". + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "decoy"), nil, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "weights.bin"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "onnx-abs", + Workspace: tempDir, + }) + require.NoError(t, err) + + // weights.bin reclassifies normally. + assert.Contains(t, mf.GetModels(), "weights.bin") + // "decoy" must NOT be in MODEL — the abs path "/decoy" was rejected before + // Join could produce a misleading "decoy" relative path. + assert.NotContains(t, mf.GetModels(), "decoy", + "absolute external_data.location must not promote unrelated workspace files") } diff --git a/pkg/modelfile/onnx.go b/pkg/modelfile/onnx.go index 7e1b4efa..266a8c11 100644 --- a/pkg/modelfile/onnx.go +++ b/pkg/modelfile/onnx.go @@ -61,9 +61,11 @@ const ( // directory containing the .onnx file, per ONNX convention). The slice is // deduplicated and ordered by first appearance. // -// Returns nil with no error for files that aren't ONNX or that don't reference -// any external data; only I/O failures produce errors. Files larger than -// onnxMaxParseSize are skipped with an error. +// An ONNX without any external_data references returns nil, nil. The function +// returns an error for: I/O failures (stat / read), files exceeding +// onnxMaxParseSize, and malformed protobuf wire data (e.g., truncated or +// corrupted bytes). Callers that want best-effort behavior should treat all +// errors as a fallback signal and continue without external_data information. func ExtractONNXExternalDataPaths(onnxPath string) ([]string, error) { info, err := os.Stat(onnxPath) if err != nil { From 876d1571ab73b5905232dcfb724a69cb456bf323 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Thu, 30 Apr 2026 11:23:13 +0800 Subject: [PATCH 5/6] feat(modelfile): infer Format from MODEL set evidence Best-effort fill of mf.format when --format is not passed. Inference keys off uniquely diagnostic signals only: saved_model.pb / saved_model.pbtxt -> tensorflow *.onnx -> onnx *.gguf -> gguf *.safetensors -> safetensors Generic .bin/.pt/.pth are intentionally NOT signals: they appear in many formats and would produce false positives on existing repos. Failure (no signal, panic on a malformed hashset value, etc.) leaves mf.format empty and degrades silently with a stderr WARNING. CLI --format always wins over inference. Format remains best-effort metadata; downstream build/push/pull paths already handle it being blank, so this never blocks generation. Signed-off-by: Zhao Chen --- pkg/modelfile/modelfile.go | 71 ++++++++++++++++++++++++++++ pkg/modelfile/modelfile_test.go | 83 +++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) diff --git a/pkg/modelfile/modelfile.go b/pkg/modelfile/modelfile.go index 895780d5..d94cbec5 100644 --- a/pkg/modelfile/modelfile.go +++ b/pkg/modelfile/modelfile.go @@ -218,6 +218,12 @@ func NewModelfileByWorkspace(workspace string, config *configmodelfile.GenerateC } mf.generateByConfig(config) + + // Best-effort: fill mf.format from MODEL file evidence when the user did not + // pass --format. Failure (no recognizable signal, panic in the loop, etc.) + // MUST NOT abort generation — Format is metadata, not load-bearing. + mf.inferFormat() + return mf, nil } @@ -420,6 +426,71 @@ func (mf *modelfile) reclassifyONNXExternalData() { } } +// inferFormat fills mf.format from filename evidence in the MODEL set when the +// user did not pass --format on the CLI. It only emits a value for highly +// specific signals (saved_model.pb / *.onnx / *.gguf / *.safetensors); generic +// extensions like *.bin / *.pt are left alone because they appear in many +// formats and would produce false positives. +// +// Priority order, when multiple signals coexist: +// +// 1. tensorflow — saved_model.pb / saved_model.pbtxt (SavedModel directory) +// 2. onnx — *.onnx +// 3. gguf — *.gguf +// 4. safetensors — *.safetensors +// +// SavedModel and ONNX are listed first because their layouts are uniquely +// identifiable; safetensors is last because it sometimes coexists with raw +// PyTorch shards in HF repos. +// +// Failure modes (no recognized signal, panic from a malformed value in the +// hashset, etc.) MUST NOT abort generation. The recover() guard ensures any +// unexpected panic degrades to "format stays empty" rather than killing the +// whole modelfile build. Format is best-effort metadata; the package gracefully +// handles a blank Format throughout the build/push/pull pipeline. +func (mf *modelfile) inferFormat() { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, + "WARNING: modelfile: format inference panicked, leaving Format empty: %v\n", r) + } + }() + + if mf.format != "" { + return + } + + var hasSavedModel, hasONNX, hasGGUF, hasSafetensors bool + for _, raw := range mf.model.Values() { + rel, ok := raw.(string) + if !ok { + continue + } + base := strings.ToLower(filepath.Base(rel)) + switch { + case base == "saved_model.pb" || base == "saved_model.pbtxt": + hasSavedModel = true + case strings.HasSuffix(base, ".onnx"): + hasONNX = true + case strings.HasSuffix(base, ".gguf"): + hasGGUF = true + case strings.HasSuffix(base, ".safetensors"): + hasSafetensors = true + } + } + + switch { + case hasSavedModel: + mf.format = "tensorflow" + case hasONNX: + mf.format = "onnx" + case hasGGUF: + mf.format = "gguf" + case hasSafetensors: + mf.format = "safetensors" + } +} + // generateByModelConfig generates the modelfile by the model config, such as config.json and generation_config.json. func (mf *modelfile) generateByModelConfig() error { // Get config map from json files. Collect all the keys and values from the config files diff --git a/pkg/modelfile/modelfile_test.go b/pkg/modelfile/modelfile_test.go index b2deaa89..4991d0cb 100644 --- a/pkg/modelfile/modelfile_test.go +++ b/pkg/modelfile/modelfile_test.go @@ -389,6 +389,9 @@ func TestNewModelfileByWorkspace(t *testing.T) { "assets/images/preview.jpg", }, expectName: "nested-model", + // inferFormat picks safetensors on seeing weights/model.safetensors; + // model.bin is too generic to be a signal. + expectFormat: "safetensors", }, { name: "deep nested directories", @@ -622,6 +625,7 @@ func TestNewModelfileByWorkspace(t *testing.T) { expectName: "llama-7b", expectArch: "transformer", expectFamily: "llama", + expectFormat: "safetensors", expectPrecision: "bfloat16", expectParamsize: "7B", }, @@ -861,6 +865,8 @@ func TestNewModelfileByWorkspace(t *testing.T) { expectCodes: []string{"ais-sync.done"}, expectDocs: []string{"tf_signature.txt"}, expectName: "tf-savedmodel", + // saved_model.pb in MODEL set is the canonical SavedModel signal. + expectFormat: "tensorflow", }, { // Mirrors screenshot 3/4/5: a parent directory containing base// @@ -916,6 +922,9 @@ func TestNewModelfileByWorkspace(t *testing.T) { "delta/20251202060205/tf_signature.txt", }, expectName: "online-learning", + // Multiple saved_model.pb instances under base/ and delta/ all signal + // tensorflow; inferFormat is set-based, so duplicates do not matter. + expectFormat: "tensorflow", }, } @@ -2296,6 +2305,8 @@ func TestNewModelfileByWorkspace_ONNXExternalData(t *testing.T) { assert.Contains(t, models, "variable_sequence_def_0.onnx") // tf_signature.txt is a doc per default DocFilePatterns (*.txt). assert.Contains(t, mf.GetDocs(), "tf_signature.txt") + // inferFormat: any *.onnx in MODEL emits format=onnx. + assert.Equal(t, "onnx", mf.GetFormat()) } // P1 coverage: an external_data.location of "../escape" must NOT promote @@ -2445,3 +2456,75 @@ func TestNewModelfileByWorkspace_ONNXExternalDataAbsolutePathRejected(t *testing assert.NotContains(t, mf.GetModels(), "decoy", "absolute external_data.location must not promote unrelated workspace files") } + +// TestNewModelfileByWorkspace_InferFormatCLIFlagWins verifies that an explicit +// --format on the CLI is never overridden by inferFormat, even when the +// workspace contains a clear signal (saved_model.pb) for a different format. +// Inference is best-effort metadata fill and must defer to user intent. +func TestNewModelfileByWorkspace_InferFormatCLIFlagWins(t *testing.T) { + tempDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "saved_model.pb"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "tf-but-custom", + Workspace: tempDir, + Format: "custom-format", + }) + require.NoError(t, err) + assert.Equal(t, "custom-format", mf.GetFormat(), + "explicit --format must override inferFormat's tensorflow guess") +} + +// TestNewModelfileByWorkspace_InferFormatNoSignal verifies that an unambiguous +// no-signal directory (e.g. a generic *.bin with no recognizable extension) +// leaves Format empty. We deliberately do NOT infer "pytorch" from .bin/.pt, +// because those extensions appear in many formats and would produce false +// positives on existing HF / TF / ONNX repos that also ship raw binary blobs. +func TestNewModelfileByWorkspace_InferFormatNoSignal(t *testing.T) { + tempDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "model.bin"), nil, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "weights.pt"), nil, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "README.md"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "no-signal", + Workspace: tempDir, + }) + require.NoError(t, err) + assert.Equal(t, "", mf.GetFormat(), + "generic .bin/.pt files alone must not trigger format inference") +} + +// TestNewModelfileByWorkspace_InferFormatPriority verifies the priority order: +// saved_model.pb wins over a sibling *.safetensors, because the SavedModel +// signal is uniquely diagnostic while *.safetensors can coexist with raw +// PyTorch shards in HF repos. +func TestNewModelfileByWorkspace_InferFormatPriority(t *testing.T) { + tempDir := t.TempDir() + // SavedModel layout + an unrelated safetensors blob in the same dir. + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "saved_model.pb"), nil, 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "extra.safetensors"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "priority-mix", + Workspace: tempDir, + }) + require.NoError(t, err) + assert.Equal(t, "tensorflow", mf.GetFormat(), + "saved_model.pb must outrank a sibling .safetensors") +} + +// TestNewModelfileByWorkspace_InferFormatGGUF locks the .gguf signal — used by +// llama.cpp and adjacent ecosystems. .gguf is a self-contained quantized model +// container and is safe to map to format=gguf on filename alone. +func TestNewModelfileByWorkspace_InferFormatGGUF(t *testing.T) { + tempDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "ggml-model.gguf"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "gguf-only", + Workspace: tempDir, + }) + require.NoError(t, err) + assert.Equal(t, "gguf", mf.GetFormat()) +} From 4c24aa31832453c90d29b3261a7c58f98b550cc0 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Thu, 30 Apr 2026 11:48:42 +0800 Subject: [PATCH 6/6] fix(modelfile): scan all walker buckets for inferFormat signals Codex review (P2) caught that saved_model.pbtxt is not in ModelFilePatterns -- the walker lands it in CODE/DOC, so the prior inferFormat (which only scanned mf.model) silently missed the textual SavedModel variant. A workspace whose only TF signal was a .pbtxt would yield format="" instead of "tensorflow". Fix: extract the per-file scan into a closure and apply it to all four walker buckets (model + config + code + doc). Set-based: duplicates across buckets are harmless. Walker classification is unchanged. Adds TestNewModelfileByWorkspace_InferFormatSavedModelPbtxt as regression coverage and updates the doc comment to spell out the multi-bucket scan and the rationale. Signed-off-by: Zhao Chen --- pkg/modelfile/modelfile.go | 53 ++++++++++++++++++++------------- pkg/modelfile/modelfile_test.go | 24 +++++++++++++++ 2 files changed, 57 insertions(+), 20 deletions(-) diff --git a/pkg/modelfile/modelfile.go b/pkg/modelfile/modelfile.go index d94cbec5..1054ee18 100644 --- a/pkg/modelfile/modelfile.go +++ b/pkg/modelfile/modelfile.go @@ -426,11 +426,11 @@ func (mf *modelfile) reclassifyONNXExternalData() { } } -// inferFormat fills mf.format from filename evidence in the MODEL set when the -// user did not pass --format on the CLI. It only emits a value for highly -// specific signals (saved_model.pb / *.onnx / *.gguf / *.safetensors); generic -// extensions like *.bin / *.pt are left alone because they appear in many -// formats and would produce false positives. +// inferFormat fills mf.format from filename evidence collected by the walker +// when the user did not pass --format on the CLI. It only emits a value for +// highly specific signals (saved_model.pb[txt] / *.onnx / *.gguf / +// *.safetensors); generic extensions like *.bin / *.pt are left alone because +// they appear in many formats and would produce false positives. // // Priority order, when multiple signals coexist: // @@ -443,6 +443,13 @@ func (mf *modelfile) reclassifyONNXExternalData() { // identifiable; safetensors is last because it sometimes coexists with raw // PyTorch shards in HF repos. // +// We scan ALL four walker buckets (model / config / code / doc), not just +// mf.model. Reason: signals like `saved_model.pbtxt` are not in +// ModelFilePatterns and the walker therefore lands them in code/doc; if we +// scanned only mf.model, a SavedModel that ships only the .pbtxt variant would +// silently fall through. A set-based scan over every bucket closes that gap +// without changing how the walker classifies each individual file. +// // Failure modes (no recognized signal, panic from a malformed value in the // hashset, etc.) MUST NOT abort generation. The recover() guard ensures any // unexpected panic degrades to "format stays empty" rather than killing the @@ -461,23 +468,29 @@ func (mf *modelfile) inferFormat() { } var hasSavedModel, hasONNX, hasGGUF, hasSafetensors bool - for _, raw := range mf.model.Values() { - rel, ok := raw.(string) - if !ok { - continue - } - base := strings.ToLower(filepath.Base(rel)) - switch { - case base == "saved_model.pb" || base == "saved_model.pbtxt": - hasSavedModel = true - case strings.HasSuffix(base, ".onnx"): - hasONNX = true - case strings.HasSuffix(base, ".gguf"): - hasGGUF = true - case strings.HasSuffix(base, ".safetensors"): - hasSafetensors = true + scan := func(set *hashset.Set) { + for _, raw := range set.Values() { + rel, ok := raw.(string) + if !ok { + continue + } + base := strings.ToLower(filepath.Base(rel)) + switch { + case base == "saved_model.pb" || base == "saved_model.pbtxt": + hasSavedModel = true + case strings.HasSuffix(base, ".onnx"): + hasONNX = true + case strings.HasSuffix(base, ".gguf"): + hasGGUF = true + case strings.HasSuffix(base, ".safetensors"): + hasSafetensors = true + } } } + scan(mf.model) + scan(mf.config) + scan(mf.code) + scan(mf.doc) switch { case hasSavedModel: diff --git a/pkg/modelfile/modelfile_test.go b/pkg/modelfile/modelfile_test.go index 4991d0cb..b65d5a13 100644 --- a/pkg/modelfile/modelfile_test.go +++ b/pkg/modelfile/modelfile_test.go @@ -2528,3 +2528,27 @@ func TestNewModelfileByWorkspace_InferFormatGGUF(t *testing.T) { require.NoError(t, err) assert.Equal(t, "gguf", mf.GetFormat()) } + +// TestNewModelfileByWorkspace_InferFormatSavedModelPbtxt is regression coverage +// for a Codex finding: `saved_model.pbtxt` is not in ModelFilePatterns, so the +// walker lands it in CODE (or DOC). An earlier inferFormat that scanned only +// mf.model would silently miss the textual SavedModel variant. The fixed +// implementation scans all walker buckets, so a workspace whose only TF signal +// is a sibling .pbtxt still resolves to format=tensorflow. +func TestNewModelfileByWorkspace_InferFormatSavedModelPbtxt(t *testing.T) { + tempDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "saved_model.pbtxt"), nil, 0o644)) + // A real model file so generateByWorkspace doesn't error on empty MODEL set. + require.NoError(t, os.WriteFile(filepath.Join(tempDir, "weights.safetensors"), nil, 0o644)) + + mf, err := NewModelfileByWorkspace(tempDir, &configmodelfile.GenerateConfig{ + Name: "tf-textual", + Workspace: tempDir, + }) + require.NoError(t, err) + + // SavedModel signal must outrank the sibling .safetensors per the + // documented priority order, even though the .pbtxt is not in MODEL. + assert.Equal(t, "tensorflow", mf.GetFormat(), + "saved_model.pbtxt in any walker bucket must trigger format=tensorflow") +}