|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "bufio" |
| 5 | + "crypto/rand" |
| 6 | + "encoding/hex" |
| 7 | + "flag" |
| 8 | + "fmt" |
| 9 | + "io" |
| 10 | + "os" |
| 11 | + "strings" |
| 12 | +) |
| 13 | + |
| 14 | +// runInit is the entry point for the `amp-proxy init` subcommand. It prompts |
| 15 | +// the operator for the handful of values that cannot be defaulted (custom |
| 16 | +// provider URL, custom provider Bearer token, optional ampcode.com upstream |
| 17 | +// key, Gemini route mode), generates a random local API key, and writes a |
| 18 | +// ready-to-run config.yaml to the requested path. Host/port, the full Amp |
| 19 | +// CLI model mapping table, and sensible defaults for everything else are |
| 20 | +// baked in so the produced file runs unmodified against a standard Amp CLI |
| 21 | +// setup. |
| 22 | +// |
| 23 | +// Anything that would need non-trivial customisation (multiple providers, |
| 24 | +// per-client upstream keys, access manager, body capture) is intentionally |
| 25 | +// left out of the generated file; operators who need those can hand-edit |
| 26 | +// afterwards or copy from config.example.yaml. |
| 27 | +func runInit(args []string) error { |
| 28 | + fs := flag.NewFlagSet("init", flag.ContinueOnError) |
| 29 | + configPath := fs.String("config", "config.yaml", "path to write the generated config file") |
| 30 | + force := fs.Bool("force", false, "overwrite the target file if it already exists") |
| 31 | + if err := fs.Parse(args); err != nil { |
| 32 | + return err |
| 33 | + } |
| 34 | + |
| 35 | + if _, err := os.Stat(*configPath); err == nil && !*force { |
| 36 | + return fmt.Errorf("refusing to overwrite existing %s — delete it, pass -force, or use -config <other-path>", *configPath) |
| 37 | + } |
| 38 | + |
| 39 | + fmt.Println("amp-proxy init — answer a few questions and a ready-to-run config will be written.") |
| 40 | + fmt.Println("Values are echoed to the terminal; clear your shell history if the API key is sensitive.") |
| 41 | + fmt.Println() |
| 42 | + |
| 43 | + reader := bufio.NewReader(os.Stdin) |
| 44 | + |
| 45 | + gatewayURL, err := promptRequired(reader, "Custom provider URL (OpenAI-compatible, e.g. http://host:port/v1)", "") |
| 46 | + if err != nil { |
| 47 | + return err |
| 48 | + } |
| 49 | + gatewayKey, err := promptRequired(reader, "Custom provider API key (Bearer token)", "") |
| 50 | + if err != nil { |
| 51 | + return err |
| 52 | + } |
| 53 | + geminiMode, err := promptChoice(reader, "Gemini route mode", []string{"translate", "ampcode"}, "translate") |
| 54 | + if err != nil { |
| 55 | + return err |
| 56 | + } |
| 57 | + ampUpstream, err := promptOptional(reader, "Amp upstream API key (for ampcode.com fallback, press Enter to skip)", "") |
| 58 | + if err != nil { |
| 59 | + return err |
| 60 | + } |
| 61 | + |
| 62 | + localKey, err := generateLocalAPIKey() |
| 63 | + if err != nil { |
| 64 | + return fmt.Errorf("generate local API key: %w", err) |
| 65 | + } |
| 66 | + |
| 67 | + content := renderInitConfig(gatewayURL, gatewayKey, ampUpstream, geminiMode, localKey) |
| 68 | + if err := os.WriteFile(*configPath, []byte(content), 0o600); err != nil { |
| 69 | + return fmt.Errorf("write %s: %w", *configPath, err) |
| 70 | + } |
| 71 | + |
| 72 | + fmt.Println() |
| 73 | + fmt.Printf("Wrote %s (mode 600).\n", *configPath) |
| 74 | + fmt.Println() |
| 75 | + fmt.Println("Start amp-proxy:") |
| 76 | + fmt.Printf(" ./amp-proxy --config %s\n", *configPath) |
| 77 | + fmt.Println() |
| 78 | + fmt.Println("Point Amp CLI at it:") |
| 79 | + fmt.Println(" export AMP_URL=http://127.0.0.1:8317") |
| 80 | + fmt.Printf(" export AMP_API_KEY=%s\n", localKey) |
| 81 | + fmt.Println(" amp") |
| 82 | + return nil |
| 83 | +} |
| 84 | + |
| 85 | +// promptRequired reads a non-empty line from the user. Empty responses |
| 86 | +// trigger a re-prompt; EOF returns the default value if set, otherwise an |
| 87 | +// error. |
| 88 | +func promptRequired(r *bufio.Reader, label, defaultVal string) (string, error) { |
| 89 | + for { |
| 90 | + if defaultVal != "" { |
| 91 | + fmt.Printf("%s [%s]: ", label, defaultVal) |
| 92 | + } else { |
| 93 | + fmt.Printf("%s: ", label) |
| 94 | + } |
| 95 | + line, err := r.ReadString('\n') |
| 96 | + if err != nil { |
| 97 | + if err == io.EOF && defaultVal != "" { |
| 98 | + return defaultVal, nil |
| 99 | + } |
| 100 | + return "", fmt.Errorf("read %q: %w", label, err) |
| 101 | + } |
| 102 | + trimmed := strings.TrimSpace(line) |
| 103 | + if trimmed == "" { |
| 104 | + if defaultVal != "" { |
| 105 | + return defaultVal, nil |
| 106 | + } |
| 107 | + fmt.Println(" value required, please try again") |
| 108 | + continue |
| 109 | + } |
| 110 | + return trimmed, nil |
| 111 | + } |
| 112 | +} |
| 113 | + |
| 114 | +// promptOptional reads a line from the user, returning the default (or |
| 115 | +// empty string) on an empty response. Used for fields the operator may |
| 116 | +// legitimately want to leave blank, such as the ampcode.com upstream key. |
| 117 | +func promptOptional(r *bufio.Reader, label, defaultVal string) (string, error) { |
| 118 | + if defaultVal != "" { |
| 119 | + fmt.Printf("%s [%s]: ", label, defaultVal) |
| 120 | + } else { |
| 121 | + fmt.Printf("%s: ", label) |
| 122 | + } |
| 123 | + line, err := r.ReadString('\n') |
| 124 | + if err != nil { |
| 125 | + if err == io.EOF { |
| 126 | + return defaultVal, nil |
| 127 | + } |
| 128 | + return "", fmt.Errorf("read %q: %w", label, err) |
| 129 | + } |
| 130 | + trimmed := strings.TrimSpace(line) |
| 131 | + if trimmed == "" { |
| 132 | + return defaultVal, nil |
| 133 | + } |
| 134 | + return trimmed, nil |
| 135 | +} |
| 136 | + |
| 137 | +// promptChoice reads a line from the user and restricts the response to |
| 138 | +// one of the provided choices. A case-insensitive empty response returns |
| 139 | +// the default. |
| 140 | +func promptChoice(r *bufio.Reader, label string, choices []string, defaultVal string) (string, error) { |
| 141 | + lowered := make([]string, len(choices)) |
| 142 | + for i, c := range choices { |
| 143 | + lowered[i] = strings.ToLower(c) |
| 144 | + } |
| 145 | + for { |
| 146 | + fmt.Printf("%s (%s) [%s]: ", label, strings.Join(choices, "/"), defaultVal) |
| 147 | + line, err := r.ReadString('\n') |
| 148 | + if err != nil { |
| 149 | + if err == io.EOF { |
| 150 | + return defaultVal, nil |
| 151 | + } |
| 152 | + return "", fmt.Errorf("read %q: %w", label, err) |
| 153 | + } |
| 154 | + trimmed := strings.TrimSpace(strings.ToLower(line)) |
| 155 | + if trimmed == "" { |
| 156 | + return defaultVal, nil |
| 157 | + } |
| 158 | + for _, c := range lowered { |
| 159 | + if trimmed == c { |
| 160 | + return trimmed, nil |
| 161 | + } |
| 162 | + } |
| 163 | + fmt.Printf(" invalid choice %q, must be one of %v\n", trimmed, choices) |
| 164 | + } |
| 165 | +} |
| 166 | + |
| 167 | +// generateLocalAPIKey returns a URL-safe hex token that amp-proxy will |
| 168 | +// require on incoming Amp CLI requests. 16 random bytes (32 hex chars) is |
| 169 | +// large enough that a local attacker cannot brute-force it in any |
| 170 | +// meaningful time on a loopback-bound server. |
| 171 | +func generateLocalAPIKey() (string, error) { |
| 172 | + b := make([]byte, 16) |
| 173 | + if _, err := rand.Read(b); err != nil { |
| 174 | + return "", err |
| 175 | + } |
| 176 | + return "amp-" + hex.EncodeToString(b), nil |
| 177 | +} |
| 178 | + |
| 179 | +// renderInitConfig produces a complete config.yaml body from the prompted |
| 180 | +// values. Layout and comments mirror config.example.yaml so operators who |
| 181 | +// later want to cross-reference the example can find their way around. |
| 182 | +func renderInitConfig(gatewayURL, gatewayKey, ampUpstream, geminiMode, localKey string) string { |
| 183 | + var b strings.Builder |
| 184 | + b.WriteString("# Generated by `amp-proxy init`.\n") |
| 185 | + b.WriteString("# Edit freely — amp-proxy hot-reloads most fields without restart.\n") |
| 186 | + b.WriteString("\n") |
| 187 | + b.WriteString("host: \"127.0.0.1\"\n") |
| 188 | + b.WriteString("port: 8317\n") |
| 189 | + b.WriteString("\n") |
| 190 | + b.WriteString("# Local API keys Amp CLI must present (match AMP_API_KEY in your shell).\n") |
| 191 | + b.WriteString("api-keys:\n") |
| 192 | + fmt.Fprintf(&b, " - %q\n", localKey) |
| 193 | + b.WriteString("\n") |
| 194 | + b.WriteString("ampcode:\n") |
| 195 | + b.WriteString(" upstream-url: \"https://ampcode.com\"\n") |
| 196 | + fmt.Fprintf(&b, " upstream-api-key: %q\n", ampUpstream) |
| 197 | + b.WriteString(" restrict-management-to-localhost: true\n") |
| 198 | + b.WriteString("\n") |
| 199 | + b.WriteString(" # Rewrite Amp CLI model names onto the gpt-5.4 family served by\n") |
| 200 | + b.WriteString(" # custom-providers below. Adjust the right-hand side if your gateway\n") |
| 201 | + b.WriteString(" # exposes different model names.\n") |
| 202 | + b.WriteString(" model-mappings:\n") |
| 203 | + mappings := [][2]string{ |
| 204 | + {"claude-opus-4-6", "gpt-5.4(high)"}, |
| 205 | + {"claude-sonnet-4-6-thinking", "gpt-5.4-mini(high)"}, |
| 206 | + {"claude-haiku-4-5-20251001", "gpt-5.4-mini"}, |
| 207 | + {"gpt-5.4", "gpt-5.4(xhigh)"}, |
| 208 | + {"gemini-2.5-flash-lite-preview-09-2025", "gpt-5.4-mini"}, |
| 209 | + {"gemini-2.5-flash-lite", "gpt-5.4-mini"}, |
| 210 | + {"claude-sonnet-4-6", "gpt-5.4-mini(high)"}, |
| 211 | + {"gpt-5.3-codex", "gpt-5.4(high)"}, |
| 212 | + {"gemini-3-flash-preview", "gpt-5.4-mini(high)"}, |
| 213 | + } |
| 214 | + for _, m := range mappings { |
| 215 | + fmt.Fprintf(&b, " - from: %q\n", m[0]) |
| 216 | + fmt.Fprintf(&b, " to: %q\n", m[1]) |
| 217 | + } |
| 218 | + b.WriteString("\n") |
| 219 | + b.WriteString(" force-model-mappings: true\n") |
| 220 | + b.WriteString("\n") |
| 221 | + b.WriteString(" custom-providers:\n") |
| 222 | + b.WriteString(" - name: \"gateway\"\n") |
| 223 | + fmt.Fprintf(&b, " url: %q\n", gatewayURL) |
| 224 | + fmt.Fprintf(&b, " api-key: %q\n", gatewayKey) |
| 225 | + b.WriteString(" models:\n") |
| 226 | + b.WriteString(" - \"gpt-5.4\"\n") |
| 227 | + b.WriteString(" - \"gpt-5.4-mini\"\n") |
| 228 | + b.WriteString("\n") |
| 229 | + fmt.Fprintf(&b, " gemini-route-mode: %q\n", geminiMode) |
| 230 | + return b.String() |
| 231 | +} |
0 commit comments