Skip to content

Commit 88694b4

Browse files
authored
Merge pull request #82 from pattern-tech/refactor/ai-v2
feat: enhance user whitelist management and automatic entry creation
2 parents 6b2d74b + 42e64e5 commit 88694b4

File tree

7 files changed

+120
-28
lines changed

7 files changed

+120
-28
lines changed

src/auth/services/auth_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def register(self, input: RegisterInput, db: Session) -> str:
8181
existing_user = db.query(UserModel).filter_by(
8282
wallet_address=input.wallet_address).first()
8383
if existing_user:
84-
raise AlreadyExistsError("User already exists")
84+
raise AlreadyExistsError("User already exists with this wallet address")
8585

8686
# # Create a new user record
8787
if input.email and input.password:

src/conversation/services/conversation_service.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ class ConversationService:
3030
def __init__(self):
3131
self.repository = ConversationRepository()
3232
self.project_repository = ProjectRepository()
33+
34+
self.user_service = UserService()
3335
self.memory_service = MemoryService()
3436
self.project_service = ProjectService()
3537
self.query_usage_service = QueryUsageService()
36-
self.user_service = UserService()
3738

3839
def create_conversation(
3940
self, db_session: Session, name: str, project_id: UUID, user_id: UUID, conversation_id: UUID = None

src/query_usage/repositories/query_usage_repository.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,25 @@ def get_user_query_count_for_today(self, db_session: Session, user_id: UUID, pro
140140

141141
return result_count, result_oldest, next_reset
142142

143+
def get_user_whitelist_allowance(self, db_session: Session, user_id: UUID) -> Optional[int]:
144+
"""
145+
Retrieves the user's whitelist max_query allowance directly from the database.
146+
147+
Args:
148+
db_session (Session): The database session to use.
149+
user_id (UUID): The unique identifier of the user.
150+
151+
Returns:
152+
Optional[int]: The user's max query allowance from the whitelist, or None if not whitelisted.
153+
"""
154+
from src.db.models import WhiteList
155+
156+
whitelist_entry = db_session.query(WhiteList.max_query)\
157+
.filter(WhiteList.user_id == user_id)\
158+
.first()
159+
160+
return whitelist_entry.max_query if whitelist_entry else None
161+
143162
def update(self, db_session, id: UUID) -> None:
144163
pass
145164

src/query_usage/routers/query_usage_router.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class TodayQueryUsage(BaseModel):
5959
today_query_count: int = Field(..., example=2)
6060
remaining_query_allowance: int = Field(..., example=3)
6161
max_query_allowance_per_day: int = Field(..., example=5)
62+
# Base allowance from whitelist
63+
whitelist_allowance: int = Field(..., example=20)
64+
# Additional allowance from staked tokens
65+
token_based_allowance: int = Field(..., example=5)
6266
next_reset_time: datetime = Field(..., example="2025-03-15T15:30:20+03:30")
6367

6468

@@ -142,14 +146,23 @@ def get_user_daily_query_usages(
142146
dict: A dictionary containing:
143147
- today_query_count: Number of queries used today
144148
- max_query_allowance_per_day: Total number of allowed queries per day
149+
- whitelist_allowance: Base query allowance from whitelist (default 20)
150+
- token_based_allowance: Additional allowance from staked tokens
145151
- remaining_queries_today: Number of remaining queries for today
146152
- next_reset_time: The datetime when the query usage will reset to zero
147153
"""
148154
try:
149-
# Get the max query allowance for this user
150-
max_query_allowance = service.get_user_max_query_allowance(db, user_id)
155+
# Get the user's base whitelist allowance
156+
whitelist_allowance = service.get_user_whitelist_allowance(db, user_id)
151157

152-
# Get today's query count using the new repository method
158+
# Get the token-based allowance using our new dedicated method
159+
token_based_allowance = service.get_user_token_based_allowance(
160+
db, user_id)
161+
162+
# Calculate total max query allowance (whitelist + token-based)
163+
max_query_allowance = whitelist_allowance + token_based_allowance
164+
165+
# Get today's query count
153166
today_query_count, _, next_reset_time = service.get_user_query_count_for_today(
154167
db, user_id, provider)
155168

src/query_usage/services/query_usage_service.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -77,45 +77,76 @@ def get_usage_setting(self, db_session: Session) -> List[QueryUsage]:
7777
def get_user_max_query_allowance(self, db_session: Session, user_id: UUID) -> int:
7878
"""
7979
Retrieves the maximum number of queries a user is allowed to make.
80+
This is calculated as the sum of:
81+
1. User's whitelist allowance
82+
2. Allowance based on staked Morpheus tokens
8083
8184
Args:
8285
db_session (Session): The database session.
8386
user_id (UUID): The ID of the user.
8487
8588
Returns:
8689
int: The maximum number of queries the user is allowed to make.
87-
88-
Raises:
89-
Exception: If the user has not staked any Morpheus tokens.
9090
"""
9191
user = self.user_service.get_user(db_session, user_id)
9292

93-
whitelist = self.user_service.get_whitelist(db_session)
93+
# Get the user's base whitelist allowance using the dedicated method
94+
whitelist_allowance = self.get_user_whitelist_allowance(
95+
db_session, user_id)
96+
97+
# Calculate additional allowance based on staked tokens
98+
token_based_allowance = self.get_user_token_based_allowance(
99+
db_session, user_id)
100+
101+
# Total allowance is the sum of whitelist allowance and token-based allowance
102+
max_allowed_query = whitelist_allowance + token_based_allowance
103+
104+
return max_allowed_query
105+
106+
def get_user_token_based_allowance(self, db_session: Session, user_id: UUID) -> int:
107+
"""
108+
Calculates a user's token-based query allowance based on their staked MOR tokens.
94109
95-
# check user payment
96-
for wl in whitelist:
97-
if str(user_id) == str(wl.user_id):
98-
max_allowed_query = wl.max_query
99-
return max_allowed_query
110+
Args:
111+
db_session (Session): The database session.
112+
user_id (UUID): The ID of the user.
100113
114+
Returns:
115+
int: The token-based query allowance (0 if no tokens are staked)
116+
"""
117+
user = self.user_service.get_user(db_session, user_id)
118+
119+
token_based_allowance = 0
101120
staked_morpheus = get_user_staked_tokens(
102121
wallet_address=user.wallet_address.lower(), provider="morpheus")
103122

104-
if staked_morpheus == 0:
105-
raise RateLimitError(
106-
"You need to stake Morpheus tokens to use this service")
123+
if staked_morpheus > 0:
124+
usage_setting = self.get_usage_setting(db_session)
125+
for setting in usage_setting:
126+
if setting.provider == "morpheus":
127+
token_based_allowance = setting.max_query * \
128+
int(int(staked_morpheus) / 1e18)
129+
break
107130

108-
usage_setting = self.get_usage_setting(
109-
db_session)
131+
return token_based_allowance
110132

111-
max_allowed_query = 0
112-
for setting in usage_setting:
113-
if setting.provider == "morpheus":
114-
max_allowed_query = setting.max_query * \
115-
int(int(staked_morpheus) / 1e18)
116-
break
133+
def get_user_whitelist_allowance(self, db_session: Session, user_id: UUID) -> int:
134+
"""
135+
Retrieves the base query allowance for a user from the whitelist using a direct database query.
117136
118-
return max_allowed_query
137+
Args:
138+
db_session (Session): The database session.
139+
user_id (UUID): The ID of the user.
140+
141+
Returns:
142+
int: The base query allowance from the whitelist (default: 0 if not whitelisted)
143+
"""
144+
# Use the repository to directly query the database for whitelist allowance
145+
allowance = self.repository.get_user_whitelist_allowance(
146+
db_session, user_id)
147+
148+
# Return the allowance if found, otherwise return 0
149+
return allowance if allowance is not None else 0
119150

120151
def get_user_query_count_for_today(self, db_session: Session, user_id: UUID, provider: Optional[str] = None) -> Tuple[int, Optional[datetime]]:
121152
"""

src/user/repositories/user_repository.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,21 @@ def get_whitelist(self, db_session: Session) -> list[WhiteList]:
122122
List[WhiteList]: A list of all whitelisted users.
123123
"""
124124
return db_session.query(WhiteList).all()
125+
126+
def create_whitelist_entry(self, db_session: Session, user_id: UUID, max_query: int = 20) -> WhiteList:
127+
"""
128+
Creates a whitelist entry for a user with a default max query limit of 20.
129+
130+
Args:
131+
db_session (Session): The database session to use.
132+
user_id (UUID): The unique identifier of the user to whitelist.
133+
max_query (int): The maximum number of queries allowed for this user.
134+
135+
Returns:
136+
WhiteList: The created whitelist entry.
137+
"""
138+
whitelist_entry = WhiteList(user_id=user_id, max_query=max_query)
139+
db_session.add(whitelist_entry)
140+
db_session.commit()
141+
db_session.refresh(whitelist_entry)
142+
return whitelist_entry

src/user/services/user_service.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import List
33
from sqlalchemy.orm import Session
44

5-
from src.db.models import UserModel, WhiteList
5+
from src.db.models import UserModel, WhiteList, UsageSetting
66
from src.user.repositories.user_repository import UserRepository
77

88

@@ -13,7 +13,7 @@ def __init__(self):
1313
def create_user(self, db_session: Session, wallet_address: str, chain_id: int, email: str = None, password: str = None,
1414
) -> UserModel:
1515
"""
16-
Creates a new user.
16+
Creates a new user and automatically whitelists them with the default max query value from database.
1717
1818
Args:
1919
db_session (Session): The database session.
@@ -31,6 +31,16 @@ def create_user(self, db_session: Session, wallet_address: str, chain_id: int, e
3131
chain_id=chain_id, email=email, password=password)
3232
user = self.repository.create(db_session, _user)
3333

34+
default_max_query = 0
35+
usage_settings = db_session.query(UsageSetting).filter(
36+
UsageSetting.provider == "pattern").first()
37+
if usage_settings:
38+
default_max_query = usage_settings.max_query
39+
40+
# Automatically whitelist the user with the default query credits from database
41+
self.repository.create_whitelist_entry(
42+
db_session, user.id, max_query=default_max_query)
43+
3444
return user
3545

3646
def get_user(self, db_session: Session, user_id: UUID) -> UserModel:

0 commit comments

Comments
 (0)