Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions pkg/skills/skillsvc/skillsvc.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
Expand All @@ -29,10 +28,6 @@ import (
"github.com/stacklok/toolhive/pkg/storage"
)

// ociTagRegexp matches valid OCI tag strings per the distribution spec.
// Reference: https://github.com/opencontainers/distribution-spec/blob/v1.1.1/spec.md#pulling-manifests
var ociTagRegexp = regexp.MustCompile(`^[\w][\w.-]{0,127}$`)

// Option configures the skill service.
type Option func(*service)

Expand Down Expand Up @@ -845,11 +840,14 @@ func validateLocalPath(path string) error {
return nil
}

// validateOCITag checks that a tag conforms to the OCI distribution spec format.
// validateOCITag checks that a tag conforms to the OCI distribution spec format
// using go-containerregistry's reference parser rather than a hand-rolled regex.
func validateOCITag(tag string) error {
if !ociTagRegexp.MatchString(tag) {
// Construct a minimal valid reference with the tag so the library
// validates the tag portion per the distribution spec.
if _, err := nameref.NewTag("dummy.invalid/repo:"+tag, nameref.StrictValidation); err != nil {
return httperr.WithCode(
fmt.Errorf("invalid OCI tag %q: must match %s", tag, ociTagRegexp.String()),
fmt.Errorf("invalid OCI tag %q: %w", tag, err),
http.StatusBadRequest,
)
}
Expand Down
50 changes: 50 additions & 0 deletions pkg/skills/skillsvc/skillsvc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -1940,3 +1941,52 @@ func TestUninstallRemovesSkillFromGroups(t *testing.T) {
})
}
}

func TestValidateOCITag(t *testing.T) {
t.Parallel()

tests := []struct {
name string
tag string
wantErr bool
}{
// Valid tags
{name: "simple version", tag: "v1.0.0", wantErr: false},
{name: "latest", tag: "latest", wantErr: false},
{name: "numeric", tag: "123", wantErr: false},
{name: "with dots", tag: "1.2.3", wantErr: false},
{name: "with hyphens", tag: "my-skill", wantErr: false},
{name: "with underscores", tag: "my_skill", wantErr: false},
{name: "mixed alphanumeric", tag: "v1.0.0-rc.1", wantErr: false},
{name: "uppercase", tag: "MyTag", wantErr: false},
{name: "single char", tag: "a", wantErr: false},
{name: "max length 128 chars", tag: strings.Repeat("a", 128), wantErr: false},
{name: "exceeds max length 129 chars", tag: strings.Repeat("a", 129), wantErr: true},

// Invalid tags
{name: "empty string", tag: "", wantErr: true},
{name: "contains space", tag: "invalid tag", wantErr: true},
{name: "contains exclamation", tag: "invalid!", wantErr: true},
{name: "contains at", tag: "invalid@tag", wantErr: true},
{name: "contains hash", tag: "invalid#tag", wantErr: true},
{name: "contains slash", tag: "invalid/tag", wantErr: true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

err := validateOCITag(tt.tag)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid OCI tag")
// Verify it returns a proper HTTP status code.
var coded *httperr.CodedError
require.ErrorAs(t, err, &coded)
assert.Equal(t, http.StatusBadRequest, coded.HTTPCode())
} else {
require.NoError(t, err)
}
})
}
}
Loading