Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion internal/api/modules/amp/model_mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package amp

import (
"sort"
"strings"
"sync"

Expand Down Expand Up @@ -50,10 +51,39 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
defer m.mu.RUnlock()

// Normalize the requested model for lookup
// Replace underscores with dashes for consistent lookup (e.g., claude-sonnet-4_5 -> claude-sonnet-4-5)
normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel))
normalizedRequest = strings.ReplaceAll(normalizedRequest, "_", "-")

// Check for direct mapping
// Check for direct mapping first
targetModel, exists := m.mappings[normalizedRequest]

// If no direct match, try prefix/wildcard matching with deterministic order
// This allows mappings like "claude-haiku-*" to match "claude-haiku-4-5-20251001"
if !exists {
patterns := make([]string, 0, len(m.mappings))
for pattern := range m.mappings {
if strings.HasSuffix(pattern, "*") {
patterns = append(patterns, pattern)
}
}
sort.Slice(patterns, func(i, j int) bool {
if len(patterns[i]) == len(patterns[j]) {
return patterns[i] < patterns[j]
}
return len(patterns[i]) > len(patterns[j])
})
for _, pattern := range patterns {
prefix := strings.TrimSuffix(pattern, "*")
if strings.HasPrefix(normalizedRequest, prefix) {
targetModel = m.mappings[pattern]
exists = true
log.Debugf("amp model mapping: wildcard match %s -> %s (pattern: %s)", normalizedRequest, targetModel, pattern)
break
}
}
}

if !exists {
return ""
}
Expand Down Expand Up @@ -88,7 +118,9 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
}

// Store with normalized lowercase key for case-insensitive lookup
// Also normalize underscores to dashes for consistent matching
normalizedFrom := strings.ToLower(from)
normalizedFrom = strings.ReplaceAll(normalizedFrom, "_", "-")
m.mappings[normalizedFrom] = to

log.Debugf("amp model mapping registered: %s -> %s", from, to)
Expand Down
62 changes: 47 additions & 15 deletions internal/registry/model_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ type ModelRegistry struct {
mutex *sync.RWMutex
}

// normalizeModelID normalizes model IDs by replacing underscores with dashes
func normalizeModelID(modelID string) string {
return strings.ReplaceAll(strings.TrimSpace(modelID), "_", "-")
}

// Global model registry instance
var globalRegistry *ModelRegistry
var registryOnce sync.Once
Expand Down Expand Up @@ -121,20 +126,21 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [

provider := strings.ToLower(clientProvider)
uniqueModelIDs := make([]string, 0, len(models))
rawModelIDs := make([]string, 0, len(models))
normalizedModelIDs := make([]string, 0, len(models))
newModels := make(map[string]*ModelInfo, len(models))
newCounts := make(map[string]int, len(models))
for _, model := range models {
if model == nil || model.ID == "" {
continue
}
rawModelIDs = append(rawModelIDs, model.ID)
newCounts[model.ID]++
if _, exists := newModels[model.ID]; exists {
normalizedID := normalizeModelID(model.ID)
normalizedModelIDs = append(normalizedModelIDs, normalizedID)
newCounts[normalizedID]++
if _, exists := newModels[normalizedID]; exists {
continue
}
newModels[model.ID] = model
uniqueModelIDs = append(uniqueModelIDs, model.ID)
newModels[normalizedID] = model
uniqueModelIDs = append(uniqueModelIDs, normalizedID)
}

if len(uniqueModelIDs) == 0 {
Expand All @@ -153,24 +159,25 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
providerChanged := oldProvider != provider
if !hadExisting {
// Pure addition path.
for _, modelID := range rawModelIDs {
for _, modelID := range normalizedModelIDs {
model := newModels[modelID]
r.addModelRegistration(modelID, provider, model, now)
}
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
r.clientModels[clientID] = append([]string(nil), normalizedModelIDs...)
if provider != "" {
r.clientProviders[clientID] = provider
} else {
delete(r.clientProviders, clientID)
}
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(normalizedModelIDs))
misc.LogCredentialSeparator()
return
}

oldCounts := make(map[string]int, len(oldModels))
for _, id := range oldModels {
oldCounts[id]++
normID := normalizeModelID(id)
oldCounts[normID]++
}

added := make([]string, 0)
Expand Down Expand Up @@ -281,8 +288,8 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
}

// Update client bookkeeping.
if len(rawModelIDs) > 0 {
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
if len(normalizedModelIDs) > 0 {
r.clientModels[clientID] = append([]string(nil), normalizedModelIDs...)
}
if provider != "" {
r.clientProviders[clientID] = provider
Expand Down Expand Up @@ -532,6 +539,9 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
return false
}

// Normalize model ID: replace underscores with dashes for consistent lookup
normalizedModelID := strings.ReplaceAll(modelID, "_", "-")

r.mutex.RLock()
defer r.mutex.RUnlock()

Expand All @@ -541,7 +551,8 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
}

for _, id := range models {
if strings.EqualFold(strings.TrimSpace(id), modelID) {
trimmedID := strings.TrimSpace(id)
if strings.EqualFold(trimmedID, modelID) || strings.EqualFold(trimmedID, normalizedModelID) {
return true
}
}
Expand Down Expand Up @@ -614,7 +625,14 @@ func (r *ModelRegistry) GetModelCount(modelID string) int {
r.mutex.RLock()
defer r.mutex.RUnlock()

if registration, exists := r.models[modelID]; exists {
// Normalize model ID: replace underscores with dashes for consistent lookup
normalizedID := strings.ReplaceAll(modelID, "_", "-")
registration, exists := r.models[normalizedID]
if !exists {
registration, exists = r.models[modelID]
}

if exists {
now := time.Now()
quotaExpiredDuration := 5 * time.Minute

Expand Down Expand Up @@ -648,7 +666,15 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string {
r.mutex.RLock()
defer r.mutex.RUnlock()

registration, exists := r.models[modelID]
// Normalize model ID: replace underscores with dashes for consistent lookup
// This handles cases like "claude-sonnet-4_5-20250929" -> "claude-sonnet-4-5-20250929"
normalizedID := strings.ReplaceAll(modelID, "_", "-")

registration, exists := r.models[normalizedID]
// Fall back to original ID if normalized version not found
if !exists {
registration, exists = r.models[modelID]
}
if !exists || registration == nil || len(registration.Providers) == 0 {
return nil
}
Expand Down Expand Up @@ -700,6 +726,12 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string {
func (r *ModelRegistry) GetModelInfo(modelID string) *ModelInfo {
r.mutex.RLock()
defer r.mutex.RUnlock()
// Normalize model ID: replace underscores with dashes for consistent lookup
normalizedID := strings.ReplaceAll(modelID, "_", "-")
if reg, ok := r.models[normalizedID]; ok && reg != nil {
return reg.Info
}
// Fall back to original ID if normalized version not found
if reg, ok := r.models[modelID]; ok && reg != nil {
return reg.Info
}
Expand Down