Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 51 additions & 47 deletions assistants/knowledge-transfer-assistant/assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import asyncio
import pathlib
from enum import Enum
from typing import Any

from assistant_extensions import attachments, dashboard_card, navigator
Expand Down Expand Up @@ -34,7 +33,7 @@
load_text_include,
)

from .common import detect_assistant_role
from .common import detect_assistant_role, detect_conversation_type, get_shared_conversation_id, ConversationType
from .config import assistant_config
from .conversation_share_link import ConversationKnowledgePackageManager
from .data import InspectorTab, LogEntryType
Expand Down Expand Up @@ -98,13 +97,6 @@ async def content_evaluator_factory(

app = assistant.fastapi_app()


class ConversationType(Enum):
COORDINATOR = "coordinator"
TEAM = "team"
SHAREABLE_TEMPLATE = "shareable_template"


@assistant.events.conversation.on_created_including_mine
async def on_conversation_created(context: ConversationContext) -> None:
"""
Expand All @@ -113,50 +105,26 @@ async def on_conversation_created(context: ConversationContext) -> None:
2. Shareable Team Conversation: A template conversation that has a share URL and is never directly used
3. Team Conversation(s): Individual conversations for team members created when they redeem the share URL
"""
# Get conversation to access metadata

conversation = await context.get_conversation()
conversation_metadata = conversation.metadata or {}
share_id = conversation_metadata.get("share_id")

config = await assistant_config.get(context.assistant)
conversation_type = detect_conversation_type(conversation)

##
## Figure out what type of conversation this is.
##

conversation_type = ConversationType.COORDINATOR

# Coordinator conversations will not have a share_id or
# is_team_conversation flag in the metadata. So, if they are there, we just
# need to decide if it's a shareable template or a team conversation.
share_id = conversation_metadata.get("share_id")
if conversation_metadata.get("is_team_conversation", False) and share_id:
# If this conversation was imported from another, it indicates it's from
# share redemption.
if conversation.imported_from_conversation_id:
conversation_type = ConversationType.TEAM
# TODO: This might work better for detecting a redeemed link, but
# hasn't been validated.

# if conversation_metadata.get("share_redemption") and conversation_metadata.get("share_redemption").get(
# "conversation_share_id"
# ):
# conversation_type = ConversationType.TEAM
else:
conversation_type = ConversationType.SHAREABLE_TEMPLATE

##
## Handle the conversation based on its type
##
match conversation_type:
case ConversationType.SHAREABLE_TEMPLATE:

# Associate the shareable template with a share ID
if not share_id:
logger.error("No share ID found for shareable team conversation.")
return

await ConversationKnowledgePackageManager.associate_conversation_with_share(context, share_id)
return

case ConversationType.TEAM:

if not share_id:
logger.error("No share ID found for team conversation.")
return
Expand All @@ -170,13 +138,9 @@ async def on_conversation_created(context: ConversationContext) -> None:
)

await ConversationKnowledgePackageManager.associate_conversation_with_share(context, share_id)
# Set the conversation role for team conversations
await ConversationKnowledgePackageManager.set_conversation_role(context, share_id, ConversationRole.TEAM)

# Synchronize files.
await ShareManager.synchronize_files_to_team_conversation(context=context, share_id=share_id)

