diff --git a/backend/backend.go b/backend/backend.go index 1d0f834..7fcd7bd 100644 --- a/backend/backend.go +++ b/backend/backend.go @@ -25,3 +25,11 @@ type Backend interface { type ConfigKey struct{} type UsersKey struct{} + +// LimitEnforcer interface for traffic limit enforcement +type LimitEnforcer interface { + Start(ctx context.Context, refreshInterval interface{}) + Stop() + ResetUserTraffic(userID int) + ResetAllTraffic() +} diff --git a/common/service.pb.go b/common/service.pb.go index ec396e1..777087f 100644 --- a/common/service.pb.go +++ b/common/service.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.10 -// protoc v3.21.12 +// protoc-gen-go v1.36.11 +// protoc v6.33.4 // source: common/service.proto package common @@ -226,8 +226,13 @@ type Backend struct { Users []*User `protobuf:"bytes,3,rep,name=users,proto3" json:"users,omitempty"` KeepAlive uint64 `protobuf:"varint,4,opt,name=keep_alive,json=keepAlive,proto3" json:"keep_alive,omitempty"` ExcludeInbounds []string `protobuf:"bytes,5,rep,name=exclude_inbounds,json=excludeInbounds,proto3" json:"exclude_inbounds,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // Limit enforcer configuration (sent from panel) + NodeId int32 `protobuf:"varint,6,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` + PanelApiUrl string `protobuf:"bytes,7,opt,name=panel_api_url,json=panelApiUrl,proto3" json:"panel_api_url,omitempty"` + LimitCheckInterval int32 `protobuf:"varint,8,opt,name=limit_check_interval,json=limitCheckInterval,proto3" json:"limit_check_interval,omitempty"` // seconds, default 30 + LimitRefreshInterval int32 `protobuf:"varint,9,opt,name=limit_refresh_interval,json=limitRefreshInterval,proto3" json:"limit_refresh_interval,omitempty"` // seconds, default 60 + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *Backend) Reset() { @@ -295,6 +300,34 @@ func (x *Backend) GetExcludeInbounds() []string { return nil } +func (x *Backend) GetNodeId() int32 { + if x != nil { + return x.NodeId + } + return 0 +} + +func (x *Backend) GetPanelApiUrl() string { + if x != nil { + return x.PanelApiUrl + } + return "" +} + +func (x *Backend) GetLimitCheckInterval() int32 { + if x != nil { + return x.LimitCheckInterval + } + return 0 +} + +func (x *Backend) GetLimitRefreshInterval() int32 { + if x != nil { + return x.LimitRefreshInterval + } + return 0 +} + // log type Log struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1251,14 +1284,18 @@ const file_common_service_proto_rawDesc = "" + "\x10BaseInfoResponse\x12\x18\n" + "\astarted\x18\x01 \x01(\bR\astarted\x12!\n" + "\fcore_version\x18\x02 \x01(\tR\vcoreVersion\x12!\n" + - "\fnode_version\x18\x03 \x01(\tR\vnodeVersion\"\xba\x01\n" + + "\fnode_version\x18\x03 \x01(\tR\vnodeVersion\"\xdf\x02\n" + "\aBackend\x12(\n" + "\x04type\x18\x01 \x01(\x0e2\x14.service.BackendTypeR\x04type\x12\x16\n" + "\x06config\x18\x02 \x01(\tR\x06config\x12#\n" + "\x05users\x18\x03 \x03(\v2\r.service.UserR\x05users\x12\x1d\n" + "\n" + "keep_alive\x18\x04 \x01(\x04R\tkeepAlive\x12)\n" + - "\x10exclude_inbounds\x18\x05 \x03(\tR\x0fexcludeInbounds\"\x1d\n" + + "\x10exclude_inbounds\x18\x05 \x03(\tR\x0fexcludeInbounds\x12\x17\n" + + "\anode_id\x18\x06 \x01(\x05R\x06nodeId\x12\"\n" + + "\rpanel_api_url\x18\a \x01(\tR\vpanelApiUrl\x120\n" + + "\x14limit_check_interval\x18\b \x01(\x05R\x12limitCheckInterval\x124\n" + + "\x16limit_refresh_interval\x18\t \x01(\x05R\x14limitRefreshInterval\"\x1d\n" + "\x03Log\x12\x16\n" + "\x06detail\x18\x01 \x01(\tR\x06detail\"X\n" + "\x04Stat\x12\x12\n" + diff --git a/common/service.proto b/common/service.proto index 05ef71a..140b091 100644 --- a/common/service.proto +++ b/common/service.proto @@ -23,6 +23,12 @@ message Backend { repeated User users = 3; uint64 keep_alive = 4; repeated string exclude_inbounds = 5; + + // Limit enforcer configuration (sent from panel) + int32 node_id = 6; + string panel_api_url = 7; + int32 limit_check_interval = 8; // seconds, default 30 + int32 limit_refresh_interval = 9; // seconds, default 60 } // log diff --git a/common/service_grpc.pb.go b/common/service_grpc.pb.go index e4525e5..cdc2973 100644 --- a/common/service_grpc.pb.go +++ b/common/service_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 -// - protoc v3.21.12 +// - protoc-gen-go-grpc v1.6.0 +// - protoc v6.33.4 // source: common/service.proto package common @@ -225,40 +225,40 @@ type NodeServiceServer interface { type UnimplementedNodeServiceServer struct{} func (UnimplementedNodeServiceServer) Start(context.Context, *Backend) (*BaseInfoResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Start not implemented") + return nil, status.Error(codes.Unimplemented, "method Start not implemented") } func (UnimplementedNodeServiceServer) Stop(context.Context, *Empty) (*Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "method Stop not implemented") + return nil, status.Error(codes.Unimplemented, "method Stop not implemented") } func (UnimplementedNodeServiceServer) GetBaseInfo(context.Context, *Empty) (*BaseInfoResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetBaseInfo not implemented") + return nil, status.Error(codes.Unimplemented, "method GetBaseInfo not implemented") } func (UnimplementedNodeServiceServer) GetLogs(*Empty, grpc.ServerStreamingServer[Log]) error { - return status.Errorf(codes.Unimplemented, "method GetLogs not implemented") + return status.Error(codes.Unimplemented, "method GetLogs not implemented") } func (UnimplementedNodeServiceServer) GetSystemStats(context.Context, *Empty) (*SystemStatsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetSystemStats not implemented") + return nil, status.Error(codes.Unimplemented, "method GetSystemStats not implemented") } func (UnimplementedNodeServiceServer) GetBackendStats(context.Context, *Empty) (*BackendStatsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetBackendStats not implemented") + return nil, status.Error(codes.Unimplemented, "method GetBackendStats not implemented") } func (UnimplementedNodeServiceServer) GetStats(context.Context, *StatRequest) (*StatResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetStats not implemented") + return nil, status.Error(codes.Unimplemented, "method GetStats not implemented") } func (UnimplementedNodeServiceServer) GetUserOnlineStats(context.Context, *StatRequest) (*OnlineStatResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetUserOnlineStats not implemented") + return nil, status.Error(codes.Unimplemented, "method GetUserOnlineStats not implemented") } func (UnimplementedNodeServiceServer) GetUserOnlineIpListStats(context.Context, *StatRequest) (*StatsOnlineIpListResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetUserOnlineIpListStats not implemented") + return nil, status.Error(codes.Unimplemented, "method GetUserOnlineIpListStats not implemented") } func (UnimplementedNodeServiceServer) SyncUser(grpc.ClientStreamingServer[User, Empty]) error { - return status.Errorf(codes.Unimplemented, "method SyncUser not implemented") + return status.Error(codes.Unimplemented, "method SyncUser not implemented") } func (UnimplementedNodeServiceServer) SyncUsers(context.Context, *Users) (*Empty, error) { - return nil, status.Errorf(codes.Unimplemented, "method SyncUsers not implemented") + return nil, status.Error(codes.Unimplemented, "method SyncUsers not implemented") } func (UnimplementedNodeServiceServer) SyncUsersChunked(grpc.ClientStreamingServer[UsersChunk, Empty]) error { - return status.Errorf(codes.Unimplemented, "method SyncUsersChunked not implemented") + return status.Error(codes.Unimplemented, "method SyncUsersChunked not implemented") } func (UnimplementedNodeServiceServer) mustEmbedUnimplementedNodeServiceServer() {} func (UnimplementedNodeServiceServer) testEmbeddedByValue() {} @@ -271,7 +271,7 @@ type UnsafeNodeServiceServer interface { } func RegisterNodeServiceServer(s grpc.ServiceRegistrar, srv NodeServiceServer) { - // If the following call pancis, it indicates UnimplementedNodeServiceServer was + // If the following call panics, it indicates UnimplementedNodeServiceServer was // embedded by pointer and is nil. This will cause panics if an // unimplemented method is ever invoked, so we test this at initialization // time to prevent it from happening at runtime later due to I/O. diff --git a/common/traffic_limits.go b/common/traffic_limits.go new file mode 100644 index 0000000..433de0f --- /dev/null +++ b/common/traffic_limits.go @@ -0,0 +1,390 @@ +package common + +import ( + "compress/gzip" + "context" + "encoding/json" + "fmt" + "io" + "math" + "math/rand" + "net/http" + "sync" + "sync/atomic" + "time" +) + +// NodeUserLimit represents a per-user per-node traffic limit +type NodeUserLimit struct { + ID int `json:"id"` + UserID int `json:"user_id"` + NodeID int `json:"node_id"` + DataLimit int64 `json:"data_limit"` // in bytes, 0 = unlimited +} + +// NodeUserLimitsResponse from panel API +type NodeUserLimitsResponse struct { + Limits []NodeUserLimit `json:"limits"` + Total int `json:"total"` +} + +// userLimitsData holds the limits map for atomic swap +type userLimitsData struct { + limits map[int]int64 +} + +// TrafficLimitsCache caches per-user per-node limits with optimized performance +type TrafficLimitsCache struct { + nodeID int + panelAPIURL string + apiKey string + httpClient *http.Client + stopChan chan struct{} + logger Logger + + // Atomic pointer for lock-free reads + limitsPtr atomic.Pointer[userLimitsData] + + // Mutex only for refresh operations (writes) + refreshMu sync.Mutex + + // ETag for conditional requests + lastETag string + lastUpdate time.Time + lastError error + + // Exponential backoff state + backoff struct { + mu sync.Mutex + currentInterval time.Duration + baseInterval time.Duration + maxInterval time.Duration + lastErrorTime time.Time + } +} + +// Logger interface for flexible logging +type Logger interface { + Printf(format string, v ...interface{}) + Errorf(format string, v ...interface{}) +} + +// defaultLogger uses fmt.Printf if no logger provided +type defaultLogger struct{} + +func (d *defaultLogger) Printf(format string, v ...interface{}) { + fmt.Printf(format, v...) +} + +func (d *defaultLogger) Errorf(format string, v ...interface{}) { + fmt.Printf("ERROR: "+format, v...) +} + +// NewTrafficLimitsCache creates a new traffic limits cache with optimizations +func NewTrafficLimitsCache(nodeID int, panelAPIURL string, apiKey string) *TrafficLimitsCache { + tlc := &TrafficLimitsCache{ + nodeID: nodeID, + panelAPIURL: panelAPIURL, + apiKey: apiKey, + httpClient: &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + DisableCompression: false, // Enable built-in compression + DisableKeepAlives: false, + MaxIdleConnsPerHost: 5, + }, + }, + stopChan: make(chan struct{}), + logger: &defaultLogger{}, + } + + // Initialize with empty limits + tlc.limitsPtr.Store(&userLimitsData{limits: make(map[int]int64)}) + + // Initialize backoff settings + tlc.backoff.baseInterval = 1 * time.Second + tlc.backoff.currentInterval = tlc.backoff.baseInterval + tlc.backoff.maxInterval = 60 * time.Second + + return tlc +} + +// SetLogger sets custom logger +func (tlc *TrafficLimitsCache) SetLogger(logger Logger) { + tlc.logger = logger +} + +// Refresh fetches latest limits from panel API +func (tlc *TrafficLimitsCache) Refresh() error { + return tlc.RefreshWithContext(context.Background()) +} + +// RefreshWithContext fetches latest limits from panel API with context +// Supports ETag for conditional requests and gzip compression +func (tlc *TrafficLimitsCache) RefreshWithContext(ctx context.Context) error { + tlc.refreshMu.Lock() + defer tlc.refreshMu.Unlock() + + url := fmt.Sprintf("%s/api/node-user-limits/node/%d", tlc.panelAPIURL, tlc.nodeID) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + tlc.recordError(fmt.Errorf("failed to create request: %w", err)) + return tlc.lastError + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tlc.apiKey)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept-Encoding", "gzip") // Request gzip compression + + // Send ETag for conditional request + if tlc.lastETag != "" { + req.Header.Set("If-None-Match", tlc.lastETag) + } + + resp, err := tlc.httpClient.Do(req) + if err != nil { + tlc.recordError(fmt.Errorf("failed to fetch limits: %w", err)) + return tlc.lastError + } + defer resp.Body.Close() + + // Handle 304 Not Modified - data hasn't changed + if resp.StatusCode == http.StatusNotModified { + tlc.lastUpdate = time.Now() + tlc.lastError = nil + tlc.resetBackoff() + tlc.logger.Printf("Traffic limits unchanged (304), using cached data\n") + return nil + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + tlc.recordError(fmt.Errorf("panel API returned status %d: %s", resp.StatusCode, string(body))) + return tlc.lastError + } + + // Handle gzip response + var reader io.Reader = resp.Body + if resp.Header.Get("Content-Encoding") == "gzip" { + gzipReader, err := gzip.NewReader(resp.Body) + if err != nil { + tlc.recordError(fmt.Errorf("failed to create gzip reader: %w", err)) + return tlc.lastError + } + defer gzipReader.Close() + reader = gzipReader + } + + var response NodeUserLimitsResponse + if err := json.NewDecoder(reader).Decode(&response); err != nil { + tlc.recordError(fmt.Errorf("failed to decode response: %w", err)) + return tlc.lastError + } + + // Build new limits map + newLimits := make(map[int]int64, len(response.Limits)) + for _, limit := range response.Limits { + newLimits[limit.UserID] = limit.DataLimit + } + + // Atomic swap - lock-free for readers + tlc.limitsPtr.Store(&userLimitsData{limits: newLimits}) + + // Update metadata + tlc.lastUpdate = time.Now() + tlc.lastError = nil + + // Save ETag for next request + if etag := resp.Header.Get("ETag"); etag != "" { + tlc.lastETag = etag + } + + // Reset backoff on success + tlc.resetBackoff() + + return nil +} + +// GetLimit returns the traffic limit for a user, returns 0 if no limit set +// This is now lock-free for maximum performance +func (tlc *TrafficLimitsCache) GetLimit(userID int) int64 { + data := tlc.limitsPtr.Load() + if data == nil || data.limits == nil { + return 0 + } + + limit, exists := data.limits[userID] + if !exists { + return 0 // No limit configured + } + return limit +} + +// HasLimit checks if a user has a limit configured +// This is now lock-free for maximum performance +func (tlc *TrafficLimitsCache) HasLimit(userID int) bool { + data := tlc.limitsPtr.Load() + if data == nil || data.limits == nil { + return false + } + + _, exists := data.limits[userID] + return exists +} + +// GetStats returns cache statistics +func (tlc *TrafficLimitsCache) GetStats() (total int, lastUpdate time.Time, lastError error) { + tlc.refreshMu.Lock() + defer tlc.refreshMu.Unlock() + + data := tlc.limitsPtr.Load() + count := 0 + if data != nil && data.limits != nil { + count = len(data.limits) + } + + return count, tlc.lastUpdate, tlc.lastError +} + +// Count returns the number of cached limits (lock-free) +func (tlc *TrafficLimitsCache) Count() int { + data := tlc.limitsPtr.Load() + if data == nil || data.limits == nil { + return 0 + } + return len(data.limits) +} + +// UpdateFromPush updates the cache from a gRPC push (for future use) +// fullSync: true = replace all limits, false = incremental update +func (tlc *TrafficLimitsCache) UpdateFromPush(limits []NodeUserLimit, fullSync bool) { + tlc.refreshMu.Lock() + defer tlc.refreshMu.Unlock() + + var newLimits map[int]int64 + + if fullSync { + // Full replacement + newLimits = make(map[int]int64, len(limits)) + } else { + // Incremental: copy existing and update + data := tlc.limitsPtr.Load() + if data != nil && data.limits != nil { + newLimits = make(map[int]int64, len(data.limits)+len(limits)) + for k, v := range data.limits { + newLimits[k] = v + } + } else { + newLimits = make(map[int]int64, len(limits)) + } + } + + for _, limit := range limits { + if limit.DataLimit == 0 && !fullSync { + // Remove limit in incremental mode + delete(newLimits, limit.UserID) + } else { + newLimits[limit.UserID] = limit.DataLimit + } + } + + tlc.limitsPtr.Store(&userLimitsData{limits: newLimits}) + tlc.lastUpdate = time.Now() +} + +// recordError records an error and updates backoff state +func (tlc *TrafficLimitsCache) recordError(err error) { + tlc.lastError = err + + tlc.backoff.mu.Lock() + defer tlc.backoff.mu.Unlock() + + tlc.backoff.lastErrorTime = time.Now() + + // Exponential backoff: double the interval + tlc.backoff.currentInterval = time.Duration( + math.Min( + float64(tlc.backoff.currentInterval*2), + float64(tlc.backoff.maxInterval), + ), + ) +} + +// resetBackoff resets the backoff interval to base +func (tlc *TrafficLimitsCache) resetBackoff() { + tlc.backoff.mu.Lock() + defer tlc.backoff.mu.Unlock() + + tlc.backoff.currentInterval = tlc.backoff.baseInterval +} + +// getBackoffInterval returns the current backoff interval with jitter +func (tlc *TrafficLimitsCache) getBackoffInterval() time.Duration { + tlc.backoff.mu.Lock() + defer tlc.backoff.mu.Unlock() + + // Add jitter: ±20% to prevent thundering herd + jitter := float64(tlc.backoff.currentInterval) * 0.2 * (rand.Float64()*2 - 1) + return tlc.backoff.currentInterval + time.Duration(jitter) +} + +// shouldUseBackoff checks if we should use backoff interval instead of normal interval +func (tlc *TrafficLimitsCache) shouldUseBackoff() bool { + tlc.backoff.mu.Lock() + defer tlc.backoff.mu.Unlock() + + // Use backoff if we had an error recently (within last 5 minutes) + return time.Since(tlc.backoff.lastErrorTime) < 5*time.Minute && + tlc.backoff.currentInterval > tlc.backoff.baseInterval +} + +// StartAutoRefresh starts automatic refresh of limits with graceful shutdown support +// Now includes exponential backoff on errors +func (tlc *TrafficLimitsCache) StartAutoRefresh(interval time.Duration) { + go func() { + // Initial refresh + if err := tlc.Refresh(); err != nil { + tlc.logger.Errorf("Failed initial traffic limits refresh: %v\n", err) + } else { + total, _, _ := tlc.GetStats() + tlc.logger.Printf("Initial traffic limits loaded: %d limits\n", total) + } + + for { + // Determine next refresh interval + var nextInterval time.Duration + if tlc.shouldUseBackoff() { + nextInterval = tlc.getBackoffInterval() + tlc.logger.Printf("Using backoff interval: %v\n", nextInterval) + } else { + nextInterval = interval + } + + timer := time.NewTimer(nextInterval) + + select { + case <-timer.C: + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + if err := tlc.RefreshWithContext(ctx); err != nil { + tlc.logger.Errorf("Failed to refresh traffic limits: %v\n", err) + } else { + total, lastUpdate, _ := tlc.GetStats() + tlc.logger.Printf("Traffic limits refreshed: %d limits at %s\n", total, lastUpdate.Format(time.RFC3339)) + } + cancel() + case <-tlc.stopChan: + timer.Stop() + tlc.logger.Printf("Stopping traffic limits auto-refresh\n") + return + } + } + }() +} + +// Stop gracefully stops the auto-refresh goroutine +func (tlc *TrafficLimitsCache) Stop() { + close(tlc.stopChan) +} diff --git a/common/traffic_limits_test.go b/common/traffic_limits_test.go new file mode 100644 index 0000000..cff4b34 --- /dev/null +++ b/common/traffic_limits_test.go @@ -0,0 +1,340 @@ +package common + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestTrafficLimitsCache_GetLimit_LockFree(t *testing.T) { + tlc := NewTrafficLimitsCache(1, "http://localhost", "test-key") + + // Set some limits via push + limits := []NodeUserLimit{ + {UserID: 1, DataLimit: 1000}, + {UserID: 2, DataLimit: 2000}, + {UserID: 3, DataLimit: 0}, // Unlimited + } + tlc.UpdateFromPush(limits, true) + + // Test GetLimit + tests := []struct { + userID int + expected int64 + }{ + {1, 1000}, + {2, 2000}, + {3, 0}, + {999, 0}, // Non-existent + } + + for _, tt := range tests { + got := tlc.GetLimit(tt.userID) + if got != tt.expected { + t.Errorf("GetLimit(%d) = %d, want %d", tt.userID, got, tt.expected) + } + } +} + +func TestTrafficLimitsCache_HasLimit(t *testing.T) { + tlc := NewTrafficLimitsCache(1, "http://localhost", "test-key") + + limits := []NodeUserLimit{ + {UserID: 1, DataLimit: 1000}, + {UserID: 2, DataLimit: 0}, // Explicit 0 is still a limit + } + tlc.UpdateFromPush(limits, true) + + tests := []struct { + userID int + expected bool + }{ + {1, true}, + {2, true}, // 0 is still a configured limit + {999, false}, // Not configured + } + + for _, tt := range tests { + got := tlc.HasLimit(tt.userID) + if got != tt.expected { + t.Errorf("HasLimit(%d) = %v, want %v", tt.userID, got, tt.expected) + } + } +} + +func TestTrafficLimitsCache_UpdateFromPush_FullSync(t *testing.T) { + tlc := NewTrafficLimitsCache(1, "http://localhost", "test-key") + + // Initial limits + tlc.UpdateFromPush([]NodeUserLimit{ + {UserID: 1, DataLimit: 1000}, + {UserID: 2, DataLimit: 2000}, + }, true) + + // Full sync - should replace all + tlc.UpdateFromPush([]NodeUserLimit{ + {UserID: 3, DataLimit: 3000}, + }, true) + + if tlc.HasLimit(1) { + t.Error("User 1 should not have limit after full sync") + } + if tlc.HasLimit(2) { + t.Error("User 2 should not have limit after full sync") + } + if !tlc.HasLimit(3) { + t.Error("User 3 should have limit after full sync") + } + if tlc.GetLimit(3) != 3000 { + t.Errorf("User 3 limit = %d, want 3000", tlc.GetLimit(3)) + } +} + +func TestTrafficLimitsCache_UpdateFromPush_Incremental(t *testing.T) { + tlc := NewTrafficLimitsCache(1, "http://localhost", "test-key") + + // Initial limits + tlc.UpdateFromPush([]NodeUserLimit{ + {UserID: 1, DataLimit: 1000}, + {UserID: 2, DataLimit: 2000}, + }, true) + + // Incremental update - should update user 1 and add user 3 + tlc.UpdateFromPush([]NodeUserLimit{ + {UserID: 1, DataLimit: 1500}, + {UserID: 3, DataLimit: 3000}, + }, false) + + if tlc.GetLimit(1) != 1500 { + t.Errorf("User 1 limit = %d, want 1500", tlc.GetLimit(1)) + } + if tlc.GetLimit(2) != 2000 { + t.Errorf("User 2 limit = %d, want 2000 (unchanged)", tlc.GetLimit(2)) + } + if tlc.GetLimit(3) != 3000 { + t.Errorf("User 3 limit = %d, want 3000", tlc.GetLimit(3)) + } + + // Incremental update with 0 - should remove the limit + tlc.UpdateFromPush([]NodeUserLimit{ + {UserID: 2, DataLimit: 0}, + }, false) + + if tlc.HasLimit(2) { + t.Error("User 2 should not have limit after removal") + } +} + +func TestTrafficLimitsCache_ConcurrentAccess(t *testing.T) { + tlc := NewTrafficLimitsCache(1, "http://localhost", "test-key") + + // Set initial limits + tlc.UpdateFromPush([]NodeUserLimit{ + {UserID: 1, DataLimit: 1000}, + }, true) + + var wg sync.WaitGroup + var reads int64 = 0 + numReaders := 100 + numWriters := 10 + iterations := 1000 + + // Start readers + for i := 0; i < numReaders; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations; j++ { + _ = tlc.GetLimit(1) + _ = tlc.HasLimit(1) + atomic.AddInt64(&reads, 1) + } + }() + } + + // Start writers + for i := 0; i < numWriters; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < iterations/10; j++ { + tlc.UpdateFromPush([]NodeUserLimit{ + {UserID: 1, DataLimit: int64(j * id)}, + }, false) + } + }(i) + } + + wg.Wait() + t.Logf("Completed %d reads with concurrent writes", atomic.LoadInt64(&reads)) +} + +func TestTrafficLimitsCache_ETagHandling(t *testing.T) { + // Create a test server that supports ETag + etag := `"test-etag-12345"` + limits := NodeUserLimitsResponse{ + Limits: []NodeUserLimit{ + {ID: 1, UserID: 100, NodeID: 1, DataLimit: 5000}, + }, + Total: 1, + } + + requestCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + + // Check If-None-Match header + if r.Header.Get("If-None-Match") == etag { + w.WriteHeader(http.StatusNotModified) + return + } + + // Set ETag and return data + w.Header().Set("ETag", etag) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(limits) + })) + defer server.Close() + + tlc := NewTrafficLimitsCache(1, server.URL, "test-key") + + // First request - should get data + err := tlc.Refresh() + if err != nil { + t.Fatalf("First refresh failed: %v", err) + } + if tlc.GetLimit(100) != 5000 { + t.Errorf("After first refresh: limit = %d, want 5000", tlc.GetLimit(100)) + } + if requestCount != 1 { + t.Errorf("Request count = %d, want 1", requestCount) + } + + // Second request - should get 304 + err = tlc.Refresh() + if err != nil { + t.Fatalf("Second refresh failed: %v", err) + } + if tlc.GetLimit(100) != 5000 { + t.Errorf("After second refresh: limit = %d, want 5000 (unchanged)", tlc.GetLimit(100)) + } + if requestCount != 2 { + t.Errorf("Request count = %d, want 2", requestCount) + } +} + +func TestTrafficLimitsCache_ExponentialBackoff(t *testing.T) { + // Create a failing server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("server error")) + })) + defer server.Close() + + tlc := NewTrafficLimitsCache(1, server.URL, "test-key") + + // First error - should start backoff + _ = tlc.Refresh() + if !tlc.shouldUseBackoff() { + t.Error("Should use backoff after first error") + } + + initial := tlc.backoff.currentInterval + + // Second error - should double + _ = tlc.Refresh() + if tlc.backoff.currentInterval <= initial { + t.Error("Backoff interval should increase after second error") + } + + // Successful request resets backoff + successServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(NodeUserLimitsResponse{Limits: []NodeUserLimit{}, Total: 0}) + })) + defer successServer.Close() + + tlc.panelAPIURL = successServer.URL + _ = tlc.Refresh() + + if tlc.backoff.currentInterval != tlc.backoff.baseInterval { + t.Error("Backoff should reset after successful request") + } +} + +func TestTrafficLimitsCache_GetStats(t *testing.T) { + tlc := NewTrafficLimitsCache(1, "http://localhost", "test-key") + + // Initially empty + total, lastUpdate, lastError := tlc.GetStats() + if total != 0 { + t.Errorf("Initial total = %d, want 0", total) + } + if !lastUpdate.IsZero() { + t.Error("Initial lastUpdate should be zero") + } + if lastError != nil { + t.Error("Initial lastError should be nil") + } + + // Add some limits + tlc.UpdateFromPush([]NodeUserLimit{ + {UserID: 1, DataLimit: 1000}, + {UserID: 2, DataLimit: 2000}, + }, true) + + total, lastUpdate, lastError = tlc.GetStats() + if total != 2 { + t.Errorf("After push: total = %d, want 2", total) + } + if lastUpdate.IsZero() { + t.Error("After push: lastUpdate should not be zero") + } +} + +func TestTrafficLimitsCache_Stop(t *testing.T) { + tlc := NewTrafficLimitsCache(1, "http://localhost", "test-key") + + // Start auto-refresh + tlc.StartAutoRefresh(100 * time.Millisecond) + + // Give it time to do at least one refresh + time.Sleep(50 * time.Millisecond) + + // Stop should not panic + tlc.Stop() + + // Give it time to stop + time.Sleep(50 * time.Millisecond) +} + +func TestTrafficLimitsCache_RefreshWithContext_Cancellation(t *testing.T) { + // Create a slow server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + tlc := NewTrafficLimitsCache(1, server.URL, "test-key") + + // Create a context that cancels quickly + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + err := tlc.RefreshWithContext(ctx) + elapsed := time.Since(start) + + if err == nil { + t.Error("Expected error due to context cancellation") + } + if elapsed > 1*time.Second { + t.Errorf("Request should have been cancelled quickly, took %v", elapsed) + } +} diff --git a/controller/controller.go b/controller/controller.go index 1d4d2ba..b482924 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -23,14 +23,15 @@ type Service interface { } type Controller struct { - backend backend.Backend - cfg *config.Config - apiPort int - clientIP string - lastRequest time.Time - stats *common.SystemStatsResponse - cancelFunc context.CancelFunc - mu sync.RWMutex + backend backend.Backend + cfg *config.Config + apiPort int + clientIP string + lastRequest time.Time + stats *common.SystemStatsResponse + cancelFunc context.CancelFunc + limitEnforcer *LimitEnforcer + mu sync.RWMutex } func New(cfg *config.Config) *Controller { @@ -62,13 +63,68 @@ func (c *Controller) Connect(ip string, keepAlive uint64) { } } +// LimitEnforcerParams contains parameters for starting the limit enforcer +// These are passed from the panel via the Backend message +type LimitEnforcerParams struct { + NodeID int32 + PanelAPIURL string + LimitCheckInterval int32 // seconds + LimitRefreshInterval int32 // seconds +} + +// StartLimitEnforcer starts the limit enforcer with config from panel +// Should be called after Connect() with params from Backend message +func (c *Controller) StartLimitEnforcer(params LimitEnforcerParams) { + if params.PanelAPIURL == "" || params.NodeID <= 0 { + return // Limit enforcer not configured by panel + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Use defaults if not specified + checkInterval := time.Duration(params.LimitCheckInterval) * time.Second + if checkInterval == 0 { + checkInterval = 30 * time.Second + } + refreshInterval := time.Duration(params.LimitRefreshInterval) * time.Second + if refreshInterval == 0 { + refreshInterval = 60 * time.Second + } + + // Create context that will be cancelled when controller disconnects + ctx := context.Background() + + c.limitEnforcer = NewLimitEnforcer(c, LimitEnforcerConfig{ + NodeID: int(params.NodeID), + PanelAPIURL: params.PanelAPIURL, + APIKey: c.cfg.ApiKey.String(), + CheckInterval: checkInterval, + }) + c.limitEnforcer.Start(ctx, refreshInterval) + log.Printf("Limit enforcer started for node %d (panel: %s)", params.NodeID, params.PanelAPIURL) +} + +// GetLimitEnforcer returns the limit enforcer (nil if not enabled) +func (c *Controller) GetLimitEnforcer() *LimitEnforcer { + c.mu.RLock() + defer c.mu.RUnlock() + return c.limitEnforcer +} + func (c *Controller) Disconnect() { c.cancelFunc() c.mu.Lock() backend := c.backend + limitEnforcer := c.limitEnforcer c.mu.Unlock() + // Stop limit enforcer first (it uses backend) + if limitEnforcer != nil { + limitEnforcer.Stop() + } + // Shutdown backend outside of lock to avoid deadlock // Shutdown() will wait for process termination to complete if backend != nil { @@ -79,6 +135,7 @@ func (c *Controller) Disconnect() { defer c.mu.Unlock() c.backend = nil + c.limitEnforcer = nil c.apiPort = tools.FindFreePort() c.clientIP = "" } diff --git a/controller/limit_enforcer.go b/controller/limit_enforcer.go new file mode 100644 index 0000000..968ffb8 --- /dev/null +++ b/controller/limit_enforcer.go @@ -0,0 +1,340 @@ +package controller + +import ( + "context" + "log" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pasarguard/node/common" +) + +// LimitEnforcerMetrics contains metrics for monitoring the limit enforcer +type LimitEnforcerMetrics struct { + // Counters + UsersRemoved atomic.Int64 // Total users removed due to limit exceeded + ChecksPerformed atomic.Int64 // Total limit checks performed + CheckErrors atomic.Int64 // Total errors during checks + RefreshCount atomic.Int64 // Total cache refreshes + + // Gauges (current values) + TrackedUsers atomic.Int64 // Current number of users being tracked + CachedLimits atomic.Int64 // Current number of limits in cache + LastCheckDuration atomic.Int64 // Last check duration in milliseconds + LastCheckTime atomic.Int64 // Unix timestamp of last check +} + +// userTrafficEntry tracks traffic data with last seen time for cleanup +type userTrafficEntry struct { + traffic int64 + lastSeen time.Time +} + +// LimitEnforcer monitors user traffic and removes users who exceed their node-specific limits +type LimitEnforcer struct { + controller *Controller + limitsCache *common.TrafficLimitsCache + checkInterval time.Duration + stopChan chan struct{} + + // Track cumulative traffic per user (since xray stats are reset on read) + userTraffic map[int]*userTrafficEntry + trafficMu sync.RWMutex + + // Configuration for cleanup + cleanupInterval time.Duration // How often to clean inactive users (default: 1h) + inactiveTimeout time.Duration // Remove users not seen for this duration (default: 24h) + + // Metrics for monitoring + Metrics LimitEnforcerMetrics +} + +// LimitEnforcerConfig contains configuration for the limit enforcer +type LimitEnforcerConfig struct { + NodeID int + PanelAPIURL string + APIKey string + CheckInterval time.Duration // How often to check stats (default: 30s) + RefreshInterval time.Duration // How often to refresh limits from panel (default: 60s) + CleanupInterval time.Duration // How often to clean inactive users (default: 1h) + InactiveTimeout time.Duration // Remove users not seen for this duration (default: 24h) +} + +// NewLimitEnforcer creates a new limit enforcer +func NewLimitEnforcer(controller *Controller, cfg LimitEnforcerConfig) *LimitEnforcer { + if cfg.CheckInterval == 0 { + cfg.CheckInterval = 30 * time.Second + } + if cfg.RefreshInterval == 0 { + cfg.RefreshInterval = 60 * time.Second + } + if cfg.CleanupInterval == 0 { + cfg.CleanupInterval = 1 * time.Hour + } + if cfg.InactiveTimeout == 0 { + cfg.InactiveTimeout = 24 * time.Hour + } + + limitsCache := common.NewTrafficLimitsCache(cfg.NodeID, cfg.PanelAPIURL, cfg.APIKey) + + return &LimitEnforcer{ + controller: controller, + limitsCache: limitsCache, + checkInterval: cfg.CheckInterval, + stopChan: make(chan struct{}), + userTraffic: make(map[int]*userTrafficEntry), + cleanupInterval: cfg.CleanupInterval, + inactiveTimeout: cfg.InactiveTimeout, + } +} + +// Start begins monitoring traffic and enforcing limits +func (le *LimitEnforcer) Start(ctx context.Context, refreshInterval time.Duration) { + // Start limits cache auto-refresh + le.limitsCache.StartAutoRefresh(refreshInterval) + + // Start traffic monitoring + go le.monitorTraffic(ctx) + + // Start periodic cleanup of inactive users + go le.periodicCleanup(ctx) + + log.Println("Limit enforcer started") +} + +// Stop gracefully stops the limit enforcer +func (le *LimitEnforcer) Stop() { + close(le.stopChan) + le.limitsCache.Stop() + log.Println("Limit enforcer stopped") +} + +// GetLimitsCache returns the underlying traffic limits cache +func (le *LimitEnforcer) GetLimitsCache() *common.TrafficLimitsCache { + return le.limitsCache +} + +// GetMetrics returns current metrics snapshot +func (le *LimitEnforcer) GetMetrics() map[string]int64 { + le.trafficMu.RLock() + trackedUsers := int64(len(le.userTraffic)) + le.trafficMu.RUnlock() + + le.Metrics.TrackedUsers.Store(trackedUsers) + le.Metrics.CachedLimits.Store(int64(le.limitsCache.Count())) + + return map[string]int64{ + "users_removed": le.Metrics.UsersRemoved.Load(), + "checks_performed": le.Metrics.ChecksPerformed.Load(), + "check_errors": le.Metrics.CheckErrors.Load(), + "refresh_count": le.Metrics.RefreshCount.Load(), + "tracked_users": le.Metrics.TrackedUsers.Load(), + "cached_limits": le.Metrics.CachedLimits.Load(), + "last_check_duration_ms": le.Metrics.LastCheckDuration.Load(), + "last_check_time": le.Metrics.LastCheckTime.Load(), + } +} + +// monitorTraffic periodically checks user traffic against limits +func (le *LimitEnforcer) monitorTraffic(ctx context.Context) { + ticker := time.NewTicker(le.checkInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-le.stopChan: + return + case <-ticker.C: + le.checkAndEnforceLimits(ctx) + } + } +} + +// periodicCleanup removes inactive users from tracking map +func (le *LimitEnforcer) periodicCleanup(ctx context.Context) { + ticker := time.NewTicker(le.cleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-le.stopChan: + return + case <-ticker.C: + le.cleanupInactiveUsers() + } + } +} + +// cleanupInactiveUsers removes users not seen for inactiveTimeout duration +func (le *LimitEnforcer) cleanupInactiveUsers() { + le.trafficMu.Lock() + defer le.trafficMu.Unlock() + + now := time.Now() + removed := 0 + + for userID, entry := range le.userTraffic { + if now.Sub(entry.lastSeen) > le.inactiveTimeout { + delete(le.userTraffic, userID) + removed++ + } + } + + if removed > 0 { + log.Printf("Limit enforcer cleanup: removed %d inactive users from tracking", removed) + } +} + +// checkAndEnforceLimits fetches user stats and removes users over their limits +func (le *LimitEnforcer) checkAndEnforceLimits(ctx context.Context) { + startTime := time.Now() + le.Metrics.ChecksPerformed.Add(1) + le.Metrics.LastCheckTime.Store(startTime.Unix()) + + defer func() { + le.Metrics.LastCheckDuration.Store(time.Since(startTime).Milliseconds()) + }() + + b := le.controller.Backend() + if b == nil { + return + } + + // Get user stats from xray (with reset=true to get delta) + statsCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + stats, err := b.GetStats(statsCtx, &common.StatRequest{ + Type: common.StatType_UsersStat, + Reset_: true, // Get delta since last check and reset counters + }) + if err != nil { + le.Metrics.CheckErrors.Add(1) + log.Printf("Limit enforcer: failed to get user stats: %v", err) + return + } + + now := time.Now() + + // Track users to remove + usersToRemove := make(map[string]int64) // email -> traffic + + le.trafficMu.Lock() + for _, stat := range stats.GetStats() { + // Parse user email from stat name (format: "user>>>email>>>traffic>>>uplink/downlink") + userID, email := parseUserStatName(stat.GetName()) + if userID == 0 { + continue + } + + // Get or create traffic entry + entry, exists := le.userTraffic[userID] + if !exists { + entry = &userTrafficEntry{} + le.userTraffic[userID] = entry + } + + // Accumulate traffic and update last seen + entry.traffic += stat.GetValue() + entry.lastSeen = now + totalTraffic := entry.traffic + + // Check if user has a limit configured + limit := le.limitsCache.GetLimit(userID) + if limit <= 0 { + continue // No limit configured, skip + } + + // Check if over limit + if totalTraffic >= limit { + usersToRemove[email] = totalTraffic + log.Printf("Limit enforcer: user %s (ID: %d) exceeded limit: %d >= %d bytes", + email, userID, totalTraffic, limit) + } + } + le.trafficMu.Unlock() + + // Remove users who exceeded their limits + if len(usersToRemove) > 0 { + le.removeOverLimitUsers(ctx, usersToRemove) + } +} + +// parseUserStatName parses user ID and email from xray stat name +// Format: "user>>>1.username>>>traffic>>>uplink" or "user>>>1.username>>>traffic>>>downlink" +func parseUserStatName(name string) (int, string) { + parts := strings.Split(name, ">>>") + if len(parts) < 2 || parts[0] != "user" { + return 0, "" + } + + email := parts[1] + + // Extract user ID from email (format: "id.username") + emailParts := strings.SplitN(email, ".", 2) + if len(emailParts) < 2 { + return 0, email + } + + userID, err := strconv.Atoi(emailParts[0]) + if err != nil { + return 0, email + } + + return userID, email +} + +// removeOverLimitUsers removes users from all xray inbounds +func (le *LimitEnforcer) removeOverLimitUsers(ctx context.Context, users map[string]int64) { + b := le.controller.Backend() + if b == nil { + return + } + + // Create empty user entries to trigger removal + for email, traffic := range users { + // Create a user with empty inbounds to remove from all inbounds + emptyUser := &common.User{ + Email: email, + Inbounds: []string{}, // Empty inbounds = remove from all + } + + if err := b.SyncUser(ctx, emptyUser); err != nil { + log.Printf("Limit enforcer: failed to remove user %s: %v", email, err) + } else { + le.Metrics.UsersRemoved.Add(1) + log.Printf("Limit enforcer: removed user %s (exceeded limit by %d bytes)", email, traffic) + } + } +} + +// ResetUserTraffic resets accumulated traffic for a user (called when limits are reset) +func (le *LimitEnforcer) ResetUserTraffic(userID int) { + le.trafficMu.Lock() + delete(le.userTraffic, userID) + le.trafficMu.Unlock() +} + +// ResetAllTraffic resets all accumulated traffic (called on full sync from panel) +func (le *LimitEnforcer) ResetAllTraffic() { + le.trafficMu.Lock() + le.userTraffic = make(map[int]*userTrafficEntry) + le.trafficMu.Unlock() +} + +// UpdateLimitsFromPush updates limits from a gRPC push +func (le *LimitEnforcer) UpdateLimitsFromPush(limits []common.NodeUserLimit, fullSync bool) { + le.limitsCache.UpdateFromPush(limits, fullSync) + + if fullSync { + // On full sync, reset traffic tracking + le.ResetAllTraffic() + } +} diff --git a/controller/limit_enforcer_test.go b/controller/limit_enforcer_test.go new file mode 100644 index 0000000..a921b7c --- /dev/null +++ b/controller/limit_enforcer_test.go @@ -0,0 +1,85 @@ +package controller + +import ( + "testing" +) + +func TestParseUserStatName(t *testing.T) { + tests := []struct { + name string + input string + expectedID int + expectedMail string + }{ + { + name: "valid uplink stat", + input: "user>>>1.username>>>traffic>>>uplink", + expectedID: 1, + expectedMail: "1.username", + }, + { + name: "valid downlink stat", + input: "user>>>42.testuser>>>traffic>>>downlink", + expectedID: 42, + expectedMail: "42.testuser", + }, + { + name: "user with dots in name", + input: "user>>>123.user.with.dots>>>traffic>>>uplink", + expectedID: 123, + expectedMail: "123.user.with.dots", + }, + { + name: "invalid - not user stat", + input: "inbound>>>vmess>>>traffic>>>uplink", + expectedID: 0, + expectedMail: "", + }, + { + name: "invalid - no parts", + input: "invalid", + expectedID: 0, + expectedMail: "", + }, + { + name: "invalid - no user id", + input: "user>>>noIdUsername>>>traffic>>>uplink", + expectedID: 0, + expectedMail: "noIdUsername", + }, + { + name: "invalid - non-numeric id", + input: "user>>>abc.username>>>traffic>>>uplink", + expectedID: 0, + expectedMail: "abc.username", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id, email := parseUserStatName(tt.input) + if id != tt.expectedID { + t.Errorf("parseUserStatName(%q) id = %d, want %d", tt.input, id, tt.expectedID) + } + if email != tt.expectedMail { + t.Errorf("parseUserStatName(%q) email = %q, want %q", tt.input, email, tt.expectedMail) + } + }) + } +} + +func TestLimitEnforcerConfig_Defaults(t *testing.T) { + // Test default values are applied + cfg := LimitEnforcerConfig{ + NodeID: 1, + PanelAPIURL: "http://localhost", + APIKey: "test-key", + } + + if cfg.CheckInterval != 0 { + t.Error("CheckInterval should be 0 before NewLimitEnforcer") + } + if cfg.RefreshInterval != 0 { + t.Error("RefreshInterval should be 0 before NewLimitEnforcer") + } +} diff --git a/controller/rest/base.go b/controller/rest/base.go index 533a86f..dc9df7f 100644 --- a/controller/rest/base.go +++ b/controller/rest/base.go @@ -10,6 +10,7 @@ import ( "github.com/pasarguard/node/backend" "github.com/pasarguard/node/backend/xray" "github.com/pasarguard/node/common" + "github.com/pasarguard/node/controller" ) func (s *Service) Base(w http.ResponseWriter, _ *http.Request) { @@ -17,7 +18,7 @@ func (s *Service) Base(w http.ResponseWriter, _ *http.Request) { } func (s *Service) Start(w http.ResponseWriter, r *http.Request) { - ctx, backendType, keepAlive, err := s.detectBackend(r) + ctx, backendType, keepAlive, limitParams, err := s.detectBackend(r) if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return @@ -41,6 +42,9 @@ func (s *Service) Start(w http.ResponseWriter, r *http.Request) { return } + // Start limit enforcer if panel provided configuration + s.StartLimitEnforcer(limitParams) + common.SendProtoResponse(w, s.BaseInfoResponse()) } @@ -50,25 +54,32 @@ func (s *Service) Stop(w http.ResponseWriter, _ *http.Request) { common.SendProtoResponse(w, &common.Empty{}) } -func (s *Service) detectBackend(r *http.Request) (context.Context, common.BackendType, uint64, error) { +func (s *Service) detectBackend(r *http.Request) (context.Context, common.BackendType, uint64, controller.LimitEnforcerParams, error) { var data common.Backend var ctx context.Context if err := common.ReadProtoBody(r.Body, &data); err != nil { - return nil, 0, 0, err + return nil, 0, 0, controller.LimitEnforcerParams{}, err } if data.Type == common.BackendType_XRAY { config, err := xray.NewXRayConfig(data.GetConfig(), data.GetExcludeInbounds()) if err != nil { - return nil, 0, 0, err + return nil, 0, 0, controller.LimitEnforcerParams{}, err } ctx = context.WithValue(r.Context(), backend.ConfigKey{}, config) } else { - return ctx, data.GetType(), data.GetKeepAlive(), errors.New("invalid backend type") + return ctx, data.GetType(), data.GetKeepAlive(), controller.LimitEnforcerParams{}, errors.New("invalid backend type") } ctx = context.WithValue(ctx, backend.UsersKey{}, data.GetUsers()) - return ctx, data.GetType(), data.GetKeepAlive(), nil + limitParams := controller.LimitEnforcerParams{ + NodeID: data.GetNodeId(), + PanelAPIURL: data.GetPanelApiUrl(), + LimitCheckInterval: data.GetLimitCheckInterval(), + LimitRefreshInterval: data.GetLimitRefreshInterval(), + } + + return ctx, data.GetType(), data.GetKeepAlive(), limitParams, nil } diff --git a/controller/rest/stats.go b/controller/rest/stats.go index f7b9529..394e718 100644 --- a/controller/rest/stats.go +++ b/controller/rest/stats.go @@ -1,9 +1,11 @@ package rest import ( - "google.golang.org/grpc/status" + "encoding/json" "net/http" + "google.golang.org/grpc/status" + "github.com/pasarguard/node/common" ) @@ -77,3 +79,22 @@ func (s *Service) GetBackendStats(w http.ResponseWriter, r *http.Request) { func (s *Service) GetSystemStats(w http.ResponseWriter, _ *http.Request) { common.SendProtoResponse(w, s.SystemStats()) } + +// GetLimitEnforcerMetrics returns limit enforcer metrics as JSON +func (s *Service) GetLimitEnforcerMetrics(w http.ResponseWriter, _ *http.Request) { + enforcer := s.Controller.GetLimitEnforcer() + if enforcer == nil { + http.Error(w, "limit enforcer not enabled", http.StatusNotFound) + return + } + + metrics := enforcer.GetMetrics() + + w.Header().Set("Content-Type", "application/json") + data, err := json.Marshal(metrics) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Write(data) +} diff --git a/controller/rpc/base.go b/controller/rpc/base.go index 507945b..c2adb1f 100644 --- a/controller/rpc/base.go +++ b/controller/rpc/base.go @@ -9,6 +9,7 @@ import ( "github.com/pasarguard/node/backend" "github.com/pasarguard/node/backend/xray" "github.com/pasarguard/node/common" + "github.com/pasarguard/node/controller" "google.golang.org/grpc/peer" ) @@ -46,6 +47,14 @@ func (s *Service) Start(ctx context.Context, detail *common.Backend) (*common.Ba s.Connect(clientIP, detail.GetKeepAlive()) + // Start limit enforcer if panel provided configuration + s.StartLimitEnforcer(controller.LimitEnforcerParams{ + NodeID: detail.GetNodeId(), + PanelAPIURL: detail.GetPanelApiUrl(), + LimitCheckInterval: detail.GetLimitCheckInterval(), + LimitRefreshInterval: detail.GetLimitRefreshInterval(), + }) + return s.BaseInfoResponse(), nil } diff --git a/pasarguard-node b/pasarguard-node new file mode 100644 index 0000000..dcf831f Binary files /dev/null and b/pasarguard-node differ