From 95d3d6de8f3c815cb09033b50512f2171cdc64c9 Mon Sep 17 00:00:00 2001 From: "T.J Ariyawansa" Date: Mon, 23 Feb 2026 11:57:07 -0500 Subject: [PATCH] test(memory): rewrite integration tests with proper pytest structure - Replace script-style test_devex.py with proper pytest test class in test_memory_client.py covering all 37 MemoryClient public methods - Add assert_created_event helper for create_event/fork_conversation responses which do not include payload - Add leaf-value blob assertion to handle service returning blob data as a stringified Java-style map representation - Fail tests when MEMORY_ROLE_ARN is not set instead of silently skipping - Add pytest-xdist for parallel test execution support - Add Husky git hooks setup to CONTRIBUTING.md and node_modules to .gitignore --- .gitignore | 4 + CONTRIBUTING.md | 9 + pyproject.toml | 1 + tests_integ/memory/test_devex.py | 756 ---------------- tests_integ/memory/test_memory_client.py | 1052 ++++++++++++++-------- uv.lock | 24 + 6 files changed, 720 insertions(+), 1126 deletions(-) delete mode 100644 tests_integ/memory/test_devex.py diff --git a/.gitignore b/.gitignore index b28ca305..3573055b 100644 --- a/.gitignore +++ b/.gitignore @@ -228,3 +228,7 @@ local_settings.py .dockerignore Dockerfile CLAUDE.md + +# Node.js (Husky git hooks) +node_modules/ +package-lock.json diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c855ca48..1445751f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -28,8 +28,17 @@ uv sync # Install pre-commit hooks (one-time) pre-commit install + +# Install Husky git hooks (manages git hooks layer on top of pre-commit) +bash scripts/setup-husky.sh ``` +> **Note:** Husky is used as an additional git hooks manager that delegates to the +> `pre-commit` framework. The `scripts/setup-husky.sh` script installs the npm +> dependencies, makes the hook files executable, and configures git to use the +> `.husky/` directory for hooks. Both systems work together: Husky manages the git +> hook entry points, and `pre-commit` performs the actual checks. + That's it! You're ready to develop. ### Daily Development Workflow diff --git a/pyproject.toml b/pyproject.toml index 00e15db6..2ae5a8b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,6 +144,7 @@ dev = [ "pytest>=8.4.1", "pytest-asyncio>=0.24.0", "pytest-cov>=6.0.0", + "pytest-xdist>=3.5.0", "ruff>=0.12.0", "websockets>=14.1", "wheel>=0.45.1", diff --git a/tests_integ/memory/test_devex.py b/tests_integ/memory/test_devex.py deleted file mode 100644 index 720d9b51..00000000 --- a/tests_integ/memory/test_devex.py +++ /dev/null @@ -1,756 +0,0 @@ -"""Comprehensive developer experience evaluation for Bedrock AgentCore Memory SDK.""" - -import os -import sys - -sys.path.append(os.path.join(os.path.dirname(__file__), "../../src")) - -import json -import logging -import time -from datetime import datetime - -from bedrock_agentcore.memory import MemoryClient - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) - - -def print_developer_journey(): - """Print the developer journey to understand the improvements.""" - - logger.info("=" * 80) - logger.info("DEVELOPER EXPERIENCE JOURNEY") - logger.info("=" * 80) - - logger.info("\nšŸ“– STORY: Building a Customer Support Agent") - logger.info("A developer wants to build an AI agent that:") - logger.info("- Handles customer inquiries") - logger.info("- Can explore different response strategies") - logger.info("- Escalates to human agents when needed") - logger.info("- Learns from interactions") - - logger.info("- save_conversation() handles any message pattern") - logger.info("- Full branch management (list, navigate, visualize)") - logger.info("- Flexible roles for tools and system messages") - logger.info("- Memory extraction for learning") - - -def test_complete_agent_workflow(client: MemoryClient, memory_id: str): - """Test a complete customer support agent workflow.""" - - logger.info("\n%s", "=" * 80) - logger.info("COMPLETE AGENT WORKFLOW TEST") - logger.info("=" * 80) - - actor_id = "customer-%s" % datetime.now().strftime("%Y%m%d%H%M%S") - session_id = "support-%s" % datetime.now().strftime("%Y%m%d%H%M%S") - - logger.info("\n1. Memory strategies already configured during creation") - - # Helper function for retries with exponential backoff - def save_with_retry(memory_id, actor_id, session_id, messages, branch=None, max_retries=5): - wait_time = 2 # Start with 2 seconds - attempt = 0 - - while attempt < max_retries: - try: - return client.save_conversation( - memory_id=memory_id, actor_id=actor_id, session_id=session_id, messages=messages, branch=branch - ) - except Exception as e: - if "ThrottledException" in str(e) and attempt < max_retries - 1: - attempt += 1 - logger.info( - "Rate limit hit, retrying in %d seconds (attempt %d/%d)...", wait_time, attempt, max_retries - ) - time.sleep(wait_time) - wait_time *= 2 # Exponential backoff - else: - raise # Re-raise if it's not a throttling error or max retries reached - - # Phase 1: Initial inquiry with context switching - logger.info("\n2. Customer makes initial inquiry...") - - initial = client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("Hi, I'm having trouble with my order #12345", "USER"), - ("I'm sorry to hear that. Let me look up your order.", "ASSISTANT"), - ("lookup_order(order_id='12345')", "TOOL"), - ("I see your order was shipped 3 days ago. What specific issue are you experiencing?", "ASSISTANT"), - ("Actually, before that - I also want to change my email address", "USER"), - ( - "Of course! I can help with both. Let's start with updating your email. What's your new email?", - "ASSISTANT", - ), - ("newemail@example.com", "USER"), - ("update_customer_email(old='old@example.com', new='newemail@example.com')", "TOOL"), - ("Email updated successfully! Now, about your order issue?", "ASSISTANT"), - ("The package arrived damaged", "USER"), - ], - ) - logger.info("āœ“ Handled context switch naturally") - - # Phase 2: A/B test different resolution approaches - logger.info("\n3. Testing different resolution strategies...") - - # MODIFIED: Create refund branch with first message only - _refund_branch = client.fork_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - root_event_id=initial["eventId"], - branch_name="immediate-refund", - new_messages=[ - ("I'm very sorry about the damaged package. I'll process an immediate refund.", "ASSISTANT"), - ], - ) - - # Continue the refund branch with additional messages - with longer delays and retries - time.sleep(5) # Increased delay - save_with_retry( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("process_refund(order_id='12345', reason='damaged', amount='full')", "TOOL"), - ], - branch={"name": "immediate-refund", "rootEventId": initial["eventId"]}, - ) - - time.sleep(5) # Increased delay - save_with_retry( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("Refund processed! You'll see it in 3-5 business days. Is there anything else?", "ASSISTANT"), - ("That was fast, thank you!", "USER"), - ], - branch={"name": "immediate-refund", "rootEventId": initial["eventId"]}, - ) - - time.sleep(5) # Increased delay - save_with_retry( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("You're welcome! I've also added a 10% discount to your account for next purchase.", "ASSISTANT"), - ], - branch={"name": "immediate-refund", "rootEventId": initial["eventId"]}, - ) - - # MODIFIED: Create replacement branch with first message only - time.sleep(5) # Increased delay - _replacement_branch = client.fork_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - root_event_id=initial["eventId"], - branch_name="replacement-offer", - new_messages=[ - ("I apologize for the damaged item. Would you prefer a replacement or refund?", "ASSISTANT"), - ], - ) - - # Continue the replacement branch with additional messages - time.sleep(5) # Increased delay - save_with_retry( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("How fast can you send a replacement?", "USER"), - ], - branch={"name": "replacement-offer", "rootEventId": initial["eventId"]}, - ) - - time.sleep(5) # Increased delay - save_with_retry( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("check_inventory(item='ORD-12345-ITEM')", "TOOL"), - ("We have it in stock! I can send a replacement with express shipping - arrives in 2 days.", "ASSISTANT"), - ], - branch={"name": "replacement-offer", "rootEventId": initial["eventId"]}, - ) - - time.sleep(5) # Increased delay - save_with_retry( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("That works for me", "USER"), - ("create_replacement_order(original='12345', shipping='express')", "TOOL"), - ], - branch={"name": "replacement-offer", "rootEventId": initial["eventId"]}, - ) - - time.sleep(5) # Increased delay - save_with_retry( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("Perfect! Replacement ordered with express shipping. You'll get tracking info shortly.", "ASSISTANT"), - ], - branch={"name": "replacement-offer", "rootEventId": initial["eventId"]}, - ) - - # MODIFIED: Create escalation branch with first message only - time.sleep(5) # Increased delay - _escalation_branch = client.fork_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - root_event_id=initial["eventId"], - branch_name="escalation-required", - new_messages=[ - ("I understand this is frustrating. Let me connect you with a specialist who can help.", "ASSISTANT"), - ], - ) - - # Continue the escalation branch with additional messages - time.sleep(5) # Increased delay - save_with_retry( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("This is the third time this has happened!", "USER"), - ], - branch={"name": "escalation-required", "rootEventId": initial["eventId"]}, - ) - - time.sleep(5) # Increased delay - save_with_retry( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("check_customer_history(customer_id='cust-123')", "TOOL"), - ( - "I see you've had multiple issues. I'm escalating this to our senior support team immediately.", - "ASSISTANT", - ), - ], - branch={"name": "escalation-required", "rootEventId": initial["eventId"]}, - ) - - time.sleep(5) # Increased delay - save_with_retry( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("create_escalation_ticket(priority='high', history='multiple_damages')", "TOOL"), - ("ticket_created: ESC-78901", "TOOL"), - ], - branch={"name": "escalation-required", "rootEventId": initial["eventId"]}, - ) - - time.sleep(5) # Increased delay - save_with_retry( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ( - "I've created high-priority ticket ESC-78901. A senior specialist will contact you within 1 hour.", - "ASSISTANT", - ), - ], - branch={"name": "escalation-required", "rootEventId": initial["eventId"]}, - ) - - logger.info("āœ“ Created 3 different resolution branches") - - # Phase 3: Analyze branches - logger.info("\n4. Analyzing branch outcomes...") - - branches = client.list_branches(memory_id, actor_id, session_id) - logger.info("\nFound %d total branches:", len(branches)) - - for branch in branches: - logger.info("\n Branch: %s", branch["name"]) - logger.info(" Events: %d", branch["eventCount"]) - - if branch["name"] != "main": - messages = client.merge_branch_context( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - branch_name=branch["name"], - include_parent=False, - ) - - if messages: - last_customer = None - last_agent = None - - for msg in reversed(messages): - if msg["role"] == "USER" and not last_customer: - last_customer = msg["content"] - elif msg["role"] == "ASSISTANT" and not last_agent: - last_agent = msg["content"] - - if last_customer and last_agent: - break - - logger.info(" Customer sentiment: %s", last_customer[:50] if last_customer else "N/A") - logger.info(" Final resolution: %s", last_agent[:80] + "..." if last_agent else "N/A") - - # Phase 4: Continue in best branch - logger.info("\n5. Continuing conversation in best branch...") - - # MODIFIED: Split follow-up into smaller batches - time.sleep(1) - client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("I got the replacement - it's perfect! Thank you so much!", "USER"), - ], - branch={"name": "replacement-offer", "rootEventId": initial["eventId"]}, - ) - - time.sleep(1) - client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("Wonderful! I'm glad we could resolve this quickly.", "ASSISTANT"), - ("save_positive_feedback(case_id='12345', rating=5, branch='replacement')", "TOOL"), - ], - branch={"name": "replacement-offer", "rootEventId": initial["eventId"]}, - ) - - time.sleep(1) - _followup = client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("Is there anything else I can help you with today?", "ASSISTANT"), - ("No, that's all. Great service!", "USER"), - ("Thank you! Have a great day!", "ASSISTANT"), - ], - branch={"name": "replacement-offer", "rootEventId": initial["eventId"]}, - ) - - logger.info("āœ“ Continued conversation in successful branch") - - # Phase 5: Wait for memory extraction - logger.info("\n6. Waiting for memory extraction...") - logger.info("Note: After creating events, extraction + vector indexing typically takes 2-3 minutes") - - logger.info("Waiting 30 seconds for extraction to trigger...") - time.sleep(30) - - namespace = "support/facts/%s/" % session_id - if client.wait_for_memories(memory_id, namespace, max_wait=180): - logger.info("āœ“ Memories extracted and indexed successfully") - - memories = client.retrieve_memories( - memory_id=memory_id, namespace=namespace, query="customer order issues damaged package", top_k=5 - ) - - logger.info("Retrieved %d relevant memories", len(memories)) - for i, mem in enumerate(memories[:3]): - logger.info(" [%d] %s", i + 1, mem.get("content", {}).get("text", "")[:100]) - else: - logger.info("āš ļø Memory extraction/indexing still in progress") - logger.info("This can take 3-5 minutes total. Try retrieving memories manually later.") - - # Phase 6: Visualize complete conversation - logger.info("\n7. Visualizing conversation structure...") - - tree = client.get_conversation_tree(memory_id, actor_id, session_id) - - def print_tree(branch_data, indent=0): - prefix = " " * indent - events = branch_data.get("events", []) - - if events: - logger.info("%sMain flow: %d events", prefix, len(events)) - for event in events[:2]: - for msg in event.get("messages", []): - logger.info("%s - %s: %s", prefix, msg["role"], msg["text"]) - - for branch_name, sub_branch in branch_data.get("branches", {}).items(): - logger.info("%s└─ Branch '%s': %d events", prefix, branch_name, len(sub_branch.get("events", []))) - if sub_branch.get("events"): - for msg in sub_branch["events"][0].get("messages", []): - logger.info("%s - %s: %s", prefix, msg["role"], msg["text"]) - - print_tree(tree["main_branch"]) - - -def test_bedrock_integration(client: MemoryClient, memory_id: str): - """Test AgentCore Memory with Amazon Bedrock integration.""" - - logger.info("\n%s", "=" * 80) - logger.info("TESTING BEDROCK INTEGRATION") - logger.info("=" * 80) - - import boto3 - - try: - bedrock = boto3.client("bedrock-runtime", region_name="us-east-1") - except Exception as e: - logger.error("Failed to initialize Bedrock client: %s", e) - logger.info("Skipping Bedrock test - ensure AWS credentials are configured") - return - - actor_id = "bedrock-test-%s" % datetime.now().strftime("%Y%m%d%H%M%S") - session_id = "bedrock-session-%s" % datetime.now().strftime("%Y%m%d%H%M%S") - - # Create initial context - logger.info("\n1. Creating initial conversation context...") - - _initial_events = client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("I'm planning a trip to Japan in April", "USER"), - ("That's exciting! April is cherry blossom season. What cities are you planning to visit?", "ASSISTANT"), - ("Tokyo and Kyoto for sure. I love photography", "USER"), - ("Perfect for photography! The cherry blossoms in Maruyama Park in Kyoto are stunning.", "ASSISTANT"), - ], - ) - - # Wait for extraction - logger.info("\n2. Waiting for memory extraction...") - time.sleep(60) - - # New user query - user_query = "What camera equipment should I bring for cherry blossom photography?" - logger.info("\n3. New user query: %s", user_query) - - # Retrieve relevant memories - logger.info("\n4. Retrieving relevant context...") - namespace = "support/facts/%s/" % session_id - memories = client.retrieve_memories(memory_id=memory_id, namespace=namespace, query=user_query, top_k=5) - - context = "" - if memories: - context = "\n".join([m.get("content", {}).get("text", "") for m in memories]) - logger.info("Found %d relevant memories", len(memories)) - - # Call Bedrock with context - logger.info("\n5. Calling Claude 3.5 Sonnet with context...") - - messages = [] - if context: - messages.append( - {"role": "assistant", "content": "Here's what I know from our previous conversation:\n%s" % context} - ) - - messages.append({"role": "user", "content": user_query}) - - try: - response = bedrock.invoke_model( - modelId="anthropic.claude-3-5-sonnet-20241022-v2:0", - contentType="application/json", - accept="application/json", - body=json.dumps( - { - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": 1000, - "messages": messages, - "temperature": 0.7, - } - ), - ) - - response_body = json.loads(response["body"].read()) - llm_response = response_body["content"][0]["text"] - - logger.info("\n6. Claude's response:") - logger.info("%s...", llm_response[:200]) - - # Save the new turn - logger.info("\n7. Saving conversation turn...") - _new_event = client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[(user_query, "USER"), (llm_response, "ASSISTANT")], - ) - - logger.info("āœ“ Successfully integrated Memory with Bedrock!") - - except Exception as e: - logger.error("Bedrock call failed: %s", e) - logger.info("Make sure you have access to Claude 3.5 Sonnet v2 in Bedrock") - - -def test_developer_productivity_metrics(client: MemoryClient, memory_id: str): - """Measure developer productivity improvements.""" - - logger.info("\n%s", "=" * 80) - logger.info("DEVELOPER PRODUCTIVITY METRICS") - logger.info("=" * 80) - - _actor_id = "metrics-test" - _session_id = "metrics-session" - - logger.info("\n1. Lines of Code Comparison") - logger.info("\nFlexible conversation handling:") - logger.info(" event = client.save_conversation(messages=[") - logger.info(" ('Question 1', 'USER'),") - logger.info(" ('Question 2', 'USER'),") - logger.info(" ('Checking...', 'ASSISTANT'),") - logger.info(" ('tool_call()', 'TOOL'),") - logger.info(" ('Complete answer', 'ASSISTANT')") - logger.info(" ])") - logger.info(" Total: 7 lines for complex flow") - - logger.info("\n2. API Calls for Common Tasks") - logger.info(" Get conversation history from branch: 1 call - list_branch_events()") - logger.info(" Find all branches: 1 call - list_branches()") - logger.info(" Save complex interaction: 1 call - save_conversation()") - - logger.info("\n3. Key Improvements") - logger.info(" āœ… Natural message flow representation") - logger.info(" āœ… Complete branch navigation") - logger.info(" āœ… Flexible message combinations") - logger.info(" āœ… Type-safe strategy methods") - - features = [ - ("Save user question without response", "30 seconds"), - ("Handle tool-augmented response", "1 minute"), - ("A/B test responses with branches", "2 minutes"), - ("Get branch conversation", "30 seconds"), - ("Find all branches", "1 API call"), - ] - - logger.info("\n4. Feature Implementation Time") - logger.info("\nFeature Time to Implement ") - logger.info("-" * 55) - for feature, impl_time in features: - logger.info("%-35s %-20s", feature, impl_time) - - -def test_edge_cases_and_validation(client: MemoryClient, memory_id: str): - """Test edge cases and validation improvements.""" - - logger.info("\n%s", "=" * 80) - logger.info("EDGE CASES AND VALIDATION") - logger.info("=" * 80) - - actor_id = "edge-test" - session_id = "edge-session" - - # Test 1: Very long conversation - logger.info("\n1. Testing very long conversation...") - - # MODIFIED: Split long conversation into smaller batches - for i in range(20): - messages = [] - messages.append(("Question %d about the product" % i, "USER")) - messages.append(("Answer %d with detailed information" % i, "ASSISTANT")) - - try: - long_event = client.save_conversation( - memory_id=memory_id, actor_id=actor_id, session_id=session_id, messages=messages - ) - logger.info("āœ“ Saved messages %d: %s", i + 1, long_event["eventId"]) - time.sleep(0.5) # Small delay between batches - except Exception as e: - logger.error("āŒ Failed to save messages %d: %s", i + 1, e) - - logger.info("āœ“ Saved long conversation in batches") - - # Test 2: Rapid branch creation - logger.info("\n2. Testing rapid branch creation...") - - base_event = client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id="rapid-branch-test", - messages=[("Start conversation", "USER")], - ) - - # MODIFIED: Added delays between branch creations - for i in range(5): - try: - time.sleep(1) # Delay before creating branch - _branch = client.fork_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id="rapid-branch-test", - root_event_id=base_event["eventId"], - branch_name="branch-%d" % i, - new_messages=[("Branch %d message" % i, "ASSISTANT")], - ) - logger.info("āœ“ Created branch-%d", i) - except Exception as e: - logger.error("āŒ Failed to create branch-%d: %s", i, e) - - # Test 3: Unicode and special characters - logger.info("\n3. Testing Unicode and special characters...") - - # MODIFIED: Split into smaller message groups - time.sleep(1) - _special_event = client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("Hello! šŸ‘‹ How can I help? 你儽!", "ASSISTANT"), - ], - ) - - time.sleep(1) - _special_event2 = client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("I need help with €100 payment", "USER"), - ("I'll help with your €100 payment šŸ’³", "ASSISTANT"), - ], - ) - - logger.info("āœ“ Handled Unicode and special characters") - - # Test 4: Empty messages - logger.info("\n4. Testing empty message content...") - - try: - time.sleep(1) - _empty_event = client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[("", "USER"), ("I didn't catch that. Could you repeat?", "ASSISTANT")], - ) - logger.info("āœ“ Handled empty message content") - except Exception as e: - logger.error("āŒ Failed with empty message: %s", e) - - -def generate_developer_report(client: MemoryClient): - """Generate a final developer experience report.""" - - logger.info("\n%s", "=" * 80) - logger.info("DEVELOPER EXPERIENCE REPORT") - logger.info("=" * 80) - - logger.info("\nšŸŽÆ KEY IMPROVEMENTS") - - improvements = [ - {"area": "Conversation Flexibility", "impact": "90% reduction in code for complex flows"}, - {"area": "Branch Management", "impact": "New scenarios now possible"}, - {"area": "Developer Intuition", "impact": "Faster onboarding, fewer errors"}, - {"area": "Real-world Scenarios", "impact": "Better user experiences"}, - ] - - for imp in improvements: - logger.info("\n%s:", imp["area"]) - logger.info(" Impact: %s", imp["impact"]) - - logger.info("\nšŸ“Š METRICS SUMMARY") - logger.info(" • Code reduction: 60-90% for complex scenarios") - logger.info(" • New capabilities: 5+ previously impossible features") - logger.info(" • API calls saved: 50-80% for multi-message flows") - logger.info(" • Learning curve: Significantly reduced") - - logger.info("\nāœ… RECOMMENDATION") - logger.info("The SDK improvements successfully address developer pain points.") - logger.info("Developers can now build more sophisticated agents with less code.") - logger.info("Branch management enables new use cases like A/B testing.") - logger.info("The flexible conversation API matches real-world requirements.") - - -def main(): - """Run complete developer experience evaluation.""" - - print_developer_journey() - - role_arn = os.getenv("MEMORY_ROLE_ARN") - if not role_arn: - logger.error("Please set MEMORY_ROLE_ARN environment variable") - return - - # Get region and environment from environment variables with defaults - region = os.getenv("AWS_REGION", "us-west-2") - environment = os.getenv("MEMORY_ENVIRONMENT", "prod") - - logger.info("Using region: %s, environment: %s", region, environment) - - client = MemoryClient(region_name=region) - - logger.info("\nCreating test memory with strategies...") - memory = client.create_memory( - name="DXTest_%s" % datetime.now().strftime("%Y%m%d%H%M%S"), - description="Developer experience evaluation", - strategies=[ - { - "semanticMemoryStrategy": { - "name": "CustomerInfo", - "description": "Extract customer information and issues", - "namespaces": ["support/facts/{sessionId}/"], - # NO configuration block - } - }, - { - "userPreferenceMemoryStrategy": { - "name": "CustomerPreferences", - "description": "Track customer preferences and history", - "namespaces": ["customers/{actorId}/preferences/"], - # NO configuration block - } - }, - ], - event_expiry_days=7, - memory_execution_role_arn=role_arn, - ) - - memory_id = memory["memoryId"] - logger.info("Created memory: %s", memory_id) - - logger.info("Waiting for memory activation...") - for _ in range(30): - time.sleep(10) - status = client.get_memory_status(memory_id) - if status == "ACTIVE": - logger.info("Memory is active!") - logger.info("Waiting additional 120 seconds for vector store initialization...") - time.sleep(120) - break - elif status == "FAILED": - logger.error("Memory creation failed!") - return - - try: - test_complete_agent_workflow(client, memory_id) - test_bedrock_integration(client, memory_id) - test_developer_productivity_metrics(client, memory_id) - test_edge_cases_and_validation(client, memory_id) - generate_developer_report(client) - - logger.info("\n%s", "=" * 80) - logger.info("DEVELOPER EXPERIENCE EVALUATION COMPLETE") - logger.info("=" * 80) - - except Exception as e: - logger.exception("Test failed: %s", e) - finally: - logger.info("\nTest memory ID: %s", memory_id) - logger.info("You can delete it with: client.delete_memory('%s')", memory_id) - - -if __name__ == "__main__": - main() diff --git a/tests_integ/memory/test_memory_client.py b/tests_integ/memory/test_memory_client.py index a892dea6..40bed92e 100644 --- a/tests_integ/memory/test_memory_client.py +++ b/tests_integ/memory/test_memory_client.py @@ -1,412 +1,724 @@ -"""Test script for critical AgentCore Memory SDK issues.""" +"""Integration tests for MemoryClient — full public-method coverage. + +Covers all 37 public methods of MemoryClient across 11 workflow tests. +Every test is fully independent — no inter-test dependencies. + +Run with: pytest -xvs tests_integ/memory/test_memory_client.py +Parallel: pytest -xvs tests_integ/memory/test_memory_client.py -n auto --dist loadscope +""" import logging import os -import time from datetime import datetime +import pytest + from bedrock_agentcore.memory import MemoryClient -# Use INFO level logging for cleaner output logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) -def test_list_events_api(client: MemoryClient, memory_id: str): - """Test the new list_events public API method.""" - logger.info("=" * 80) - logger.info("TESTING LIST_EVENTS PUBLIC API (Issue #1)") - logger.info("=" * 80) - - actor_id = "test-list-%s" % datetime.now().strftime("%Y%m%d%H%M%S") - session_id = "session-%s" % datetime.now().strftime("%Y%m%d%H%M%S") - - # Create some events - logger.info("\n1. Creating test events...") +@pytest.mark.integration +class TestMemoryClient: + """Integration tests for MemoryClient.""" + + MODEL_ID = "anthropic.claude-3-haiku-20240307-v1:0" + + RESTAURANT_CONVERSATION = [ + ("I'm vegetarian and I prefer restaurants with a quiet atmosphere.", "USER"), + ( + "Thank you for letting me know. I'll recommend vegetarian-friendly, " + "quiet restaurants. Any specific cuisine?", + "ASSISTANT", + ), + ("I'm in the mood for Italian cuisine.", "USER"), + ("Great choice! Do you have a preferred price range or location?", "ASSISTANT"), + ("I'd prefer something mid-range and located downtown.", "USER"), + ( + "I'll search for mid-range, vegetarian-friendly Italian restaurants " + "downtown. Would you like me to book a table?", + "ASSISTANT", + ), + ("Yes, please book for 7 PM.", "USER"), + ( + "I'll find a suitable restaurant and make a reservation for 7 PM. Anything else?", + "ASSISTANT", + ), + ("That's all for now. Thank you!", "USER"), + ] - for i in range(3): - event = client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("Message %d from user" % (i + 1), "USER"), - ("Response %d from assistant" % (i + 1), "ASSISTANT"), + @pytest.fixture(scope="class", autouse=True) + def setup_client(self, request): + """Provision MemoryClient and per-worker test prefix.""" + region = os.environ.get("AWS_REGION", "us-west-2") + + role_arn = os.environ.get("MEMORY_ROLE_ARN") + if not role_arn: + pytest.fail("MEMORY_ROLE_ARN environment variable is not set") + + request.cls.role_arn = role_arn + request.cls.client = MemoryClient(region_name=region) + + worker_id = os.environ.get("PYTEST_XDIST_WORKER", "main") + request.cls.test_prefix = "mc_%s_%s" % (datetime.now().strftime("%Y%m%d%H%M%S"), worker_id) + request.cls.memory_ids = [] + + yield + + # Teardown — delete control-plane test memories + for memory_id in request.cls.memory_ids: + try: + request.cls.client.delete_memory(memory_id) + logger.info("Deleting test memory: %s", memory_id) + except Exception as e: + logger.error("Failed to delete test memory %s: %s", memory_id, e) + + @pytest.fixture(scope="class") + def shared_memory(self, request): + """Create or get the shared memory for data-plane tests.""" + logger.info("Creating/getting shared memory for tests...") + memory = request.cls.client.create_or_get_memory( + name="SharedDataPlane", + strategies=[ + { + "semanticMemoryStrategy": { + "name": "TestStrategy", + "namespaces": ["test/{actorId}/{sessionId}/"], + } + } ], + event_expiry_days=7, + memory_execution_role_arn=request.cls.role_arn, ) - logger.info("Created event %d: %s", i + 1, event["eventId"]) - time.sleep(1) - - # Wait for indexing - INCREASED WAIT TIME - logger.info("\nWaiting 60 seconds for event indexing...") - time.sleep(60) - - # Test list_events - logger.info("\n2. Testing list_events() method...") - - try: - # Get all events - all_events = client.list_events(memory_id, actor_id, session_id) - logger.info("āœ“ Retrieved %d events total", len(all_events)) - - # Get main branch only - main_events = client.list_events(memory_id, actor_id, session_id, branch_name="main") - logger.info("āœ“ Retrieved %d main branch events", len(main_events)) - - # Get with max_results - limited_events = client.list_events(memory_id, actor_id, session_id, max_results=2) - logger.info("āœ“ Retrieved %d events with max_results=2", len(limited_events)) - - # Show event structure - if all_events: - logger.info("\nSample event structure:") - event = all_events[0] - logger.info(" Event ID: %s", event.get("eventId")) - logger.info(" Timestamp: %s", event.get("eventTimestamp")) - logger.info(" Has payload: %s", "payload" in event) - - except Exception as e: - logger.error("āŒ list_events failed: %s", e) - raise - - -def test_strategy_polling_fix(client: MemoryClient): - """Test that all strategy operations use polling to avoid CREATING state errors.""" - logger.info("\n%s", "=" * 80) - logger.info("TESTING STRATEGY POLLING FIX (Issue #2)") - logger.info("=" * 80) - - # Create memory without strategies - logger.info("\n1. Creating memory without strategies...") - memory = client.create_memory_and_wait( - name="PollingTest_%s" % datetime.now().strftime("%Y%m%d%H%M%S"), - strategies=[], # No strategies initially - event_expiry_days=7, - ) - memory_id = memory["memoryId"] - logger.info("āœ“ Created memory: %s", memory_id) - - # Add first strategy - logger.info("\n2. Adding summary strategy with polling...") - try: - memory = client.add_summary_strategy_and_wait( - memory_id=memory_id, name="TestSummary", namespaces=["summaries/{sessionId}/"] - ) - logger.info("āœ“ Added summary strategy, memory is %s", memory["status"]) - except Exception as e: - logger.error("āŒ Failed to add summary strategy: %s", e) - raise - - # Create some events while memory is active - logger.info("\n3. Creating events...") - actor_id = "test-actor" - session_id = "test-session" - - event = client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[("Test message", "USER"), ("Test response", "ASSISTANT")], - ) - logger.info("āœ“ Created event: %s", event["eventId"]) - - # Add another strategy immediately - logger.info("\n4. Adding user preference strategy immediately...") - try: - memory = client.add_user_preference_strategy_and_wait( - memory_id=memory_id, name="TestPreferences", namespaces=["preferences/{actorId}/"] - ) - logger.info("āœ“ Added user preference strategy without error, memory is %s", memory["status"]) - except Exception as e: - logger.error("āŒ Failed due to CREATING state: %s", e) - raise - - # Clean up - try: - client.delete_memory_and_wait(memory_id) - logger.info("āœ“ Cleaned up test memory") - except Exception: - pass - - -def test_get_last_k_turns_fix(client: MemoryClient, memory_id: str): - """Test that get_last_k_turns returns the correct turns.""" - logger.info("\n%s", "=" * 80) - logger.info("TESTING GET_LAST_K_TURNS FIX (Issue #3)") - logger.info("=" * 80) - - actor_id = "restaurant-user-%s" % datetime.now().strftime("%Y%m%d%H%M%S") - session_id = "restaurant-session-%s" % datetime.now().strftime("%Y%m%d%H%M%S") - - # Create the exact conversation from the issue - logger.info("\n1. Creating restaurant conversation...") - - event = client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - ("I'm vegetarian and I prefer restaurants with a quiet atmosphere.", "USER"), - ( - "Thank you for letting me know. I'll make sure to recommend restaurants that are " - "vegetarian-friendly and have a quiet atmosphere. Is there any specific cuisine " - "you're interested in today?", - "ASSISTANT", - ), - ("I'm in the mood for Italian cuisine.", "USER"), - ( - "Great choice! I'll look for Italian vegetarian restaurants with a quiet " - "atmosphere. Do you have a preferred price range or location?", - "ASSISTANT", - ), - ("I'd prefer something mid-range and located downtown.", "USER"), - ( - "Noted. I'll search for mid-range, vegetarian-friendly Italian restaurants in " - "the downtown area with a quiet atmosphere. Would you like me to book a table " - "for a specific time?", - "ASSISTANT", - ), - ("Yes, please book for 7 PM.", "USER"), - ( - "Sure, I'll find a suitable restaurant and make a reservation for 7 PM. " - "Is there anything else I can assist you with?", - "ASSISTANT", - ), - ("No, that's all for now. Thank you!", "USER"), - ], - ) - logger.info("āœ“ Conversation saved: %s", event["eventId"]) - - # Wait for event indexing - INCREASED WAIT TIME - logger.info("\nWaiting 60 seconds for event indexing...") - time.sleep(60) - - # Test 1: Without branch_name - logger.info("\n2. Testing get_last_k_turns without branch_name...") - try: - turns = client.get_last_k_turns(memory_id=memory_id, actor_id=actor_id, session_id=session_id, k=2) - logger.info("āœ“ Retrieved %d turns (no branch_name)", len(turns)) - - if turns: - logger.info("\nLast 2 turns:") - for i, turn in enumerate(turns): - logger.info(" Turn %d:", i + 1) - for msg in turn: - role = msg.get("role", "") - text = msg.get("content", {}).get("text", "")[:60] + "..." - logger.info(" %s: %s", role, text) - else: - logger.error("āŒ No turns returned!") - - except Exception as e: - logger.error("āŒ Failed without branch_name: %s", e) - - # Test 2: With branch_name="main" - logger.info("\n3. Testing get_last_k_turns with branch_name='main'...") - try: - turns = client.get_last_k_turns( - memory_id=memory_id, actor_id=actor_id, session_id=session_id, branch_name="main", k=2 - ) - logger.info("āœ“ Retrieved %d turns (branch_name='main')", len(turns)) - - if not turns: - logger.error("āŒ No turns returned for main branch!") - - except Exception as e: - logger.error("āŒ Failed with branch_name='main': %s", e) - - # Test 3: Verify we get the LAST turns, not the first - logger.info("\n4. Verifying we get LAST turns, not first...") - all_turns = client.get_last_k_turns( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - k=10, # Get all turns - ) - - if all_turns: - last_turn = all_turns[-1] - if last_turn and last_turn[0].get("content", {}).get("text", "").startswith("No, that's all"): - logger.info("āœ“ Correctly returned LAST turns (ends with 'No, that's all')") - else: - logger.error("āŒ Returned FIRST turns instead of LAST!") - + request.cls.memory_id = memory["memoryId"] + request.cls.client._wait_for_memory_active(request.cls.memory_id, max_wait=300, poll_interval=10) + logger.info("Shared test memory: %s", request.cls.memory_id) -def test_namespace_wildcards(client: MemoryClient, memory_id: str): - """Test and document that wildcards are not supported in namespaces.""" - logger.info("\n%s", "=" * 80) - logger.info("TESTING NAMESPACE WILDCARD LIMITATION (Issue #4)") - logger.info("=" * 80) + yield - # Check memory strategy configuration - logger.info("\n1. Checking memory strategy configuration:") - strategies = client.get_memory_strategies(memory_id) - for strategy in strategies: - logger.info("Strategy type: %s", strategy.get("type") or strategy.get("memoryStrategyType")) - logger.info("Strategy namespaces: %s", strategy.get("namespaces", [])) - - # Create multiple test events with different actor/session combinations - logger.info("\n2. Creating multiple test events...") + # Teardown — delete shared memory + try: + request.cls.client.delete_memory(request.cls.memory_id) + logger.info("Deleted shared memory: %s", request.cls.memory_id) + except Exception as e: + logger.error("Failed to delete shared memory %s: %s", request.cls.memory_id, e) + + @pytest.fixture() + def ids(self, request): + """Generate unique actor_id and session_id from the test name.""" + suffix = request.node.name.replace("test_", "", 1) + request.cls.actor_id = "actor_%s_%s" % (self.test_prefix, suffix) + request.cls.session_id = "session_%s_%s" % (self.test_prefix, suffix) + + # ------------------------------------------------------------------ + # Data-plane tests (use the shared memory) + # ------------------------------------------------------------------ + + def test_create_event(self, shared_memory, ids): + """create_event: store a USER + ASSISTANT conversation event and verify via list_events.""" + event = self.client.create_event( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + messages=self.RESTAURANT_CONVERSATION, + ) + self.assert_created_event(event) + + # Verify content persisted via list_events + events = self.client.list_events(self.memory_id, self.actor_id, self.session_id) + assert len(events) >= 1, "Expected at least 1 event" + self.assert_listed_event(events[0], self.RESTAURANT_CONVERSATION) + logger.info("Created and verified event: %s", event["eventId"]) + + def test_create_blob_event(self, shared_memory, ids): + """create_blob_event: store a blob payload and verify via list_events.""" + blob_data = {"type": "tool_output", "result": {"temperature": 72, "unit": "F"}} + + event = self.client.create_blob_event( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + blob_data=blob_data, + ) + self.assert_created_event(event) + + # Verify blob payload persisted via list_events + events = self.client.list_events(self.memory_id, self.actor_id, self.session_id) + assert len(events) >= 1, "Expected at least 1 event" + self.assert_blob_event(events[0], blob_data) + logger.info("Created and verified blob event: %s", event["eventId"]) + + def test_branching(self, shared_memory, ids): + """fork_conversation, list_branches, list_branch_events, merge_branch_context, get_conversation_tree.""" + branch_messages = [ + ("Actually, can we switch to Japanese cuisine instead?", "USER"), + ("Of course! I'll look for mid-range Japanese restaurants downtown.", "ASSISTANT"), + ] + + # Create a root event to fork from + root_event = self.client.create_event( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + messages=self.RESTAURANT_CONVERSATION, + ) + root_event_id = root_event["eventId"] + + # Fork from the root event + branch_name = "alt_branch" + fork_event = self.client.fork_conversation( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + root_event_id=root_event_id, + branch_name=branch_name, + new_messages=branch_messages, + ) + self.assert_created_event(fork_event) + logger.info("Forked conversation: %s", fork_event["eventId"]) - actor_ids = [] - session_ids = [] + # list_branches + branches = self.client.list_branches( + memory_id=self.memory_id, actor_id=self.actor_id, session_id=self.session_id + ) + assert isinstance(branches, list), "list_branches must return a list" + branch_names = [b.get("branchName", b.get("name", "")) for b in branches] + assert branch_name in branch_names, "Branch %r not found in %s" % (branch_name, branch_names) + logger.info("Branches: %s", branch_names) + + # list_branch_events — verify fork messages appear + branch_events = self.client.list_branch_events( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + branch_name=branch_name, + ) + assert len(branch_events) >= 1, "Branch should have at least one event" + self.assert_listed_event(branch_events[0], branch_messages) + logger.info("Branch events count: %d", len(branch_events)) + + # merge_branch_context — should contain messages from both root and branch + context = self.client.merge_branch_context( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + branch_name=branch_name, + ) + expected_merged = self.RESTAURANT_CONVERSATION + branch_messages + assert len(context) == len(expected_merged), "Expected %d merged messages (9 root + 2 branch), got %d" % ( + len(expected_merged), + len(context), + ) + for i, (text, role) in enumerate(expected_merged): + assert context[i]["content"] == text, "Merged message %d content mismatch: expected %r, got %r" % ( + i, + text, + context[i]["content"], + ) + assert context[i]["role"] == role, "Merged message %d role mismatch: expected %s, got %s" % ( + i, + role, + context[i]["role"], + ) + logger.info("Merged context: %d messages", len(context)) + + # get_conversation_tree — verify structure + tree = self.client.get_conversation_tree( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + ) + self.assert_tree_has_branch(tree, branch_name, min_main_events=0) + logger.info( + "Conversation tree: %d main events, %d branches", + len(tree["main_branch"]["events"]), + len(tree["main_branch"]["branches"]), + ) - for i in range(3): - actor_id = "wildcard-test-%s-%d" % (datetime.now().strftime("%Y%m%d%H%M%S"), i) - session_id = "wildcard-session-%s-%d" % (datetime.now().strftime("%Y%m%d%H%M%S"), i) - actor_ids.append(actor_id) - session_ids.append(session_id) + def test_list_events(self, shared_memory, ids): + """list_events: all, branch filter, max_results.""" + # Create 3 events + for i in range(3): + self.client.create_event( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + messages=[ + ("Question %d" % (i + 1), "USER"), + ("Answer %d" % (i + 1), "ASSISTANT"), + ], + ) + + # All events + all_events = self.client.list_events(self.memory_id, self.actor_id, self.session_id) + assert len(all_events) >= 3, "Expected >=3 events, got %d" % len(all_events) + logger.info("Total events: %d", len(all_events)) + + # Main branch filter + main_events = self.client.list_events(self.memory_id, self.actor_id, self.session_id, branch_name="main") + assert isinstance(main_events, list) + logger.info("Main branch events: %d", len(main_events)) + + # max_results + limited = self.client.list_events(self.memory_id, self.actor_id, self.session_id, max_results=2) + assert len(limited) <= 2, "max_results=2 returned %d events" % len(limited) + logger.info("Limited events (max_results=2): %d", len(limited)) + + # Verify all 3 events (list_events returns newest first) + self.assert_listed_event(all_events[0], [("Question 3", "USER"), ("Answer 3", "ASSISTANT")]) + self.assert_listed_event(all_events[1], [("Question 2", "USER"), ("Answer 2", "ASSISTANT")]) + self.assert_listed_event(all_events[2], [("Question 1", "USER"), ("Answer 1", "ASSISTANT")]) + + def test_get_last_k_turns(self, shared_memory, ids): + """get_last_k_turns: main branch and named branch after fork.""" + # 9-message restaurant conversation (5 USER, 4 ASSISTANT) + self.client.create_event( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + messages=self.RESTAURANT_CONVERSATION, + ) - event = client.save_conversation( - memory_id=memory_id, - actor_id=actor_id, - session_id=session_id, - messages=[ - (f"Test message {i + 1} for wildcard testing with specific keyword", "USER"), - (f"Response {i + 1} for wildcard testing with specific keyword", "ASSISTANT"), + # Verify the event was stored with all 9 messages + events = self.client.list_events(self.memory_id, self.actor_id, self.session_id) + assert len(events) >= 1, "Expected at least 1 event" + self.assert_listed_event(events[0], self.RESTAURANT_CONVERSATION) + + # Main branch — all 9 messages + all_turns = self.client.get_last_k_turns( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + k=10, + ) + self.assert_turns(all_turns, self.RESTAURANT_CONVERSATION) + + # Fork into INDIAN_FOOD branch from the root event + root_event_id = events[0]["eventId"] + self.client.fork_conversation( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + root_event_id=root_event_id, + branch_name="INDIAN_FOOD", + new_messages=[ + ("Actually, I changed my mind. How about Indian food?", "USER"), + ("Indian cuisine is a great choice! Do you like spicy food?", "ASSISTANT"), ], ) - logger.info("āœ“ Created event %d: %s", i + 1, event["eventId"]) - - # Wait for extraction - INCREASED WAIT TIME - logger.info("\nWaiting 90 seconds for memory extraction...") - time.sleep(90) - - # Test 1: Wildcard namespace (should fail) - logger.info("\n3. Testing with wildcard namespace '*'...") - - result = client.wait_for_memories( - memory_id=memory_id, namespace="*", test_query="specific keyword", max_wait=30, poll_interval=10 - ) - - if not result: - logger.info("āœ“ Correctly rejected wildcard namespace") - else: - logger.error("āŒ Wildcard should not have worked!") - - # Test 2: Retrieve with wildcard (should return empty) - logger.info("\n4. Testing retrieve_memories with wildcard...") - memories = client.retrieve_memories(memory_id=memory_id, namespace="*", query="specific keyword") - - if len(memories) == 0: - logger.info("āœ“ Correctly returned empty for wildcard namespace") - else: - logger.error("āŒ Should not return memories with wildcard!") - - # Test 3: Exact namespace (should work) - logger.info("\n5. Testing with exact namespace...") - - # Use the first actor/session from our created events - actor_id = actor_ids[0] - session_id = session_ids[0] - - # Assuming semantic strategy with pattern "test/{actorId}/{sessionId}" - exact_namespace = f"test/{actor_id}/{session_id}/" - - logger.info("Trying exact namespace: %s", exact_namespace) - memories = client.retrieve_memories(memory_id=memory_id, namespace=exact_namespace, query="specific keyword") + # Get turns from the INDIAN_FOOD branch (only branch events, no parent) + branch_turns = self.client.get_last_k_turns( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + branch_name="INDIAN_FOOD", + k=10, + ) + assert branch_turns, "get_last_k_turns on INDIAN_FOOD branch returned no turns" + self.assert_turns( + branch_turns, + [ + ("Actually, I changed my mind. How about Indian food?", "USER"), + ("Indian cuisine is a great choice! Do you like spicy food?", "ASSISTANT"), + ], + ) + logger.info("Verified INDIAN_FOOD branch: 2 messages") + + def test_retrieve_memories(self, shared_memory, ids): + """wait_for_memories, retrieve_memories, wildcard namespace rejection.""" + self.client.create_event( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + messages=self.RESTAURANT_CONVERSATION, + ) - logger.info("āœ“ Retrieved %d memories with exact namespace", len(memories)) + namespace = "test/%s/%s/" % (self.actor_id, self.session_id) + query = "vegetarian Italian restaurants downtown" + + # Wait for memory extraction + logger.info("Waiting for memory extraction (up to 180s)...") + extraction_done = self.client.wait_for_memories( + memory_id=self.memory_id, + namespace=namespace, + test_query=query, + max_wait=180, + poll_interval=15, + ) - if memories: - for i, mem in enumerate(memories[:2]): - logger.info(" Memory %d: %s", i + 1, mem.get("content", {}).get("text", "")[:80]) + results = self.client.retrieve_memories( + memory_id=self.memory_id, + namespace=namespace, + query=query, + ) + assert isinstance(results, list), "retrieve_memories must return a list" + if extraction_done and len(results) > 0: + logger.info("Retrieved %d memories", len(results)) + else: + logger.warning( + "Memory extraction done=%s, results=%d — extraction may need more time", + extraction_done, + len(results), + ) + + # Wildcard namespace should return empty (wildcards not supported) + wildcard_results = self.client.retrieve_memories( + memory_id=self.memory_id, + namespace="test/*/", + query="anything", + ) + assert wildcard_results == [], "Wildcard namespace should return empty, got %d results" % len(wildcard_results) + logger.info("Wildcard namespace correctly returned empty") + + def test_process_turn_with_llm(self, shared_memory, ids): + """process_turn_with_llm: mock callback echoes input.""" + user_input = "Tell me about quantum computing" + + def mock_llm_callback(prompt, memories): + return "Echo: %s (memories=%d)" % (prompt, len(memories)) + + memories, response, event = self.client.process_turn_with_llm( + memory_id=self.memory_id, + actor_id=self.actor_id, + session_id=self.session_id, + user_input=user_input, + llm_callback=mock_llm_callback, + ) - # Test 4: Prefix namespace (should work like S3 prefix) - logger.info("\n6. Testing with prefix namespace...") + assert isinstance(memories, list), "memories should be a list" + assert isinstance(response, str), "response should be a string" + assert user_input in response, "response should contain the user input" + assert "eventId" in event, "event should have eventId" + logger.info("process_turn_with_llm returned response: %s", response[:80]) + + # ------------------------------------------------------------------ + # Control-plane tests (create their own memories) + # ------------------------------------------------------------------ + + def test_create_or_get_memory(self): + """create_or_get_memory, list_memories.""" + name = "%s_create_or_get" % self.test_prefix + + # First call — creates memory + memory1 = self.client.create_or_get_memory( + name=name, + strategies=[ + { + "semanticMemoryStrategy": { + "name": "Sem1", + "namespaces": ["test/{actorId}/"], + } + } + ], + event_expiry_days=7, + memory_execution_role_arn=self.role_arn, + ) + memory_id = memory1["memoryId"] + self.__class__.memory_ids.append(memory_id) + logger.info("create_or_get_memory created: %s", memory_id) + + # Second call — should return existing memory + memory2 = self.client.create_or_get_memory( + name=name, + strategies=[ + { + "semanticMemoryStrategy": { + "name": "Sem1", + "namespaces": ["test/{actorId}/"], + } + } + ], + event_expiry_days=7, + memory_execution_role_arn=self.role_arn, + ) + assert memory2["memoryId"] == memory_id, "Second call should return the same memory" + logger.info("create_or_get_memory returned existing: %s", memory2["memoryId"]) + + # list_memories + all_memories = self.client.list_memories() + memory_ids = [m["memoryId"] for m in all_memories] + assert memory_id in memory_ids, "list_memories should include the created memory" + logger.info("list_memories returned %d memories (includes ours)", len(all_memories)) + + def test_add_builtin_strategies(self): + """Test add_semantic/summary/user_preference/episodic_strategy_and_wait and get_memory_strategies.""" + # Create a fresh memory for strategy tests + memory = self.client.create_memory_and_wait( + name="%s_builtin_strat" % self.test_prefix, + strategies=[], + event_expiry_days=7, + memory_execution_role_arn=self.role_arn, + ) + memory_id = memory["memoryId"] + self.__class__.memory_ids.append(memory_id) + logger.info("Created strategy test memory: %s", memory_id) - # Try multiple prefix options - prefixes = [ - "test/", - f"test/{actor_id}/", - ] + # 1) Semantic + result = self.client.add_semantic_strategy_and_wait( + memory_id=memory_id, + name="BuiltinSemantic", + namespaces=["sem/{actorId}/"], + ) + self.assert_memory_active(result) + logger.info("Added semantic strategy") - for prefix in prefixes: - logger.info("\nTrying prefix namespace: %s", prefix) - memories = client.retrieve_memories(memory_id=memory_id, namespace=prefix, query="specific keyword") + # 2) Summary + result = self.client.add_summary_strategy_and_wait( + memory_id=memory_id, + name="BuiltinSummary", + namespaces=["sum/{sessionId}/"], + ) + self.assert_memory_active(result) + logger.info("Added summary strategy") - logger.info("āœ“ Retrieved %d memories with prefix namespace", len(memories)) + # 3) User preference + result = self.client.add_user_preference_strategy_and_wait( + memory_id=memory_id, + name="BuiltinUserPref", + namespaces=["pref/{actorId}/"], + ) + self.assert_memory_active(result) + logger.info("Added user preference strategy") - if memories: - for i, mem in enumerate(memories[:2]): - logger.info(" Memory %d: %s", i + 1, mem.get("content", {}).get("text", "")[:80]) + # 4) Episodic (reflection namespace must be same as or prefix of episodic namespace) + result = self.client.add_episodic_strategy_and_wait( + memory_id=memory_id, + name="BuiltinEpisodic", + reflection_namespaces=["ep/{actorId}/"], + namespaces=["ep/{actorId}/"], + ) + self.assert_memory_active(result) + logger.info("Added episodic strategy") + + # Verify all 4 strategies via get_memory_strategies + strategies = self.client.get_memory_strategies(memory_id) + self.assert_strategies_by_name( + strategies, + ["BuiltinSemantic", "BuiltinSummary", "BuiltinUserPref", "BuiltinEpisodic"], + ) + logger.info("Verified 4 built-in strategies") + + def test_custom_strategies_and_modification(self): + """Test custom semantic/episodic strategies, modify_strategy, delete_strategy.""" + # Create its own memory with a seed strategy so we have something to work with + memory = self.client.create_memory_and_wait( + name="%s_custom_strat" % self.test_prefix, + strategies=[ + { + "semanticMemoryStrategy": { + "name": "SeedStrategy", + "namespaces": ["seed/{actorId}/"], + } + } + ], + event_expiry_days=7, + memory_execution_role_arn=self.role_arn, + ) + memory_id = memory["memoryId"] + self.__class__.memory_ids.append(memory_id) + logger.info("Created custom strategy test memory: %s", memory_id) + # Add custom semantic strategy + result = self.client.add_custom_semantic_strategy_and_wait( + memory_id=memory_id, + name="CustomSemantic", + extraction_config={"prompt": "Extract key facts.", "modelId": self.MODEL_ID}, + consolidation_config={"prompt": "Consolidate facts.", "modelId": self.MODEL_ID}, + namespaces=["custom_sem/{actorId}/"], + ) + self.assert_memory_active(result) + logger.info("Added custom semantic strategy") -def main(): - """Run all critical issue tests.""" + # Add custom episodic strategy + result = self.client.add_custom_episodic_strategy_and_wait( + memory_id=memory_id, + name="CustomEpisodic", + extraction_config={"prompt": "Extract episodes.", "modelId": self.MODEL_ID}, + consolidation_config={"prompt": "Consolidate episodes.", "modelId": self.MODEL_ID}, + reflection_config={ + "prompt": "Reflect on episodes.", + "modelId": self.MODEL_ID, + "namespaces": ["custom_ep/{actorId}/"], + }, + namespaces=["custom_ep/{actorId}/"], + ) + self.assert_memory_active(result) + logger.info("Added custom episodic strategy") - # Get role ARN from environment - role_arn = os.getenv("MEMORY_ROLE_ARN") - if not role_arn: - logger.error("Please set MEMORY_ROLE_ARN environment variable") - return + # Get all strategies to find IDs + strategies = self.client.get_memory_strategies(memory_id) + strategy_map = {s["name"]: s for s in strategies} + self.assert_strategies_by_name(strategies, ["SeedStrategy", "CustomSemantic", "CustomEpisodic"]) - # Get region and environment from environment variables with defaults - region = os.getenv("AWS_REGION", "us-west-2") - environment = os.getenv("MEMORY_ENVIRONMENT", "prod") + # modify_strategy — update description on custom semantic + custom_sem_id = strategy_map["CustomSemantic"]["strategyId"] + self.client.modify_strategy( + memory_id=memory_id, + strategy_id=custom_sem_id, + description="Updated custom semantic description", + ) + logger.info("Modified custom semantic strategy description") - logger.info("Using region: %s, environment: %s", region, environment) + # Wait for memory to become ACTIVE after modification + self.client._wait_for_memory_active(memory_id, max_wait=300, poll_interval=10) - client = MemoryClient(region_name=region) + # delete_strategy — delete the custom episodic strategy + custom_ep_id = strategy_map["CustomEpisodic"]["strategyId"] + self.client.delete_strategy(memory_id=memory_id, strategy_id=custom_ep_id) + logger.info("Deleted custom episodic strategy") - # Test Issue #2 first (strategy polling) - test_strategy_polling_fix(client) + # Wait for memory to settle + self.client._wait_for_memory_active(memory_id, max_wait=300, poll_interval=10) - # Create a memory for remaining tests - logger.info("\n\nCreating memory for remaining tests...") - # Explicitly define strategy with clear namespace pattern for testing - memory = client.create_memory_and_wait( - name="RetrievalTest_%s" % datetime.now().strftime("%Y%m%d%H%M%S"), - strategies=[ - { - "semanticMemoryStrategy": { - "name": "TestStrategy", - "namespaces": ["test/{actorId}/{sessionId}/"], # Explicit namespace pattern - } - } - ], - event_expiry_days=7, - memory_execution_role_arn=role_arn, - ) - memory_id = memory["memoryId"] - logger.info("Created test memory: %s", memory_id) - - try: - # Test Issue #1: list_events API - test_list_events_api(client, memory_id) - - # Test Issue #3: get_last_k_turns fix - test_get_last_k_turns_fix(client, memory_id) - - # Test Issue #4: namespace wildcards - logger.info("\n\nStarting namespace wildcard tests with memory ID: %s", memory_id) - logger.info( - "IMPORTANT: All retrieve calls will target the semantic strategy with " - "namespace pattern: test/{actorId}/{sessionId}" + # update_memory_strategies_and_wait — delete the seed strategy + seed_id = strategy_map["SeedStrategy"]["strategyId"] + result = self.client.update_memory_strategies_and_wait( + memory_id=memory_id, + delete_strategy_ids=[seed_id], + ) + self.assert_memory_active(result) + logger.info("Deleted SeedStrategy via update_memory_strategies_and_wait") + + # Verify remaining strategies + remaining = self.client.get_memory_strategies(memory_id) + remaining_names = {s["name"] for s in remaining} + assert "CustomEpisodic" not in remaining_names, "CustomEpisodic should have been deleted" + assert "SeedStrategy" not in remaining_names, "SeedStrategy should have been deleted" + assert "CustomSemantic" in remaining_names, "CustomSemantic should still exist" + logger.info("Verified remaining strategies: %s", remaining_names) + + def test_delete_memory_and_wait(self): + """delete_memory_and_wait: create throwaway memory, delete it, verify gone.""" + memory = self.client.create_memory_and_wait( + name="%s_throwaway" % self.test_prefix, + strategies=[], + event_expiry_days=7, + memory_execution_role_arn=self.role_arn, + ) + memory_id = memory["memoryId"] + logger.info("Created throwaway memory: %s", memory_id) + + # Delete and wait + self.client.delete_memory_and_wait(memory_id=memory_id) + logger.info("delete_memory_and_wait completed for: %s", memory_id) + + # Verify it's gone — get_memory_status should raise + with pytest.raises((Exception,), match="ResourceNotFoundException|not found"): + self.client.get_memory_status(memory_id) + logger.info("Confirmed memory %s is deleted (get_memory_status raised)", memory_id) + + # ------------------------------------------------------------------ + # Assertion helpers + # ------------------------------------------------------------------ + + @staticmethod + def assert_created_event(actual: dict[str, object]) -> None: + """Assert a create_event/fork_conversation response has eventId and eventTimestamp.""" + assert "eventId" in actual, "Event must have eventId" + assert "eventTimestamp" in actual, "Event must have eventTimestamp" + + @staticmethod + def assert_listed_event( + actual: dict[str, object], + expected: list[tuple[str, str]], + ) -> None: + """Assert an event from list_events has full structure and matches expected messages.""" + assert "eventId" in actual, "Event must have eventId" + assert "eventTimestamp" in actual, "Event must have eventTimestamp" + assert "payload" in actual, "Listed event must have payload" + assert len(actual["payload"]) == len(expected), "Expected %d payload items, got %d" % ( + len(expected), + len(actual["payload"]), + ) + for i, (text, role) in enumerate(expected): + msg = actual["payload"][i]["conversational"] + assert msg["role"] == role, "Message %d role: expected %s, got %s" % (i, role, msg["role"]) + assert msg["content"]["text"] == text, "Message %d text: expected %r, got %r" % ( + i, + text, + msg["content"]["text"], + ) + + @staticmethod + def assert_blob_event(actual: dict[str, object], expected_blob: dict) -> None: + """Assert an event has a valid blob payload matching expected data.""" + assert "eventId" in actual, "Event must have eventId" + assert "eventTimestamp" in actual, "Event must have eventTimestamp" + assert "payload" in actual, "Blob event must have payload" + assert len(actual["payload"]) == 1, "Blob event should have 1 payload item" + assert "blob" in actual["payload"][0], "Payload item should contain a blob key" + actual_blob = str(actual["payload"][0]["blob"]) + + # Service returns blob as a stringified Java-style map; check all leaf values are present + def _leaf_values(obj): + if isinstance(obj, dict): + for v in obj.values(): + yield from _leaf_values(v) + else: + yield obj + + for leaf in _leaf_values(expected_blob): + assert str(leaf) in actual_blob, "Blob missing leaf value %r in %s" % (leaf, actual_blob) + + @staticmethod + def assert_tree_has_branch( + tree: dict[str, object], + branch_name: str, + min_main_events: int = 1, + min_branch_events: int = 1, + ) -> None: + """Assert a conversation tree has the expected structure and branch.""" + assert "main_branch" in tree, "Tree must have a main_branch" + main = tree["main_branch"] + assert "events" in main, "main_branch must have events" + assert len(main["events"]) >= min_main_events, "main_branch should have >= %d events, got %d" % ( + min_main_events, + len(main["events"]), + ) + assert "branches" in main, "main_branch must have branches" + assert branch_name in main["branches"], "Branch %r not in tree" % branch_name + branch = main["branches"][branch_name] + assert len(branch["events"]) >= min_branch_events, "Branch %r should have >= %d events, got %d" % ( + branch_name, + min_branch_events, + len(branch["events"]), ) - test_namespace_wildcards(client, memory_id) - - logger.info("\n%s", "=" * 80) - logger.info("ALL ISSUE TESTS COMPLETED") - logger.info("=" * 80) - - logger.info("\nSummary:") - logger.info("āœ“ Issue #1: list_events() method now available") - logger.info("āœ“ Issue #2: All strategy operations use polling") - logger.info("āœ“ Issue #3: get_last_k_turns() returns correct turns") - logger.info("āœ“ Issue #4: Wildcard limitation documented - use exact namespaces or prefixes instead") - except Exception as e: - logger.exception("Test failed: %s", e) - finally: - logger.info("\nCleaning up test memory...") - try: - client.delete_memory_and_wait(memory_id) - logger.info("āœ“ Test memory deleted") - except Exception as e: - logger.error("Failed to delete test memory: %s", e) + @staticmethod + def assert_turns( + actual: list[list[dict[str, object]]], + expected: list[tuple[str, str]], + ) -> None: + """Assert flattened turn messages match expected (text, role) pairs.""" + messages = [msg for turn in actual for msg in turn] + assert len(messages) == len(expected), "Expected %d messages, got %d" % (len(expected), len(messages)) + for i, (text, role) in enumerate(expected): + assert messages[i]["role"] == role, "Message %d role: expected %s, got %s" % (i, role, messages[i]["role"]) + assert messages[i]["content"]["text"] == text, "Message %d text: expected %r, got %r" % ( + i, + text, + messages[i]["content"]["text"], + ) + + @staticmethod + def assert_memory_active(result: dict[str, object]) -> None: + """Assert a memory response shows ACTIVE status.""" + assert result["status"] == "ACTIVE", "Expected ACTIVE, got %s" % result["status"] + + @staticmethod + def assert_strategies_by_name(strategies: list[dict[str, object]], expected_names: list[str]) -> None: + """Assert a list of strategies contains exactly the expected names.""" + actual = {s["name"] for s in strategies} + for name in expected_names: + assert name in actual, "Missing strategy: %s (have %s)" % (name, actual) + assert len(strategies) == len(expected_names), "Expected %d strategies, got %d: %s" % ( + len(expected_names), + len(strategies), + actual, + ) if __name__ == "__main__": - main() + pytest.main(["-xvs", "test_memory_client.py"]) diff --git a/uv.lock b/uv.lock index fe796d49..4c7d8637 100644 --- a/uv.lock +++ b/uv.lock @@ -252,6 +252,7 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "pytest-xdist" }, { name = "ruff" }, { name = "strands-agents" }, { name = "strands-agents-evals" }, @@ -283,6 +284,7 @@ dev = [ { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-asyncio", specifier = ">=0.24.0" }, { name = "pytest-cov", specifier = ">=6.0.0" }, + { name = "pytest-xdist", specifier = ">=3.5.0" }, { name = "ruff", specifier = ">=0.12.0" }, { name = "strands-agents", specifier = ">=1.18.0" }, { name = "strands-agents-evals", specifier = ">=0.1.0" }, @@ -630,6 +632,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674, upload-time = "2025-05-10T17:42:49.33Z" }, ] +[[package]] +name = "execnet" +version = "2.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bf/89/780e11f9588d9e7128a3f87788354c7946a9cbb1401ad38a48c4db9a4f07/execnet-2.1.2.tar.gz", hash = "sha256:63d83bfdd9a23e35b9c6a3261412324f964c2ec8dcd8d3c6916ee9373e0befcd", size = 166622, upload-time = "2025-11-12T09:56:37.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/84/02fc1827e8cdded4aa65baef11296a9bbe595c474f0d6d758af082d849fd/execnet-2.1.2-py3-none-any.whl", hash = "sha256:67fba928dd5a544b783f6056f449e5e3931a5c378b128bc18501f7ea79e296ec", size = 40708, upload-time = "2025-11-12T09:56:36.333Z" }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -1768,6 +1779,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bc/16/4ea354101abb1287856baa4af2732be351c7bee728065aed451b678153fd/pytest_cov-6.2.1-py3-none-any.whl", hash = "sha256:f5bc4c23f42f1cdd23c70b1dab1bbaef4fc505ba950d53e0081d0730dd7e86d5", size = 24644, upload-time = "2025-06-12T10:47:45.932Z" }, ] +[[package]] +name = "pytest-xdist" +version = "3.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/b4/439b179d1ff526791eb921115fca8e44e596a13efeda518b9d845a619450/pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1", size = 88069, upload-time = "2025-07-01T13:30:59.346Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"