diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go index 87384a802..74edf47f5 100644 --- a/internal/api/modules/amp/model_mapping.go +++ b/internal/api/modules/amp/model_mapping.go @@ -3,6 +3,7 @@ package amp import ( + "sort" "strings" "sync" @@ -50,10 +51,39 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string { defer m.mu.RUnlock() // Normalize the requested model for lookup + // Replace underscores with dashes for consistent lookup (e.g., claude-sonnet-4_5 -> claude-sonnet-4-5) normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel)) + normalizedRequest = strings.ReplaceAll(normalizedRequest, "_", "-") - // Check for direct mapping + // Check for direct mapping first targetModel, exists := m.mappings[normalizedRequest] + + // If no direct match, try prefix/wildcard matching with deterministic order + // This allows mappings like "claude-haiku-*" to match "claude-haiku-4-5-20251001" + if !exists { + patterns := make([]string, 0, len(m.mappings)) + for pattern := range m.mappings { + if strings.HasSuffix(pattern, "*") { + patterns = append(patterns, pattern) + } + } + sort.Slice(patterns, func(i, j int) bool { + if len(patterns[i]) == len(patterns[j]) { + return patterns[i] < patterns[j] + } + return len(patterns[i]) > len(patterns[j]) + }) + for _, pattern := range patterns { + prefix := strings.TrimSuffix(pattern, "*") + if strings.HasPrefix(normalizedRequest, prefix) { + targetModel = m.mappings[pattern] + exists = true + log.Debugf("amp model mapping: wildcard match %s -> %s (pattern: %s)", normalizedRequest, targetModel, pattern) + break + } + } + } + if !exists { return "" } @@ -88,7 +118,9 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) { } // Store with normalized lowercase key for case-insensitive lookup + // Also normalize underscores to dashes for consistent matching normalizedFrom := strings.ToLower(from) + normalizedFrom = strings.ReplaceAll(normalizedFrom, "_", "-") m.mappings[normalizedFrom] = to log.Debugf("amp model mapping registered: %s -> %s", from, to) diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index 5ef9007f9..95efa7138 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -93,6 +93,11 @@ type ModelRegistry struct { mutex *sync.RWMutex } +// normalizeModelID normalizes model IDs by replacing underscores with dashes +func normalizeModelID(modelID string) string { + return strings.ReplaceAll(strings.TrimSpace(modelID), "_", "-") +} + // Global model registry instance var globalRegistry *ModelRegistry var registryOnce sync.Once @@ -121,20 +126,21 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ provider := strings.ToLower(clientProvider) uniqueModelIDs := make([]string, 0, len(models)) - rawModelIDs := make([]string, 0, len(models)) + normalizedModelIDs := make([]string, 0, len(models)) newModels := make(map[string]*ModelInfo, len(models)) newCounts := make(map[string]int, len(models)) for _, model := range models { if model == nil || model.ID == "" { continue } - rawModelIDs = append(rawModelIDs, model.ID) - newCounts[model.ID]++ - if _, exists := newModels[model.ID]; exists { + normalizedID := normalizeModelID(model.ID) + normalizedModelIDs = append(normalizedModelIDs, normalizedID) + newCounts[normalizedID]++ + if _, exists := newModels[normalizedID]; exists { continue } - newModels[model.ID] = model - uniqueModelIDs = append(uniqueModelIDs, model.ID) + newModels[normalizedID] = model + uniqueModelIDs = append(uniqueModelIDs, normalizedID) } if len(uniqueModelIDs) == 0 { @@ -153,24 +159,25 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ providerChanged := oldProvider != provider if !hadExisting { // Pure addition path. - for _, modelID := range rawModelIDs { + for _, modelID := range normalizedModelIDs { model := newModels[modelID] r.addModelRegistration(modelID, provider, model, now) } - r.clientModels[clientID] = append([]string(nil), rawModelIDs...) + r.clientModels[clientID] = append([]string(nil), normalizedModelIDs...) if provider != "" { r.clientProviders[clientID] = provider } else { delete(r.clientProviders, clientID) } - log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs)) + log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(normalizedModelIDs)) misc.LogCredentialSeparator() return } oldCounts := make(map[string]int, len(oldModels)) for _, id := range oldModels { - oldCounts[id]++ + normID := normalizeModelID(id) + oldCounts[normID]++ } added := make([]string, 0) @@ -281,8 +288,8 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ } // Update client bookkeeping. - if len(rawModelIDs) > 0 { - r.clientModels[clientID] = append([]string(nil), rawModelIDs...) + if len(normalizedModelIDs) > 0 { + r.clientModels[clientID] = append([]string(nil), normalizedModelIDs...) } if provider != "" { r.clientProviders[clientID] = provider @@ -532,6 +539,9 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool { return false } + // Normalize model ID: replace underscores with dashes for consistent lookup + normalizedModelID := strings.ReplaceAll(modelID, "_", "-") + r.mutex.RLock() defer r.mutex.RUnlock() @@ -541,7 +551,8 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool { } for _, id := range models { - if strings.EqualFold(strings.TrimSpace(id), modelID) { + trimmedID := strings.TrimSpace(id) + if strings.EqualFold(trimmedID, modelID) || strings.EqualFold(trimmedID, normalizedModelID) { return true } } @@ -614,7 +625,14 @@ func (r *ModelRegistry) GetModelCount(modelID string) int { r.mutex.RLock() defer r.mutex.RUnlock() - if registration, exists := r.models[modelID]; exists { + // Normalize model ID: replace underscores with dashes for consistent lookup + normalizedID := strings.ReplaceAll(modelID, "_", "-") + registration, exists := r.models[normalizedID] + if !exists { + registration, exists = r.models[modelID] + } + + if exists { now := time.Now() quotaExpiredDuration := 5 * time.Minute @@ -648,7 +666,15 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string { r.mutex.RLock() defer r.mutex.RUnlock() - registration, exists := r.models[modelID] + // Normalize model ID: replace underscores with dashes for consistent lookup + // This handles cases like "claude-sonnet-4_5-20250929" -> "claude-sonnet-4-5-20250929" + normalizedID := strings.ReplaceAll(modelID, "_", "-") + + registration, exists := r.models[normalizedID] + // Fall back to original ID if normalized version not found + if !exists { + registration, exists = r.models[modelID] + } if !exists || registration == nil || len(registration.Providers) == 0 { return nil } @@ -700,6 +726,12 @@ func (r *ModelRegistry) GetModelProviders(modelID string) []string { func (r *ModelRegistry) GetModelInfo(modelID string) *ModelInfo { r.mutex.RLock() defer r.mutex.RUnlock() + // Normalize model ID: replace underscores with dashes for consistent lookup + normalizedID := strings.ReplaceAll(modelID, "_", "-") + if reg, ok := r.models[normalizedID]; ok && reg != nil { + return reg.Info + } + // Fall back to original ID if normalized version not found if reg, ok := r.models[modelID]; ok && reg != nil { return reg.Info }