Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions backend/app/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.pool import NullPool
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

Expand Down Expand Up @@ -54,9 +53,8 @@
)

from sqlalchemy import create_engine

# sync engine to populate database
SYNC_DATABASE_URL = DATABASE_URL.replace("+asyncpg", "")
#sync engine to populate database
SYNC_DATABASE_URL = DATABASE_URL.replace('+asyncpg','')
sync_engine = create_engine(
SYNC_DATABASE_URL,
echo=DB_ECHO,
Expand Down
8 changes: 6 additions & 2 deletions backend/app/scripts/reset_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os # <--- 1. Was missing
from dotenv import load_dotenv
from sqlalchemy.ext.asyncio import create_async_engine

from sqlalchemy import text
# 3. Correct Import: 'Bet', not 'Bets'
from ..db.models import Base, Tournament, Agent, AgentState, Trade, Bet, PlanItem, AgentResearchArtifact

Expand All @@ -29,8 +29,12 @@ async def reset_database():

async with engine.begin() as conn:
print("🔥 Dropping all tables...")
await conn.run_sync(Base.metadata.drop_all)
await conn.execute(text("DROP SCHEMA public CASCADE"))
await conn.execute(text("CREATE SCHEMA public"))

await conn.execute(text("GRANT ALL ON SCHEMA public TO neondb_owner"))
await conn.execute(text("GRANT ALL ON SCHEMA public TO public"))

print("🏗️ Creating new tables...")
await conn.run_sync(Base.metadata.create_all)

Expand Down
93 changes: 60 additions & 33 deletions backend/app/scripts/seed_db.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
# backend/app/scripts/seed_db.py
import asyncio
import os
from uuid import uuid4
from datetime import datetime, timezone, timedelta
from uuid import UUID, uuid4
from datetime import datetime, timedelta, timezone
from decimal import Decimal

from dotenv import load_dotenv
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker

from ..db.models import Tournament, Agent, Trade, Bet, StatusEnum, ActionEnum
from ..db.database import sync_engine, engine
from ..db.models import Tournament, Agent, Trade, Bet, StatusEnum, ActionEnum, AgentState

load_dotenv()

Expand All @@ -20,22 +15,16 @@
async def seed_database():
"""Seed the database with test data"""

connect_args = {} if DB_DISABLE_SSL else {"ssl": "require"}
engine = create_async_engine(DATABASE_URL, connect_args=connect_args)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)

async with async_session() as session:
now = datetime.now(timezone.utc)

with Session(sync_engine) as session:
# Create Tournaments
tournament1 = Tournament(
id=uuid4(),
name="Crypto Trading Championship #1",
name="Q4 2025 Championship",
status=StatusEnum.live,
start_date=now,
end_date=now + timedelta(days=30),
prize_pool=Decimal("10000.00"),
created_at=now,
created_at=datetime.now(timezone.utc),
)

tournament2 = Tournament(
Expand All @@ -45,7 +34,7 @@ async def seed_database():
start_date=now + timedelta(days=7),
end_date=now + timedelta(days=37),
prize_pool=Decimal("5000.00"),
created_at=now,
created_at=datetime.now(timezone.utc),
)

session.add(tournament1)
Expand All @@ -58,9 +47,9 @@ async def seed_database():
personality="aggressive",
strategy_type="momentum",
avatar_url="https://example.com/avatar1.png",
stats={},
memory={},
created_at=now,
stats={"win_rate": 0.65, "total_trades": 150},
memory={"last_analysis": "Bullish on tech stocks"},
created_at=datetime.now(timezone.utc),
)

agent2 = Agent(
Expand All @@ -69,9 +58,9 @@ async def seed_database():
personality="conservative",
strategy_type="value",
avatar_url="https://example.com/avatar2.png",
stats={},
memory={},
created_at=now,
stats={"win_rate": 0.58, "total_trades": 200},
memory={"last_analysis": "Focus on fundamentals"},
created_at=datetime.now(timezone.utc),
)

agent3 = Agent(
Expand All @@ -80,9 +69,9 @@ async def seed_database():
personality="balanced",
strategy_type="quantitative",
avatar_url="https://example.com/avatar3.png",
stats={},
memory={},
created_at=now,
stats={"win_rate": 0.72, "total_trades": 500},
memory={"last_analysis": "Pattern detected in BTC"},
created_at=datetime.now(timezone.utc),
)

session.add(agent1)
Expand Down Expand Up @@ -129,6 +118,44 @@ async def seed_database():
session.add(agent_state3)
session.commit()

agent_state1 = AgentState(
agent_id=agent1.id,
tournament_id=tournament1.id,
portfolio={"USD": 10000.0}, # Starting cash
portfolio_value_usd=Decimal("10000.00"),
rank=1,
trades_count=0,
last_decision="Initial state",
updated_at=datetime.now(timezone.utc),
)

agent_state2 = AgentState(
agent_id=agent2.id,
tournament_id=tournament1.id,
portfolio={"USD": 10000.0},
portfolio_value_usd=Decimal("10000.00"),
rank=2,
trades_count=0,
last_decision="Initial state",
updated_at=datetime.now(timezone.utc),
)

agent_state3 = AgentState(
agent_id=agent3.id,
tournament_id=tournament1.id,
portfolio={"USD": 10000.0},
portfolio_value_usd=Decimal("10000.00"),
rank=3,
trades_count=0,
last_decision="Initial state",
updated_at=datetime.now(timezone.utc),
)

session.add(agent_state1)
session.add(agent_state2)
session.add(agent_state3)
session.commit()

# Create Trades
trade1 = Trade(
id=uuid4(),
Expand All @@ -138,7 +165,7 @@ async def seed_database():
asset="BTC",
amount=Decimal("0.5"),
price=Decimal("45000.00"),
timestamp=now,
timestamp=datetime.now(timezone.utc),
)

trade2 = Trade(
Expand All @@ -149,7 +176,7 @@ async def seed_database():
asset="ETH",
amount=Decimal("5.0"),
price=Decimal("3000.00"),
timestamp=now,
timestamp=datetime.now(timezone.utc),
)

trade3 = Trade(
Expand All @@ -160,7 +187,7 @@ async def seed_database():
asset="BTC",
amount=Decimal("0.25"),
price=Decimal("46000.00"),
timestamp=now,
timestamp=datetime.now(timezone.utc),
)

session.add(trade1)
Expand All @@ -175,7 +202,7 @@ async def seed_database():
tournament_id=tournament1.id,
amount=Decimal("100.00"),
odds=Decimal("2.5"),
placed_at=now,
placed_at=datetime.now(timezone.utc),
settled=False,
)

Expand All @@ -186,7 +213,7 @@ async def seed_database():
tournament_id=tournament1.id,
amount=Decimal("250.00"),
odds=Decimal("3.0"),
placed_at=now,
placed_at=datetime.now(timezone.utc),
settled=False,
)

Expand Down
Loading