Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ require (
github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874
github.com/coocood/freecache v1.2.4
github.com/envoyproxy/go-control-plane v0.14.0
github.com/envoyproxy/go-control-plane/envoy v1.36.0
github.com/envoyproxy/go-control-plane/envoy v1.37.1-0.20260409083702-98966259b99a
github.com/envoyproxy/go-control-plane/ratelimit v0.1.1-0.20260131204543-4ca8b9cded3e
github.com/go-kit/log v0.2.1
github.com/golang/mock v1.6.0
Expand Down Expand Up @@ -50,7 +50,7 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect
github.com/envoyproxy/protoc-gen-validate v1.3.3 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA=
github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU=
github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g=
github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98=
github.com/envoyproxy/go-control-plane/envoy v1.37.1-0.20260409083702-98966259b99a h1:qrP4J6AWJ9yd6CINhPMRL/MbFXNiV7qimRsCDTOV0a0=
github.com/envoyproxy/go-control-plane/envoy v1.37.1-0.20260409083702-98966259b99a/go.mod h1:5yRfenlmRH8sxKrhXyiFtK8BDz3syDWcFm81rkCcATM=
github.com/envoyproxy/go-control-plane/ratelimit v0.1.1-0.20260131204543-4ca8b9cded3e h1:EHL6eLDhQduyYGEKh+QSXE7s7Yhg/hpeeHFT0ET0gBw=
github.com/envoyproxy/go-control-plane/ratelimit v0.1.1-0.20260131204543-4ca8b9cded3e/go.mod h1:buWyXJdrI6ayYbeGm3upu3Qf/qHHrdWfUHKnVrTD+vM=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4=
github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA=
github.com/envoyproxy/protoc-gen-validate v1.3.3 h1:MVQghNeW+LZcmXe7SY1V36Z+WFMDjpqGAGacLe2T0ds=
github.com/envoyproxy/protoc-gen-validate v1.3.3/go.mod h1:TsndJ/ngyIdQRhMcVVGDDHINPLWB7C82oDArY51KfB0=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
Expand Down
17 changes: 15 additions & 2 deletions src/limiter/base_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ import (
"github.com/envoyproxy/ratelimit/src/utils"
)

// DecrementScript atomically decrements a rate limit counter, floored at 0.
const DecrementScript = `
local current = tonumber(redis.call('GET', KEYS[1]) or '0') -- get current count, default 0
local new_val = math.max(0, current - tonumber(ARGV[1])) -- subtract hits, floor at 0
redis.call('SET', KEYS[1], tostring(math.floor(new_val))) -- persist new value
redis.call('EXPIRE', KEYS[1], tonumber(ARGV[2])) -- reset TTL
return new_val -- return count after decrement
`

type BaseRateLimiter struct {
timeSource utils.TimeSource
JitterRand *rand.Rand
Expand Down Expand Up @@ -44,7 +53,7 @@ func NewRateLimitInfo(limit *config.RateLimit, limitBeforeIncrease uint64, limit
// Generates cache keys for given rate limit request. Each cache key is represented by a concatenation of
// domain, descriptor and current timestamp.
func (this *BaseRateLimiter) GenerateCacheKeys(request *pb.RateLimitRequest,
limits []*config.RateLimit, hitsAddends []uint64,
limits []*config.RateLimit, hitsAddends []utils.HitsAddend,
) []CacheKey {
assert.Assert(len(request.Descriptors) == len(limits))
cacheKeys := make([]CacheKey, len(request.Descriptors))
Expand All @@ -55,7 +64,11 @@ func (this *BaseRateLimiter) GenerateCacheKeys(request *pb.RateLimitRequest,
cacheKeys[i] = this.cacheKeyGenerator.GenerateCacheKey(request.Domain, request.Descriptors[i], limits[i], now)
// Increase statistics for limits hit by their respective requests.
if limits[i] != nil {
limits[i].Stats.TotalHits.Add(hitsAddends[i])
if hitsAddends[i].IsNegative {
limits[i].Stats.TotalNegativeHits.Add(hitsAddends[i].Value)
} else {
limits[i].Stats.TotalHits.Add(hitsAddends[i].Value)
}
}
}
return cacheKeys
Expand Down
43 changes: 34 additions & 9 deletions src/memcached/cache_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ func (this *rateLimitMemcacheImpl) DoLimit(
continue
}

// Negative hits skip the over-limit check — they always proceed.
if hitsAddends[i].IsNegative {
keysToGet = append(keysToGet, cacheKey.Key)
continue
}

// Check if key is over the limit in local cache.
if this.baseRateLimiter.IsOverLimitWithLocalCache(cacheKey.Key) {
isOverLimitWithLocalCache[i] = true
Expand Down Expand Up @@ -129,15 +135,25 @@ func (this *rateLimitMemcacheImpl) DoLimit(
} else {
limitBeforeIncrease = uint64(decoded)
}

}

limitAfterIncrease := limitBeforeIncrease + hitsAddends[i]
if hitsAddends[i].IsNegative {
// Predict the post-decrement value (guard against uint64 underflow).
var limitAfterDecrease uint64
if limitBeforeIncrease > hitsAddends[i].Value {
limitAfterDecrease = limitBeforeIncrease - hitsAddends[i].Value
}
responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key,
limiter.NewRateLimitInfo(limits[i], limitAfterDecrease, limitAfterDecrease, 0, 0),
false, 0)
} else {
limitAfterIncrease := limitBeforeIncrease + hitsAddends[i].Value

limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0)
limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0)

responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key,
limitInfo, isOverLimitWithLocalCache[i], hitsAddends[i])
responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key,
limitInfo, isOverLimitWithLocalCache[i], hitsAddends[i].Value)
}
}

