Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions backend/internal/api/settings_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,14 @@ func (h *SettingsHandler) saveAppSettings(w http.ResponseWriter, r *http.Request
return
}
payload = normalizeAppSettings(payload)
if err := validateLocalProxyHost(payload.ProxyHost); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if err := validateLANShareIPWhitelist(payload.LANShareIPWhitelist); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if proxyEndpointChanged(current, payload) {
if err := h.updateEnabledProxyConfig(payload); err != nil {
http.Error(w, err.Error(), http.StatusConflict)
Expand All @@ -341,6 +349,43 @@ func (h *SettingsHandler) saveAppSettings(w http.ResponseWriter, r *http.Request
writeJSON(w, http.StatusOK, h.appSettings())
}

func validateLocalProxyHost(host string) error {
normalized := strings.TrimSpace(host)
if normalized == "" {
return fmt.Errorf("proxy_host is required")
}
if strings.EqualFold(normalized, "localhost") {
return nil
}
ip := net.ParseIP(normalized)
if ip != nil && ip.IsLoopback() {
return nil
}
return fmt.Errorf("proxy_host %q is not local-only, use 127.0.0.1 / localhost / ::1", host)
}

func validateLANShareIPWhitelist(raw string) error {
for _, entry := range strings.FieldsFunc(raw, func(r rune) bool {
return r == '\n' || r == '\r' || r == ',' || r == ';'
}) {
normalized := strings.TrimSpace(entry)
if normalized == "" {
continue
}
if strings.EqualFold(normalized, "localhost") {
continue
}
if ip := net.ParseIP(normalized); ip != nil {
continue
}
if _, _, err := net.ParseCIDR(normalized); err == nil {
continue
}
return fmt.Errorf("lan_share_ip_whitelist entry %q is invalid, use IP or CIDR", normalized)
}
return nil
}

func (h *SettingsHandler) getFailoverQueue(w http.ResponseWriter) {
if h.settings == nil {
writeJSON(w, http.StatusOK, []int64{})
Expand Down Expand Up @@ -600,6 +645,7 @@ func normalizeAppSettings(value settings.AppSettings) settings.AppSettings {
if value.ProxyPort <= 0 {
value.ProxyPort = defaults.ProxyPort
}
value.LANShareIPWhitelist = strings.TrimSpace(value.LANShareIPWhitelist)
switch value.UpstreamProxyMode {
case settings.UpstreamProxyModeSystem, settings.UpstreamProxyModeDirect, settings.UpstreamProxyModeManual:
default:
Expand Down
90 changes: 89 additions & 1 deletion backend/internal/api/settings_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ func TestSettingsHandlerGetAndPutAppSettings(t *testing.T) {
"usage_request_timeout_seconds": 18,
"proxy_host": "localhost",
"proxy_port": 15721,
"lan_share_enabled": true,
"lan_share_ip_whitelist": "192.168.1.10\n192.168.1.0/24",
"auto_failover_enabled": true,
"auto_backup_interval_hours": 12,
"backup_retention_count": 7,
Expand All @@ -66,11 +68,51 @@ func TestSettingsHandlerGetAndPutAppSettings(t *testing.T) {
if err != nil {
t.Fatalf("GetAppSettings returned error: %v", err)
}
if !stored.LaunchAtLogin || !stored.SilentStart || stored.CloseToTray || stored.ShowProxySwitchOnHome || stored.ShowHomeUpdateIndicator || stored.UsageRequestTimeoutSeconds != 18 || stored.ProxyHost != "localhost" || stored.ProxyPort != 15721 || !stored.AutoFailoverEnabled || stored.AutoBackupIntervalHours != 12 || stored.BackupRetentionCount != 7 || stored.Language != "en-US" || stored.ThemeMode != "dark" {
if !stored.LaunchAtLogin || !stored.SilentStart || stored.CloseToTray || stored.ShowProxySwitchOnHome || stored.ShowHomeUpdateIndicator || stored.UsageRequestTimeoutSeconds != 18 || stored.ProxyHost != "localhost" || stored.ProxyPort != 15721 || !stored.LANShareEnabled || stored.LANShareIPWhitelist != "192.168.1.10\n192.168.1.0/24" || !stored.AutoFailoverEnabled || stored.AutoBackupIntervalHours != 12 || stored.BackupRetentionCount != 7 || stored.Language != "en-US" || stored.ThemeMode != "dark" {
t.Fatalf("stored settings = %+v, want updated values", stored)
}
}

func TestSettingsHandlerRejectsInvalidLANShareWhitelist(t *testing.T) {
t.Setenv("HOME", t.TempDir())

handler, _ := newSettingsHandler(t)

body := strings.NewReader(`{
"launch_at_login": false,
"silent_start": false,
"close_to_tray": true,
"show_proxy_switch_on_home": true,
"show_home_update_indicator": true,
"status_refresh_interval_seconds": 60,
"usage_request_timeout_seconds": 15,
"proxy_host": "127.0.0.1",
"proxy_port": 6789,
"lan_share_enabled": true,
"lan_share_ip_whitelist": "bad-entry",
"upstream_proxy_mode": "system",
"upstream_proxy_url": "",
"upstream_proxy_username": "",
"upstream_proxy_password": "",
"auto_failover_enabled": true,
"auto_backup_interval_hours": 24,
"backup_retention_count": 10,
"language": "zh-CN",
"theme_mode": "system"
}`)
req := httptest.NewRequest(http.MethodPut, "/settings/app", body)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)

if rec.Code != http.StatusBadRequest {
t.Fatalf("PUT /settings/app status = %d, want %d; body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "lan_share_ip_whitelist") {
t.Fatalf("response body = %q, want lan_share_ip_whitelist validation error", rec.Body.String())
}
}

func TestSettingsHandlerGetAndPutFailoverQueue(t *testing.T) {
handler, repo := newSettingsHandler(t)

Expand Down Expand Up @@ -609,6 +651,52 @@ func TestSettingsHandlerUpdatingProxyAddressRewritesEnabledConfig(t *testing.T)
assertFileContains(t, filepath.Join(home, ".codex", "config.toml"), `base_url = "http://localhost:15721/ai-router/api"`)
}

func TestSettingsHandlerRejectsNonLocalProxyHost(t *testing.T) {
t.Setenv("HOME", t.TempDir())

handler, repo := newSettingsHandler(t)

body := strings.NewReader(`{
"launch_at_login": false,
"silent_start": false,
"close_to_tray": true,
"show_proxy_switch_on_home": true,
"show_home_update_indicator": true,
"status_refresh_interval_seconds": 60,
"usage_request_timeout_seconds": 15,
"proxy_host": "192.168.1.24",
"proxy_port": 6789,
"upstream_proxy_mode": "system",
"upstream_proxy_url": "",
"upstream_proxy_username": "",
"upstream_proxy_password": "",
"auto_failover_enabled": true,
"auto_backup_interval_hours": 24,
"backup_retention_count": 10,
"language": "zh-CN",
"theme_mode": "system"
}`)
req := httptest.NewRequest(http.MethodPut, "/settings/app", body)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)

if rec.Code != http.StatusBadRequest {
t.Fatalf("PUT /settings/app status = %d, want %d; body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
}
if !strings.Contains(rec.Body.String(), "proxy_host") {
t.Fatalf("response body = %q, want proxy_host validation error", rec.Body.String())
}

stored, err := repo.GetAppSettings()
if err != nil {
t.Fatalf("GetAppSettings returned error: %v", err)
}
if stored.ProxyHost != settings.DefaultAppSettings().ProxyHost {
t.Fatalf("stored proxy_host = %q, want %q", stored.ProxyHost, settings.DefaultAppSettings().ProxyHost)
}
}

func TestSettingsHandlerProxyDisableDetachesEvenWhenConfigChanged(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
Expand Down
83 changes: 82 additions & 1 deletion backend/internal/bootstrap/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"errors"
"fmt"
"log"
"net"
"net/http"
"path/filepath"
"runtime/debug"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -184,7 +186,7 @@ func NewApp(_ context.Context, cfg Config) (*App, error) {
}
http.NotFound(w, r)
})
mux.Handle("/ai-router/api/", withCORS(http.StripPrefix("/ai-router/api", apiMux)))
mux.Handle("/ai-router/api/", withCORS(withLANShareAccessControl(settingsRepo, http.StripPrefix("/ai-router/api", apiMux))))

appCtx, cancel := context.WithCancel(context.Background())
app := &App{listenAddr: cfg.ListenAddr, handler: mux, store: store, cancel: cancel}
Expand Down Expand Up @@ -315,6 +317,85 @@ func withCORS(next http.Handler) http.Handler {
})
}

func withLANShareAccessControl(repo settings.ReadRepository, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if repo == nil {
next.ServeHTTP(w, r)
return
}

appSettings, err := repo.GetAppSettings()
if err != nil || !appSettings.LANShareEnabled {
next.ServeHTTP(w, r)
return
}

remoteIP, err := remoteIPFromAddr(r.RemoteAddr)
if err != nil {
http.Error(w, "invalid remote address", http.StatusForbidden)
return
}
if remoteIP.IsLoopback() {
next.ServeHTTP(w, r)
return
}
allowed, err := ipAllowedByWhitelist(remoteIP, appSettings.LANShareIPWhitelist)
if err != nil {
log.Printf("lan share whitelist parse failed: %v", err)
http.Error(w, "lan share whitelist is invalid", http.StatusForbidden)
return
}
if !allowed {
http.Error(w, "remote address is not allowed by lan share whitelist", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}

func remoteIPFromAddr(addr string) (net.IP, error) {
host, _, err := net.SplitHostPort(strings.TrimSpace(addr))
if err != nil {
return nil, err
}
ip := net.ParseIP(host)
if ip == nil {
return nil, fmt.Errorf("parse remote ip: %q", host)
}
return ip, nil
}

func ipAllowedByWhitelist(ip net.IP, raw string) (bool, error) {
if strings.TrimSpace(raw) == "" {
return true, nil
}
for _, entry := range strings.FieldsFunc(raw, func(r rune) bool {
return r == '\n' || r == '\r' || r == ',' || r == ';'
}) {
normalized := strings.TrimSpace(entry)
if normalized == "" {
continue
}
if strings.EqualFold(normalized, "localhost") {
continue
}
if allowedIP := net.ParseIP(normalized); allowedIP != nil {
if allowedIP.Equal(ip) {
return true, nil
}
continue
}
_, network, err := net.ParseCIDR(normalized)
if err != nil {
return false, err
}
if network.Contains(ip) {
return true, nil
}
}
return false, nil
}

func (a *App) ListenAddr() string {
return a.listenAddr
}
Expand Down
107 changes: 107 additions & 0 deletions backend/internal/bootstrap/lan_share_access_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package bootstrap

import (
"net/http"
"net/http/httptest"
"path/filepath"
"testing"

"github.com/gcssloop/codex-router/backend/internal/settings"
sqlitestore "github.com/gcssloop/codex-router/backend/internal/store/sqlite"
)

func TestLANShareAccessControlAllowsLoopbackEvenWithWhitelist(t *testing.T) {
t.Parallel()

store, err := sqlitestore.Open(filepath.Join(t.TempDir(), "router.sqlite"))
if err != nil {
t.Fatalf("Open returned error: %v", err)
}
t.Cleanup(func() {
_ = store.Close()
})
repo := settings.NewSQLiteRepository(store.DB())
appSettings := settings.DefaultAppSettings()
appSettings.LANShareEnabled = true
appSettings.LANShareIPWhitelist = "192.168.1.10"
if err := repo.SaveAppSettings(appSettings); err != nil {
t.Fatalf("SaveAppSettings returned error: %v", err)
}

handler := withLANShareAccessControl(repo, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
req := httptest.NewRequest(http.MethodGet, "/ai-router/api/settings/app", nil)
req.RemoteAddr = "127.0.0.1:54321"
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

if rec.Code != http.StatusNoContent {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
}
}

func TestLANShareAccessControlBlocksNonWhitelistedRemoteAddr(t *testing.T) {
t.Parallel()

store, err := sqlitestore.Open(filepath.Join(t.TempDir(), "router.sqlite"))
if err != nil {
t.Fatalf("Open returned error: %v", err)
}
t.Cleanup(func() {
_ = store.Close()
})
repo := settings.NewSQLiteRepository(store.DB())
appSettings := settings.DefaultAppSettings()
appSettings.LANShareEnabled = true
appSettings.LANShareIPWhitelist = "192.168.1.10"
if err := repo.SaveAppSettings(appSettings); err != nil {
t.Fatalf("SaveAppSettings returned error: %v", err)
}

handler := withLANShareAccessControl(repo, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
req := httptest.NewRequest(http.MethodGet, "/ai-router/api/settings/app", nil)
req.RemoteAddr = "192.168.1.11:54321"
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

if rec.Code != http.StatusForbidden {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden)
}
}

func TestLANShareAccessControlAllowsAllWhenWhitelistEmpty(t *testing.T) {
t.Parallel()

store, err := sqlitestore.Open(filepath.Join(t.TempDir(), "router.sqlite"))
if err != nil {
t.Fatalf("Open returned error: %v", err)
}
t.Cleanup(func() {
_ = store.Close()
})
repo := settings.NewSQLiteRepository(store.DB())
appSettings := settings.DefaultAppSettings()
appSettings.LANShareEnabled = true
appSettings.LANShareIPWhitelist = ""
if err := repo.SaveAppSettings(appSettings); err != nil {
t.Fatalf("SaveAppSettings returned error: %v", err)
}

handler := withLANShareAccessControl(repo, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
req := httptest.NewRequest(http.MethodGet, "/ai-router/api/settings/app", nil)
req.RemoteAddr = "192.168.1.11:54321"
rec := httptest.NewRecorder()

handler.ServeHTTP(rec, req)

if rec.Code != http.StatusNoContent {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
}
}
Loading
Loading