Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/backend/clara/agents/design_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ async def pre_tool_hook(
context: Any
) -> dict[str, Any]:
"""Track tool usage before execution."""
import json
# Log full input_data to understand structure
logger.info(f"[PreToolUse] input_data keys: {list(input_data.keys())}")
logger.info(f"[PreToolUse] full input_data: {input_data}")
Expand Down Expand Up @@ -672,9 +671,10 @@ async def restore_session(
]
session.state.blueprint_preview.agent_count = len(blueprint_state.get("agents", []))

# Restore goal summary
# Restore goal summary (handle both goal_text and primary_goal keys)
if db_session.goal_summary:
session.state.goal_summary = db_session.goal_summary.get("goal_text")
goal = db_session.goal_summary
session.state.goal_summary = goal.get("goal_text") or goal.get("primary_goal")

# Restore agent capabilities
if db_session.agent_capabilities:
Expand Down
111 changes: 32 additions & 79 deletions src/backend/clara/api/context_files.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Context Files API endpoints.

Provides file upload/download/delete endpoints for agent context files.
Files are sandboxed per project and validated for security.
Files are linked to InterviewAgent (canonical source of truth) and sandboxed per project.
"""

import logging
Expand All @@ -13,7 +13,7 @@
from sqlalchemy.ext.asyncio import AsyncSession

from clara.config import settings
from clara.db.models import AgentContextFile, ContextFileStatus, DesignSession
from clara.db.models import AgentContextFile, ContextFileStatus, InterviewAgent
from clara.db.session import get_db
from clara.services.file_service import FileUploadService

Expand Down Expand Up @@ -49,10 +49,9 @@ class UploadResponse(BaseModel):
error: str | None = None


@router.post("/sessions/{session_id}/agents/{agent_index}/upload", response_model=UploadResponse)
@router.post("/agents/{agent_id}/upload", response_model=UploadResponse)
async def upload_context_file(
session_id: str,
agent_index: int,
agent_id: str,
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db)
) -> UploadResponse:
Expand All @@ -65,29 +64,20 @@ async def upload_context_file(

The file content is extracted for use in agent context.
"""
# Verify session exists and get project_id
# Verify agent exists and get project_id
result = await db.execute(
select(DesignSession).where(DesignSession.id == session_id)
select(InterviewAgent).where(InterviewAgent.id == agent_id)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
agent = result.scalar_one_or_none()
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")

project_id = session.project_id

# Check agent index is valid
agents = session.blueprint_state.get("agents", [])
if agent_index < 0 or agent_index >= len(agents):
raise HTTPException(
status_code=400,
detail=f"Invalid agent index {agent_index}. Session has {len(agents)} agents."
)
project_id = agent.project_id

# Check file count limit
count_result = await db.execute(
select(func.count(AgentContextFile.id))
.where(AgentContextFile.session_id == session_id)
.where(AgentContextFile.agent_index == agent_index)
.where(AgentContextFile.agent_id == agent_id)
.where(AgentContextFile.deleted_at.is_(None))
)
current_count = count_result.scalar() or 0
Expand All @@ -109,7 +99,7 @@ async def upload_context_file(
file_content=content,
filename=file.filename or "unnamed",
project_id=project_id,
agent_index=agent_index
agent_index=0 # Not used for path, keeping for compatibility
)

if not upload_result.success:
Expand All @@ -120,9 +110,7 @@ async def upload_context_file(
ext = "." + filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
context_file = AgentContextFile(
id=upload_result.file_id,
session_id=session_id,
project_id=project_id,
agent_index=agent_index,
agent_id=agent_id,
original_filename=file.filename or "unnamed",
stored_filename=upload_result.stored_filename,
file_extension=ext,
Expand All @@ -137,23 +125,6 @@ async def upload_context_file(
db.add(context_file)
await db.commit()

# Also update the agent's context_files in blueprint_state
agents = list(session.blueprint_state.get("agents", []))
if agent_index < len(agents):
agent = agents[agent_index]
context_files = agent.get("context_files", [])
context_files.append({
"id": upload_result.file_id,
"name": file.filename or "unnamed",
"type": upload_result.mime_type,
"size": upload_result.file_size,
"uploaded_at": datetime.now(UTC).isoformat(),
})
agent["context_files"] = context_files
agents[agent_index] = agent
session.blueprint_state = {**session.blueprint_state, "agents": agents}
await db.commit()

return UploadResponse(
success=True,
file=ContextFileResponse(
Expand All @@ -168,26 +139,24 @@ async def upload_context_file(
)


@router.get("/sessions/{session_id}/agents/{agent_index}", response_model=ContextFileListResponse)
@router.get("/agents/{agent_id}", response_model=ContextFileListResponse)
async def list_context_files(
session_id: str,
agent_index: int,
agent_id: str,
db: AsyncSession = Depends(get_db)
) -> ContextFileListResponse:
"""List all context files for an agent."""
# Verify session exists
# Verify agent exists
result = await db.execute(
select(DesignSession).where(DesignSession.id == session_id)
select(InterviewAgent).where(InterviewAgent.id == agent_id)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
agent = result.scalar_one_or_none()
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")

# Get files
result = await db.execute(
select(AgentContextFile)
.where(AgentContextFile.session_id == session_id)
.where(AgentContextFile.agent_index == agent_index)
.where(AgentContextFile.agent_id == agent_id)
.where(AgentContextFile.deleted_at.is_(None))
.order_by(AgentContextFile.created_at.desc())
)
Expand All @@ -210,28 +179,26 @@ async def list_context_files(
)


@router.delete("/sessions/{session_id}/agents/{agent_index}/files/{file_id}")
@router.delete("/agents/{agent_id}/files/{file_id}")
async def delete_context_file(
session_id: str,
agent_index: int,
agent_id: str,
file_id: str,
db: AsyncSession = Depends(get_db)
):
"""Delete a context file (soft delete)."""
# Verify session exists
# Verify agent exists
result = await db.execute(
select(DesignSession).where(DesignSession.id == session_id)
select(InterviewAgent).where(InterviewAgent.id == agent_id)
)
session = result.scalar_one_or_none()
if not session:
raise HTTPException(status_code=404, detail="Session not found")
agent = result.scalar_one_or_none()
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")

# Find file
result = await db.execute(
select(AgentContextFile)
.where(AgentContextFile.id == file_id)
.where(AgentContextFile.session_id == session_id)
.where(AgentContextFile.agent_index == agent_index)
.where(AgentContextFile.agent_id == agent_id)
)
context_file = result.scalar_one_or_none()
if not context_file:
Expand All @@ -242,27 +209,14 @@ async def delete_context_file(
context_file.status = ContextFileStatus.FAILED.value
context_file.status_message = "Deleted by user"

# Also remove from blueprint_state
agents = list(session.blueprint_state.get("agents", []))
if agent_index < len(agents):
agent = agents[agent_index]
context_files = [
f for f in agent.get("context_files", [])
if f.get("id") != file_id
]
agent["context_files"] = context_files
agents[agent_index] = agent
session.blueprint_state = {**session.blueprint_state, "agents": agents}

await db.commit()

return {"status": "deleted", "file_id": file_id}


@router.get("/sessions/{session_id}/agents/{agent_index}/files/{file_id}/content")
@router.get("/agents/{agent_id}/files/{file_id}/content")
async def get_extracted_content(
session_id: str,
agent_index: int,
agent_id: str,
file_id: str,
db: AsyncSession = Depends(get_db)
):
Expand All @@ -271,8 +225,7 @@ async def get_extracted_content(
result = await db.execute(
select(AgentContextFile)
.where(AgentContextFile.id == file_id)
.where(AgentContextFile.session_id == session_id)
.where(AgentContextFile.agent_index == agent_index)
.where(AgentContextFile.agent_id == agent_id)
.where(AgentContextFile.deleted_at.is_(None))
)
context_file = result.scalar_one_or_none()
Expand Down
Loading