Skip to content

Commit 4c83d48

Browse files
authored
Merge pull request #122 from ComplexData-MILA/main
Adding Main Changes
2 parents c17f83a + f86e890 commit 4c83d48

6 files changed

Lines changed: 68 additions & 7 deletions

File tree

app/api/endpoints/claim_endpoints.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from app.services.analysis_orchestrator import AnalysisOrchestrator
2626
from app.core.exceptions import NotFoundException, NotAuthorizedException
2727
from app.services.interfaces.embedding_generator import EmbeddingGeneratorInterface
28+
from app.core.exceptions import MonthlyLimitExceededError
2829

2930
router = APIRouter(prefix="/claims", tags=["claims"])
3031
logger = logging.getLogger(__name__)
@@ -45,8 +46,12 @@ async def create_claim(
4546
language=data.language,
4647
batch_user_id=data.batch_user_id,
4748
batch_post_id=data.batch_post_id,
49+
auth0_id=current_user.auth0_id,
4850
)
4951
return ClaimRead.model_validate(claim)
52+
except MonthlyLimitExceededError:
53+
# We don't have 'e.limit' anymore, so we just say "Limit reached"
54+
raise HTTPException(status_code=429, detail="You have reached your monthly claim limit.")
5055
except Exception as e:
5156
raise HTTPException(
5257
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create claim: {str(e)}"
@@ -65,12 +70,19 @@ async def create_claims_batch(
6570
raise HTTPException(status_code=400, detail="Maximum of 100 claims allowed.")
6671

6772
try:
68-
created_claims = await claim_service.create_claims_batch(claims, current_user.id)
73+
created_claims = await claim_service.create_claims_batch(
74+
claims,
75+
current_user.id,
76+
auth0_id=current_user.auth0_id,
77+
)
6978
claim_ids = [str(claim.id) for claim in created_claims]
7079
background_tasks.add_task(
7180
claim_service.process_claims_batch_async, created_claims, current_user.id, analysis_orchestrator
7281
)
7382
return {"message": f"Processing {len(created_claims)} claims in the background.", "claim_ids": claim_ids}
83+
except MonthlyLimitExceededError:
84+
# We don't have 'e.limit' anymore, so we just say "Limit reached"
85+
raise HTTPException(status_code=429, detail="You have reached your monthly claim limit.")
7486
except Exception as e:
7587
raise HTTPException(
7688
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to queue batch: {str(e)}"

app/core/exceptions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ class DuplicateUserError(Exception):
3838
pass
3939

4040

41+
"""
42+
Claim exceptions
43+
"""
44+
45+
46+
class MonthlyLimitExceededError(Exception):
47+
"""Raised when a user hits their monthly claim limit."""
48+
49+
pass
50+
51+
4152
"""
4253
Feedback exceptions
4354
"""

app/models/domain/analysis.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from datetime import datetime
3-
from typing import Optional, List, Dict
3+
from typing import Optional, List
44
from uuid import UUID
55
import pickle
66

@@ -11,10 +11,8 @@
1111

1212
@dataclass
1313
class LogProbsData:
14-
anth_conf_score: float
1514
tokens: List[str]
1615
probs: List[float]
17-
alternatives: List[Dict[str, float]]
1816

1917

2018
@dataclass

app/repositories/implementations/claim_repository.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from uuid import UUID
44
from sqlalchemy import select, func, and_
55
from sqlalchemy.ext.asyncio import AsyncSession
6-
from datetime import datetime
6+
from datetime import datetime, UTC
77

88
from app.models.database.models import ClaimModel, ClaimStatus
99
from app.models.domain.claim import Claim
@@ -94,6 +94,22 @@ async def get_claims_in_date_range(self, start_date: datetime, end_date: datetim
9494
result = await self._session.execute(stmt)
9595
return [self._to_domain(claim) for claim in result.scalars().all()]
9696

97+
async def get_monthly_claim_count(self, user_id: str) -> int:
98+
"""Counts how many claims a user has created this month."""
99+
100+
# Calculate the 1st of the current month
101+
now = datetime.now(UTC)
102+
start_of_month = now.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
103+
104+
query = (
105+
select(func.count())
106+
.select_from(ClaimModel)
107+
.where(ClaimModel.user_id == user_id, ClaimModel.created_at >= start_of_month)
108+
)
109+
110+
result = await self._session.execute(query)
111+
return result.scalar_one()
112+
97113
async def insert_many(self, claim: List[Claim]) -> List[Claim]:
98114
models = [self._to_model(claim) for claim in claim]
99115
self._session.add_all(models)

app/services/analysis_orchestrator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,8 @@ async def _generate_analysis(
248248
logger.info(con_score)
249249
current_analysis.confidence_score = float(con_score)
250250
# log_data = await self._get_anth_confidence_score(statement=claim_text, veracity_score=veracity_score)
251-
# current_analysis.log_probs = log_data
251+
log_data = LogProbsData(tokens=analysis_text, probs=log_probs)
252+
current_analysis.log_probs = log_data
252253

253254
updated_analysis = await self._analysis_repo.update(current_analysis)
254255

app/services/claim_service.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from app.repositories.implementations.claim_repository import ClaimRepository
1818
from app.repositories.implementations.analysis_repository import AnalysisRepository
1919
from app.services.analysis_orchestrator import AnalysisOrchestrator
20+
from app.core.exceptions import MonthlyLimitExceededError
2021

2122
from app.core.exceptions import NotFoundException, NotAuthorizedException
2223

@@ -31,6 +32,9 @@
3132
logger = logging.getLogger(__name__)
3233
executor = ThreadPoolExecutor(max_workers=1)
3334

35+
RESTRICTED_CLIENT_ID = "hHRhJr5OoJhWumP87MHk5RldejycVAmC@clients"
36+
MONTHLY_LIMIT = 3000
37+
3438

3539
class ClaimService:
3640
def __init__(self, claim_repository: ClaimRepository, analysis_repository: AnalysisRepository):
@@ -45,9 +49,17 @@ async def create_claim(
4549
language: str,
4650
batch_user_id: str = None,
4751
batch_post_id: str = None,
52+
auth0_id: str = None,
4853
) -> Claim:
4954
"""Create a new claim."""
5055
now = datetime.now(UTC)
56+
57+
if auth0_id is not None:
58+
if auth0_id == RESTRICTED_CLIENT_ID:
59+
current_count = await self._claim_repo.get_monthly_claim_count(user_id)
60+
61+
if current_count >= MONTHLY_LIMIT:
62+
raise MonthlyLimitExceededError()
5163
claim = Claim(
5264
id=uuid4(),
5365
user_id=user_id,
@@ -234,9 +246,20 @@ def _heavy_clustering_math(claims, num_clusters):
234246
result = await loop.run_in_executor(executor, _heavy_clustering_math, claims, num_clusters)
235247
return result
236248

237-
async def create_claims_batch(self, claims: List[Claim], user_id: str) -> List[Claim]:
249+
async def create_claims_batch(
250+
self,
251+
claims: List[Claim],
252+
user_id: str,
253+
auth0_id: str = None,
254+
) -> List[Claim]:
238255
# Map ClaimCreate + user_id → Claim DB objects
239256
now = datetime.now(UTC)
257+
if auth0_id is not None:
258+
if auth0_id == RESTRICTED_CLIENT_ID:
259+
current_count = await self._claim_repo.get_monthly_claim_count(user_id)
260+
261+
if current_count + len(claims) >= MONTHLY_LIMIT:
262+
raise MonthlyLimitExceededError()
240263
claim_models = [
241264
Claim(
242265
id=uuid4(),

0 commit comments

Comments
 (0)