Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions api/auth.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Optional
from fastapi import Header, HTTPException, Depends
from sqlalchemy.orm.exc import NoResultFound
from rrl import RateLimiter, Tier, RateLimitExceeded
from .db import SessionLocal, get_db, models
from .rate_limiter import V3RateLimiter, Tier, RateLimitExceeded

limiter = RateLimiter(
limiter = V3RateLimiter(
prefix="v3",
tiers=[
Tier("default", 10, 0, 250),
Expand Down Expand Up @@ -37,7 +37,7 @@ def apikey_auth(
.one()
)
try:
limiter.check_limit(provided_apikey, key.api_tier)
limiter.check_limit_and_increment_counters(provided_apikey, key.api_tier)
except RateLimitExceeded as e:
raise HTTPException(429, detail=str(e))
except ValueError:
Expand Down
153 changes: 153 additions & 0 deletions api/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import os
import datetime
import typing
from dataclasses import dataclass
from redis import Redis


@dataclass
class Tier:
name: str
per_minute: int
per_hour: int
per_day: int


@dataclass
class DailyUsage:
date: datetime.date
calls: int


class RateLimitExceeded(Exception):
pass


def _get_redis_connection() -> Redis:
host = os.environ.get("RRL_REDIS_HOST", "localhost")
port = int(os.environ.get("RRL_REDIS_PORT", 6379))
db = int(os.environ.get("RRL_REDIS_DB", 0))
return Redis(host=host, port=port, db=db)


class V3RateLimiter:
"""
<prefix>:<key>:<hour><minute> expires in 2 minutes
<prefix>:<key>:<hour> expires in 2 hours
<prefix>:<key>:<day> never expires
"""

def __init__(
self,
tiers: typing.List[Tier],
*,
prefix: str = "",
use_redis_time: bool = True,
track_daily_usage: bool = True,
):
self.redis = _get_redis_connection()
self.tiers = {tier.name: tier for tier in tiers}
self.prefix = prefix
self.use_redis_time = use_redis_time
self.track_daily_usage = track_daily_usage

def check_limit_and_increment_counters(self, key: str, tier_name: str) -> bool:
try:
tier = self.tiers[tier_name]
except KeyError:
raise ValueError(f"unknown tier: {tier_name}")
if self.use_redis_time:
timestamp = self.redis.time()[0]
now = datetime.datetime.fromtimestamp(timestamp)
else:
now = datetime.datetime.utcnow()

# check AND increment usage counters
pipe = self.redis.pipeline()
day = now.strftime("%Y%m%d")
day_key = f"{self.prefix}:{key}:d{day}"
day_requests_key = f"{self.prefix}:{key}:dr{day}"
if tier.per_minute:
minute_key = f"{self.prefix}:{key}:m{now.minute}"
pipe.incr(minute_key)
pipe.expire(minute_key, 60)
if tier.per_hour:
hour_key = f"{self.prefix}:{key}:h{now.hour}"
pipe.incr(hour_key)
pipe.expire(hour_key, 3600)
if tier.per_day or self.track_daily_usage:
# Keep separate day and day-requests keys
# day key: used for aggregate usage tracking, so we want to limit this to
# track allowed requests the user has used
# day-requests key: tracking how many TOTAL (incl blocked) requests made
pipe.incr(day_key)
pipe.incr(day_requests_key)
# keep data around for usage tracking
if not self.track_daily_usage:
pipe.expire(day_key, 86400)
pipe.expire(day_requests_key, 86400)
result = pipe.execute()

# parse redis pipeline results
# the result is pairs of results of incr and expire calls, so if all 3 limits are set
# it looks like [per_minute_calls, True, per_hour_calls, True, per_day_allowed_calls, per_day_raw_calls]
# we increment value_pos as we consume values so we know which location we're looking at
value_pos = 0
minute_calls = hour_calls = day_calls = 0
minute_exceeded = hour_exceeded = day_exceeded = False
if tier.per_minute:
minute_calls = result[value_pos]
if result[value_pos] > tier.per_minute:
minute_exceeded = True
value_pos += 2
if tier.per_hour:
hour_calls = result[value_pos]
if result[value_pos] > tier.per_hour:
hour_exceeded = True
value_pos += 2
if tier.per_day:
# report back the # of raw requests, not just allowed requests
day_calls = result[value_pos + 1]
if result[value_pos] > tier.per_day:
day_exceeded = True
# daily usage numbers are used to report overall usage
# so actually want to decrement back to the prior value
# otherwise the usage count for the day will include all *blocked* requests
self.redis.decr(day_key)

# Raise appropriate exception if limit exceeded
if minute_exceeded:
raise RateLimitExceeded(
f"exceeded limit of {tier.per_minute}/min: {minute_calls}"
)
if hour_exceeded:
raise RateLimitExceeded(
f"exceeded limit of {tier.per_hour}/hour: {hour_calls}"
)
if day_exceeded:
raise RateLimitExceeded(
f"exceeded limit of {tier.per_day}/day: {day_calls}"
)

return True

def get_usage_since(
self,
key: str,
start: datetime.date,
end: typing.Optional[datetime.date] = None,
) -> typing.List[DailyUsage]:
if not self.track_daily_usage:
raise RuntimeError("track_daily_usage is not enabled")
if not end:
end = datetime.date.today()
days = []
day = start
while day <= end:
days.append(day)
day += datetime.timedelta(days=1)
day_keys = [f"{self.prefix}:{key}:d{day.strftime('%Y%m%d')}" for day in days]
return [
DailyUsage(d, int(calls.decode()) if calls else 0)
for d, calls in zip(days, self.redis.mget(day_keys))
]
55 changes: 34 additions & 21 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "openstates-api"
version = "3.0.0"
version = "3.0.1 "
description = "Open States API v3"
authors = ["James Turk <dev@jamesturk.net>"]
license = "MIT"
Expand All @@ -15,7 +15,7 @@ gunicorn = "^20.0.4"
sentry-sdk = "^1.0.0"
pybase62 = "^0.4.3"
python-slugify = "^4.0.1"
rrl = "^0.3.1"
redis = "^6.4.0"
prometheus-fastapi-instrumentator = "^5.8.2"
fastapi = {extras = ["all"], version = "^0.87.0"}

Expand Down