Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions ai/integrations/bedrock/src/unitycatalog/ai/bedrock/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def get_stream(self) -> Iterator[str]:
if chunk:
yield chunk

# TODO Move to bedrock utils.
def extract_tool_calls(response: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Extracts tool calls from Bedrock response."""
tool_calls = []
Expand All @@ -73,20 +74,23 @@ def extract_tool_calls(response: Dict[str, Any]) -> List[Dict[str, Any]]:
},
'invocation_id': control_data['invocationId']
})
#TODO: - Invocation ID is common for all functions
return tool_calls

def execute_tool_calls(tool_calls, client):
def execute_tool_calls(tool_calls, client, catalog_name, schema_name):
results = []
for tool_call in tool_calls:
try:
full_function_name = tool_call.get("function_name")
print(f"Attempting to execute function: {full_function_name} with parameters: {tool_call.get('parameters')}")

# Derive the three level namespace from toolkit session
function_name = full_function_name.split('__')[1]
full_function_name_override = f'{catalog_name}.{schema_name}.{function_name}'
# Attempt to retrieve function info explicitly and log it
full_function_name_override = 'AICatalog.AISchema.location_weather_in_c'
function_info = client.get_function(full_function_name_override)
print(f"Retrieved function info Override: {function_info}")

# We are checking the existence of the function and assuming it's safe to call with parameters given by LLM
result = client.execute_function(
full_function_name_override,
tool_call['parameters']
Expand All @@ -109,7 +113,7 @@ def generate_tool_call_session_state(tool_result: Dict[str, Any],
action_group, function = tool_call['function_name'].split('__')
return {
'invocationId': tool_result['invocation_id'],
'returnControlInvocationResults': [{
'returnControlInvocationResults': [{ #TODO Need to iterate all tools
'functionResult': {
'actionGroup': action_group,
'function': function,
Expand All @@ -126,18 +130,21 @@ def generate_tool_call_session_state(tool_result: Dict[str, Any],
class BedrockSession:
"""Manages a session with AWS Bedrock agent runtime."""

def __init__(self, agent_id: str, agent_alias_id: str):
def __init__(self, agent_id: str, agent_alias_id: str, catalog_name: str, schema_name: str):
"""Initialize a Bedrock session."""
self.agent_id = agent_id
self.agent_alias_id = agent_alias_id
self.client = boto3.client('bedrock-agent-runtime')
self.catalog_name = catalog_name
self.schema_name = schema_name

def invoke_agent(
self,
input_text: str,
enable_trace: bool = None,
session_id: str = None,
session_state: dict = None,
streaming_configurations: dict = None,
uc_client: Optional[UnitycatalogFunctionClient] = None
) -> BedrockToolResponse:
"""Invoke the Bedrock agent with the given input text."""
Expand All @@ -153,6 +160,8 @@ def invoke_agent(
params['sessionId'] = session_id
if session_state is not None:
params['sessionState'] = session_state
if streaming_configurations is not None:
params['streamingConfigurations'] = streaming_configurations

response = self.client.invoke_agent(**params)
tool_calls = extract_tool_calls(response)
Expand All @@ -162,11 +171,11 @@ def invoke_agent(
print(f"Response from invoke agent: {response}") #Debugging
print(f"Tool Call Results: {tool_calls}") #Debugging

tool_results = execute_tool_calls(tool_calls, uc_client)
tool_results = execute_tool_calls(tool_calls, uc_client, self.catalog_name, self.schema_name)
print(f"ToolResults: {tool_results}") #Debugging
if tool_results:
session_state = generate_tool_call_session_state(
tool_results[0], tool_calls[0])
tool_results[0], tool_calls[0]) # TODO - Need to pass list of tool results
print(f"SessionState: {session_state}") #Debugging

if 'returnControlInvocationResults' in session_state:
Expand Down Expand Up @@ -195,7 +204,12 @@ def invoke_agent(
return self.invoke_agent(input_text="",
session_id=session_id,
enable_trace=enable_trace,
session_state=session_state)
session_state=session_state,
streaming_configurations={
'applyGuardrailInterval': 123, # TODO: Test variations
'streamFinalResponse': True
}
)

return BedrockToolResponse(raw_response=response, tool_calls=tool_calls)

Expand Down
89 changes: 89 additions & 0 deletions ai/integrations/bedrock/tests/ResponseText.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
Response from invoke agent:
{'ResponseMetadata': {
'RequestId': '80be282b-8f01-4a97-9a5c-7efb94f2e2a1',
'HTTPStatusCode': 200,
'HTTPHeaders': {
'date': 'Tue, 04 Feb 2025 21:20:29 GMT',
'content-type': 'application/vnd.amazon.eventstream',
'transfer-encoding': 'chunked',
'connection': 'keep-alive',
'x-amzn-requestid': '80be282b-8f01-4a97-9a5c-7efb94f2e2a1',
'x-amz-bedrock-agent-session-id': 'b9666c3a-e33c-11ef-9098-4e3e215d14d1',
'x-amzn-bedrock-agent-content-type': 'application/json'
},
'RetryAttempts': 0
},
'contentType': 'application/json',
'sessionId': 'b9666c3a-e33c-11ef-9098-4e3e215d14d1',
'completion': <botocore.eventstream.EventStream object at 0x115cd5cd0>
}
Tool Call Results: [
{
'function_name': 'tbd-bda-action-group-name__location_weather_in_c',
'parameters': {'fetch_date': '2024-11-19', 'location_id': '1234'},
'invocation_id': 'c752ccf6-45c6-4ae2-a27a-5dd61378896e-uc-result'
}
]
Attempting to execute function: tbd-bda-action-group-name__location_weather_in_c with parameters: {'fetch_date': '2024-11-19', 'location_id': '1234'}
Retrieved function info Override:
name='location_weather_in_c'
catalog_name='AICatalog'
schema_name='AISchema'
input_params=FunctionParameterInfos(parameters=[
FunctionParameterInfo(name='location_id', type_text='STRING',
type_json='{
"name": "location_id",
"type": "string",
"nullable": false,
"metadata": {
"comment": "The name to be included in the greeting message."
}
}',
type_name=<ColumnTypeName.STRING: 'STRING'>,
type_precision=None,
type_scale=None,
type_interval_type=None,
position=0,
parameter_mode=None,
parameter_type=None,
parameter_default=None,
comment='The name to be included in the greeting message.'
),
FunctionParameterInfo(name='fetch_date', type_text='STRING', type_json='{"name": "fetch_date", "type": "string", "nullable": false, "metadata": {"comment": "The date with the location"}}', type_name=<ColumnTypeName.STRING: 'STRING'>, type_precision=None, type_scale=None, type_interval_type=None, position=1, parameter_mode=None, parameter_type=None, parameter_default=None, comment='The date with the location')]) data_type=<ColumnTypeName.STRING: 'STRING'> full_data_type='STRING' return_params=None routine_body='EXTERNAL' routine_definition='try:\n # Fetch from Databricks SQL Warehouse based UC function execution \n return "23"\nexcept Exception as e:\n raise Exception(f"Error occurred: {e}")' routine_dependencies=None parameter_style='S' is_deterministic=True sql_data_access='NO_SQL' is_null_call=False security_type='DEFINER' specific_name='location_weather_in_c' comment='Test function for AWS Bedrock integration.' properties=None full_name='AICatalog.AISchema.location_weather_in_c' owner=None created_at=1738703541746 created_by=None updated_at=1738703541746 updated_by=None function_id='5937d4a7-1d10-4ae2-87ff-1b3be92df25d' external_language='PYTHON'
ToolResults: [
{
'invocation_id': 'c752ccf6-45c6-4ae2-a27a-5dd61378896e-uc-result',
'result': '23'
}
]
SessionState: {
'invocationId': 'c752ccf6-45c6-4ae2-a27a-5dd61378896e-uc-result',
'returnControlInvocationResults': [
{
'functionResult': {
'actionGroup': 'tbd-bda-action-group-name',
'function': 'location_weather_in_c',
'confirmationState': 'CONFIRM',
'responseBody': {
'TEXT': {'body': '23'}
}
}
}
]
}
Results: [
{
'functionResult': {
'actionGroup': 'tbd-bda-action-group-name',
'function': 'location_weather_in_c',
'confirmationState': 'CONFIRM',
'responseBody': {
'TEXT': {
'body': '23'
}
}
}
}
]
response_body_obj: {'TEXT': {'body': '23'}}
result value: 23
Loading