# Generate a welcome message.
welcome_message, debug = await generate_team_welcome_message(context)
await context.send_messages(
NewConversationMessage(
Expand All @@ -202,11 +166,10 @@ async def on_conversation_created(context: ConversationContext) -> None:

case ConversationType.COORDINATOR:
try:
# In the beginning, we created a share...
share_id = await KnowledgeTransferManager.create_share(context)

# No default brief - let the state inspector handle displaying instructional content

# Create a team conversation with a share URL
# And it was good. So we then created a sharable conversation that we use as a template.
share_url = await KnowledgeTransferManager.create_shareable_team_conversation(
context=context, share_id=share_id
)
Expand All @@ -218,14 +181,52 @@ async def on_conversation_created(context: ConversationContext) -> None:
except Exception as e:
welcome_message = f"I'm having trouble setting up your knowledge transfer. Please try again or contact support if the issue persists. {str(e)}"

# Send the welcome message
await context.send_messages(
NewConversationMessage(
content=welcome_message,
message_type=MessageType.chat,
)
)

# Pop open the inspector panel.
await context.send_conversation_state_event(
AssistantStateEvent(
state_id="brief",
event="focus",
state=None,
)
)

@assistant.events.conversation.on_updated
async def on_conversation_updated(context: ConversationContext) -> None:
"""
Handle conversation updates (including title changes) and sync with shareable template.
"""
try:
conversation = await context.get_conversation()
conversation_type = detect_conversation_type(conversation)
if conversation_type != ConversationType.COORDINATOR:
return

shared_conversation_id = await get_shared_conversation_id(context)
if not shared_conversation_id:
return

# Update the shareable template conversation's title if needed.
try:
target_context = context.for_conversation(shared_conversation_id)
target_conversation = await target_context.get_conversation()
if target_conversation.title != conversation.title:
await target_context.update_conversation_title(conversation.title)
logger.debug(f"Updated conversation {shared_conversation_id} title from '{target_conversation.title}' to '{conversation.title}'")
else:
logger.debug(f"Conversation {shared_conversation_id} title already matches: '{conversation.title}'")
except Exception as title_update_error:
logger.error(f"Error updating conversation {shared_conversation_id} title: {title_update_error}")

except Exception as e:
logger.error(f"Error syncing conversation title: {e}")


@assistant.events.conversation.message.chat.on_created
async def on_message_created(
Expand Down Expand Up @@ -545,3 +546,6 @@ async def on_participant_joined(

except Exception as e:
logger.exception(f"Error handling participant join event: {e}")



57 changes: 57 additions & 0 deletions assistants/knowledge-transfer-assistant/assistant/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
helping to reduce code duplication and maintain consistency.
"""

from enum import Enum
from typing import Dict, Optional

from semantic_workbench_assistant.assistant_app import ConversationContext
Expand All @@ -14,7 +15,35 @@
from .logging import logger
from .storage import ShareStorage
from .storage_models import ConversationRole
from semantic_workbench_api_model.workbench_model import Conversation

class ConversationType(Enum):
COORDINATOR = "coordinator"
TEAM = "team"
SHAREABLE_TEMPLATE = "shareable_template"

def detect_conversation_type(conversation: Conversation) -> ConversationType:
conversation_metadata = conversation.metadata or {}
conversation_type = ConversationType.COORDINATOR
# Coordinator conversations will not have a share_id or
# is_team_conversation flag in the metadata. So, if they are there, we just
# need to decide if it's a shareable template or a team conversation.
share_id = conversation_metadata.get("share_id")
if conversation_metadata.get("is_team_conversation", False) and share_id:
# If this conversation was imported from another, it indicates it's from
# share redemption.
if conversation.imported_from_conversation_id:
conversation_type = ConversationType.TEAM
# TODO: This might work better for detecting a redeemed link, but
# hasn't been validated.

# if conversation_metadata.get("share_redemption") and conversation_metadata.get("share_redemption").get(
# "conversation_share_id"
# ):
# conversation_type = ConversationType.TEAM
else:
conversation_type = ConversationType.SHAREABLE_TEMPLATE
return conversation_type

async def detect_assistant_role(context: ConversationContext) -> ConversationRole:
"""
Expand Down Expand Up @@ -45,6 +74,34 @@ async def detect_assistant_role(context: ConversationContext) -> ConversationRol
return ConversationRole.COORDINATOR


async def get_shared_conversation_id(context: ConversationContext) -> Optional[str]:
"""
Get the shared conversation ID for a coordinator conversation.

This utility function retrieves the share ID and finds the associated
shareable template conversation ID from the knowledge package.

Args:
context: The conversation context (should be a coordinator conversation)

Returns:
The shared conversation ID if found, None otherwise
"""
try:
share_id = await ConversationKnowledgePackageManager.get_associated_share_id(context)
if not share_id:
return None

knowledge_package = ShareStorage.read_share(share_id)
if not knowledge_package or not knowledge_package.shared_conversation_id:
return None

return knowledge_package.shared_conversation_id
except Exception as e:
logger.error(f"Error getting shared conversation ID: {e}")
return None


async def log_transfer_action(
context: ConversationContext,
entry_type: LogEntryType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ async def update_conversation(self, metadata: dict[str, Any]) -> workbench_model
http_response.raise_for_status()
return workbench_model.Conversation.model_validate(http_response.json())

async def update_conversation_title(self, title: str) -> workbench_model.Conversation:
update_data = workbench_model.UpdateConversation(title=title)
http_response = await self._client.patch(
f"/conversations/{self._conversation_id}",
json=update_data.model_dump(mode="json", exclude_unset=True, exclude_defaults=True),
headers=self._headers,
)
http_response.raise_for_status()
return workbench_model.Conversation.model_validate(http_response.json())

async def get_participant_me(self) -> workbench_model.ConversationParticipant:
http_response = await self._client.get(
f"/conversations/{self._conversation_id}/participants/me", headers=self._headers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ async def get_conversation(self) -> workbench_model.Conversation:
async def update_conversation(self, metadata: dict[str, Any]) -> workbench_model.Conversation:
return await self._conversation_client.update_conversation(metadata)

async def update_conversation_title(self, title: str) -> workbench_model.Conversation:
return await self._conversation_client.update_conversation_title(title)

async def get_participants(self, include_inactive=False) -> workbench_model.ConversationParticipantList:
return await self._conversation_client.get_participants(include_inactive=include_inactive)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,13 @@ async def _forward_event(
file,
)

case workbench_model.ConversationEventType.conversation_updated:
# Conversation metadata updates (title, metadata, etc.)
await self.assistant_app.events.conversation._on_updated_handlers(
True, # event_originated_externally (always True for workbench updates)
conversation_context,
)

@translate_assistant_errors
async def get_conversation_state_descriptions(
self, assistant_id: str, conversation_id: str
Expand Down