this.waitGroup.Add(1)
Expand All @@ -150,15 +166,24 @@ func (this *rateLimitMemcacheImpl) DoLimit(
}

func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, isOverLimitWithLocalCache []bool,
limits []*config.RateLimit, hitsAddends []uint64,
limits []*config.RateLimit, hitsAddends []utils.HitsAddend,
) {
defer this.waitGroup.Done()
for i, cacheKey := range cacheKeys {
if cacheKey.Key == "" || isOverLimitWithLocalCache[i] {
continue
}

_, err := this.client.Increment(cacheKey.Key, hitsAddends[i])
if hitsAddends[i].IsNegative {
// Memcached Decrement natively floors at 0.
_, err := this.client.Decrement(cacheKey.Key, hitsAddends[i].Value)
if err != nil && err != memcache.ErrCacheMiss {
logger.Errorf("Failed to decrement key %s: %s", cacheKey.Key, err)
}
continue
}

_, err := this.client.Increment(cacheKey.Key, hitsAddends[i].Value)
if err == memcache.ErrCacheMiss {
expirationSeconds := utils.UnitToDivider(limits[i].Limit.Unit)
if this.expirationJitterMaxSeconds > 0 {
Expand All @@ -168,13 +193,13 @@ func (this *rateLimitMemcacheImpl) increaseAsync(cacheKeys []limiter.CacheKey, i
// Need to add instead of increment.
err = this.client.Add(&memcache.Item{
Key: cacheKey.Key,
Value: []byte(strconv.FormatUint(hitsAddends[i], 10)),
Value: []byte(strconv.FormatUint(hitsAddends[i].Value, 10)),
Expiration: int32(expirationSeconds),
})
if err == memcache.ErrNotStored {
// There was a race condition to do this add. We should be able to increment
// now instead.
_, err := this.client.Increment(cacheKey.Key, hitsAddends[i])
_, err := this.client.Increment(cacheKey.Key, hitsAddends[i].Value)
if err != nil {
logger.Errorf("Failed to increment key %s after failing to add: %s", cacheKey.Key, err)
continue
Expand Down
1 change: 1 addition & 0 deletions src/memcached/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ var _ Client = (*memcache.Client)(nil)
type Client interface {
GetMulti(keys []string) (map[string]*memcache.Item, error)
Increment(key string, delta uint64) (newValue uint64, err error)
Decrement(key string, delta uint64) (newValue uint64, err error)
Add(item *memcache.Item) error
}
19 changes: 19 additions & 0 deletions src/memcached/stats_collecting_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ type statsCollectingClient struct {
incrementSuccess stats.Counter
incrementMiss stats.Counter
incrementError stats.Counter
decrementSuccess stats.Counter
decrementMiss stats.Counter
decrementError stats.Counter
addSuccess stats.Counter
addError stats.Counter
addNotStored stats.Counter
Expand All @@ -28,6 +31,9 @@ func CollectStats(c Client, scope stats.Scope) Client {
incrementSuccess: scope.NewCounterWithTags("increment", map[string]string{"code": "success"}),
incrementMiss: scope.NewCounterWithTags("increment", map[string]string{"code": "miss"}),
incrementError: scope.NewCounterWithTags("increment", map[string]string{"code": "error"}),
decrementSuccess: scope.NewCounterWithTags("decrement", map[string]string{"code": "success"}),
decrementMiss: scope.NewCounterWithTags("decrement", map[string]string{"code": "miss"}),
decrementError: scope.NewCounterWithTags("decrement", map[string]string{"code": "error"}),
addSuccess: scope.NewCounterWithTags("add", map[string]string{"code": "success"}),
addError: scope.NewCounterWithTags("add", map[string]string{"code": "error"}),
addNotStored: scope.NewCounterWithTags("add", map[string]string{"code": "not_stored"}),
Expand Down Expand Up @@ -64,6 +70,19 @@ func (scc statsCollectingClient) Increment(key string, delta uint64) (newValue u
return
}

func (scc statsCollectingClient) Decrement(key string, delta uint64) (newValue uint64, err error) {
newValue, err = scc.c.Decrement(key, delta)
switch err {
case memcache.ErrCacheMiss:
scc.decrementMiss.Inc()
case nil:
scc.decrementSuccess.Inc()
default:
scc.decrementError.Inc()
}
return
}

func (scc statsCollectingClient) Add(item *memcache.Item) error {
err := scc.c.Add(item)

Expand Down
65 changes: 41 additions & 24 deletions src/redis/fixed_cache_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,28 @@ func pipelineAppend(client Client, pipeline *Pipeline, key string, hitsAddend ui
*pipeline = client.PipeAppend(*pipeline, nil, "EXPIRE", key, expirationSeconds)
}

func pipelineAppendDecrement(client Client, pipeline *Pipeline, key string, hitsAddend uint64, result *uint64, expirationSeconds int64) {
*pipeline = client.PipeAppend(*pipeline, result, "EVAL", limiter.DecrementScript, 1, key, hitsAddend, expirationSeconds)
}

func (this *fixedRateLimitCacheImpl) selectPipeline(cacheKey limiter.CacheKey, pipeline *Pipeline, perSecondPipeline *Pipeline) (Client, *Pipeline) {
if this.perSecondClient != nil && cacheKey.PerSecond {
if *perSecondPipeline == nil {
*perSecondPipeline = Pipeline{}
}
return this.perSecondClient, perSecondPipeline
}
if *pipeline == nil {
*pipeline = Pipeline{}
}
return this.client, pipeline
}

func pipelineAppendtoGet(client Client, pipeline *Pipeline, key string, result *uint64) {
*pipeline = client.PipeAppend(*pipeline, result, "GET", key)
}

func (this *fixedRateLimitCacheImpl) getHitsAddend(hitsAddend uint64, isCacheKeyOverlimit, isCacheKeyNearlimit,
func (this *fixedRateLimitCacheImpl) getHitsAddendValue(hitsAddend uint64, isCacheKeyOverlimit, isCacheKeyNearlimit,
isNearLimt bool,
) uint64 {
// If stopCacheKeyIncrementWhenOverlimit is false, then we always increment the cache key.
Expand Down Expand Up @@ -94,8 +111,9 @@ func (this *fixedRateLimitCacheImpl) DoLimit(
isCacheKeyNearlimit := false

// Check if any of the keys are already to the over limit in cache.
// Negative hits (decrements) skip this check — they always proceed.
for i, cacheKey := range cacheKeys {
if cacheKey.Key == "" {
if cacheKey.Key == "" || hitsAddends[i].IsNegative {
continue
}

Expand All @@ -116,7 +134,7 @@ func (this *fixedRateLimitCacheImpl) DoLimit(
// then we check if any of the keys are near limit in redis cache.
if this.stopCacheKeyIncrementWhenOverlimit && !isCacheKeyOverlimit {
for i, cacheKey := range cacheKeys {
if cacheKey.Key == "" {
if cacheKey.Key == "" || hitsAddends[i].IsNegative {
continue
}

Expand All @@ -141,12 +159,12 @@ func (this *fixedRateLimitCacheImpl) DoLimit(
}

for i, cacheKey := range cacheKeys {
if cacheKey.Key == "" {
if cacheKey.Key == "" || hitsAddends[i].IsNegative {
continue
}
// Now fetch the pipeline.
limitBeforeIncrease := currentCount[i]
limitAfterIncrease := limitBeforeIncrease + hitsAddends[i]
limitAfterIncrease := limitBeforeIncrease + hitsAddends[i].Value

limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0)

Expand All @@ -157,7 +175,7 @@ func (this *fixedRateLimitCacheImpl) DoLimit(
}
}

// Now, actually setup the pipeline to increase the usage of cache key, skipping empty cache keys.
// Now, actually setup the pipeline to increase/decrease the usage of cache key, skipping empty cache keys.
for i, cacheKey := range cacheKeys {
if cacheKey.Key == "" || overlimitIndexes[i] {
continue
Expand All @@ -170,19 +188,12 @@ func (this *fixedRateLimitCacheImpl) DoLimit(
expirationSeconds += this.baseRateLimiter.JitterRand.Int63n(this.baseRateLimiter.ExpirationJitterMaxSeconds)
}

// Use the perSecondConn if it is not nil and the cacheKey represents a per second Limit.
if this.perSecondClient != nil && cacheKey.PerSecond {
if perSecondPipeline == nil {
perSecondPipeline = Pipeline{}
}
pipelineAppend(this.perSecondClient, &perSecondPipeline, cacheKey.Key, this.getHitsAddend(hitsAddends[i],
isCacheKeyOverlimit, isCacheKeyNearlimit, nearlimitIndexes[i]), &results[i], expirationSeconds)
client, p := this.selectPipeline(cacheKey, &pipeline, &perSecondPipeline)
if hitsAddends[i].IsNegative {
pipelineAppendDecrement(client, p, cacheKey.Key, hitsAddends[i].Value, &results[i], expirationSeconds)
} else {
if pipeline == nil {
pipeline = Pipeline{}
}
pipelineAppend(this.client, &pipeline, cacheKey.Key, this.getHitsAddend(hitsAddends[i], isCacheKeyOverlimit,
isCacheKeyNearlimit, nearlimitIndexes[i]), &results[i], expirationSeconds)
pipelineAppend(client, p, cacheKey.Key, this.getHitsAddendValue(hitsAddends[i].Value,
isCacheKeyOverlimit, isCacheKeyNearlimit, nearlimitIndexes[i]), &results[i], expirationSeconds)
}
}

Expand All @@ -206,14 +217,20 @@ func (this *fixedRateLimitCacheImpl) DoLimit(
responseDescriptorStatuses := make([]*pb.RateLimitResponse_DescriptorStatus,
len(request.Descriptors))
for i, cacheKey := range cacheKeys {
if hitsAddends[i].IsNegative {
// Negative hits always return OK with the remaining capacity.
responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key,
limiter.NewRateLimitInfo(limits[i], results[i], results[i], 0, 0),
false, 0)
} else {
limitAfterIncrease := results[i]
limitBeforeIncrease := limitAfterIncrease - hitsAddends[i].Value

limitAfterIncrease := results[i]
limitBeforeIncrease := limitAfterIncrease - hitsAddends[i]

limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0)
limitInfo := limiter.NewRateLimitInfo(limits[i], limitBeforeIncrease, limitAfterIncrease, 0, 0)

responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key,
limitInfo, isOverLimitWithLocalCache[i], hitsAddends[i])
responseDescriptorStatuses[i] = this.baseRateLimiter.GetResponseDescriptorStatus(cacheKey.Key,
limitInfo, isOverLimitWithLocalCache[i], hitsAddends[i].Value)
}
}

return responseDescriptorStatuses
Expand Down
1 change: 1 addition & 0 deletions src/stats/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ type RateLimitStats struct {
OverLimitWithLocalCache gostats.Counter
WithinLimit gostats.Counter
ShadowMode gostats.Counter
TotalNegativeHits gostats.Counter
}

// Stats for a domain entry
Expand Down
1 change: 1 addition & 0 deletions src/stats/manager_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func (this *ManagerImpl) NewStats(key string) RateLimitStats {
ret.OverLimitWithLocalCache = this.rlStatsScope.NewCounter(key + ".over_limit_with_local_cache")
ret.WithinLimit = this.rlStatsScope.NewCounter(key + ".within_limit")
ret.ShadowMode = this.rlStatsScope.NewCounter(key + ".shadow_mode")
ret.TotalNegativeHits = this.rlStatsScope.NewCounter(key + ".total_negative_hits")
return ret
}

Expand Down
14 changes: 10 additions & 4 deletions src/utils/utilities.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,25 @@ func SanitizeStatName(s string) string {
})
}

func GetHitsAddends(request *pb.RateLimitRequest) []uint64 {
hitsAddends := make([]uint64, len(request.Descriptors))
type HitsAddend struct {
Value uint64
IsNegative bool
}

func GetHitsAddends(request *pb.RateLimitRequest) []HitsAddend {
hitsAddends := make([]HitsAddend, len(request.Descriptors))

for i, descriptor := range request.Descriptors {
if descriptor.HitsAddend != nil {
// If the per descriptor hits_addend is set, use that. It allows to be zero. The zero value is
// means check only by no increment the hits.
hitsAddends[i] = descriptor.HitsAddend.Value
hitsAddends[i].Value = descriptor.HitsAddend.Value
} else {
// If the per descriptor hits_addend is not set, use the request's hits_addend. If the value is
// zero (default value if not specified by the caller), use 1 for backward compatibility.
hitsAddends[i] = uint64(max(1, uint64(request.HitsAddend)))
hitsAddends[i].Value = uint64(max(1, uint64(request.HitsAddend)))
}
hitsAddends[i].IsNegative = descriptor.GetIsNegativeHits()
}
return hitsAddends
}
11 changes: 11 additions & 0 deletions test/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ func NewRateLimitRequestWithPerDescriptorHitsAddend(domain string, descriptors [
return request
}

func NewRateLimitRequestWithNegativeHits(domain string, descriptors [][][2]string,
hitsAddends []uint64, negativeFlags []bool,
) *pb.RateLimitRequest {
request := NewRateLimitRequest(domain, descriptors, 1)
for i, hitsAddend := range hitsAddends {
request.Descriptors[i].HitsAddend = &wrapperspb.UInt64Value{Value: hitsAddend}
request.Descriptors[i].IsNegativeHits = negativeFlags[i]
}
return request
}

func AssertProtoEqual(assert *assert.Assertions, expected proto.Message, actual proto.Message) {
assert.True(proto.Equal(expected, actual),
fmt.Sprintf("These two protobuf messages are not equal:\nexpected: %v\nactual: %v", expected, actual))
Expand Down
Loading
Loading