From 4607517d38fc515afa9848c60a980fd464708e96 Mon Sep 17 00:00:00 2001 From: Charles Green Date: Thu, 7 May 2026 10:40:16 +0900 Subject: [PATCH] feat(rag): add JSON and template augmentation strategies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The augmentInput step combined retrieved documents with the user query using a single hard-coded prepend format ("Context:\n…\n\nQuery:\n…"). That works for plain-text chat generators but not for tool-calling agents that expect structured input or for callers who want to control the exact prompt scaffolding. Adds an AugmentationStrategy enum with three modes: - AugmentPrepend (default, unchanged behavior) - AugmentJSON: emits {"context": "...", "query": "..."} - AugmentTemplate: renders a user-supplied text/template with {{.Context}} and {{.Query}} fields Two new options on RAG: - WithAugmentationStrategy(s) — pick a built-in strategy - WithAugmentationTemplate(tmpl) — supplies a template and implicitly flips the strategy to AugmentTemplate augmentInput moves from a package-level helper to a method on *RAG so it can dispatch on the orchestrator's configured strategy. The Execute path now propagates augmentation errors (e.g. malformed JSON metadata, missing template) instead of silently producing the wrong prompt. TestAugmentInput grows from 5 → 8 subtests covering JSON output, custom template rendering, and the misconfiguration error path. All other RAG tests pass unchanged. Closes #56 --- docs/PATTERNS.md | 26 ++++++ internal/orchestration/rag.go | 125 ++++++++++++++++++++++++----- internal/orchestration/rag_test.go | 68 +++++++++++++++- 3 files changed, 199 insertions(+), 20 deletions(-) diff --git a/docs/PATTERNS.md b/docs/PATTERNS.md index 28e49d4..47bbbec 100644 --- a/docs/PATTERNS.md +++ b/docs/PATTERNS.md @@ -671,6 +671,32 @@ rag := orchestration.NewRAG( result, _ := rag.Execute(ctx, userQuestion) ``` +**Augmentation Strategies**: + +The retrieved documents are combined with the original query before being +passed to the generator. Three built-in strategies are available — pick the +one that matches your generator agent's expected input format: + +| Strategy | Output | When to use | +|----------|--------|-------------| +| `AugmentPrepend` (default) | `Context:\n\n\nQuery:\n` | Plain-text generators / chat models | +| `AugmentJSON` | `{"context": "...", "query": "..."}` | Tool-calling agents that parse structured input | +| `AugmentTemplate` | User-supplied `text/template` with `{{.Context}}` and `{{.Query}}` | Custom prompt templates with system-specific scaffolding | + +```go +// JSON augmentation for a structured generator +rag := orchestration.NewRAG("qa", runtime, "retriever", "generator", + orchestration.WithAugmentationStrategy(orchestration.AugmentJSON), +) + +// Custom template +rag := orchestration.NewRAG("qa", runtime, "retriever", "generator", + orchestration.WithAugmentationTemplate( + "Use the following context to answer:\n{{.Context}}\n\nQ: {{.Query}}\nA:", + ), +) +``` + **Metrics Tracked**: - Retrieval precision/recall - Context usage (% of retrieved context used in answer) diff --git a/internal/orchestration/rag.go b/internal/orchestration/rag.go index 03653ab..3cb6820 100644 --- a/internal/orchestration/rag.go +++ b/internal/orchestration/rag.go @@ -6,6 +6,7 @@ import ( "fmt" "maps" "strings" + "text/template" "time" "github.com/aixgo-dev/aixgo/internal/agent" @@ -15,6 +16,23 @@ import ( "go.opentelemetry.io/otel/trace" ) +// AugmentationStrategy controls how the query and retrieved documents are +// combined before being passed to the generator agent. +type AugmentationStrategy int + +const ( + // AugmentPrepend prepends retrieved context to the query (default). + // Renders as: "Context:\n\n\nQuery:\n" + AugmentPrepend AugmentationStrategy = iota + // AugmentJSON encodes the augmented input as a JSON object so the + // generator can address each field independently: + // {"context": "", "query": ""} + AugmentJSON + // AugmentTemplate renders a user-supplied text/template with the + // fields .Context and .Query. Use WithAugmentationTemplate to set it. + AugmentTemplate +) + // RAG implements Retrieval-Augmented Generation pattern. // Retrieves relevant documents from a vector store, then generates grounded answers. // Most common enterprise pattern for chatbots and Q&A systems. @@ -26,15 +44,17 @@ import ( // - Context-aware generation type RAG struct { *BaseOrchestrator - retriever string // Agent that retrieves relevant documents - generator string // Agent that generates the answer - topK int // Number of documents to retrieve - rerank bool // Whether to rerank retrieved documents - reranker string // Optional reranker agent - conversationHist []ConversationTurn // For conversational RAG - historyAgent string // Agent for managing history - queryExpander string // For multi-query RAG - keywordRetriever string // For hybrid RAG + retriever string // Agent that retrieves relevant documents + generator string // Agent that generates the answer + topK int // Number of documents to retrieve + rerank bool // Whether to rerank retrieved documents + reranker string // Optional reranker agent + conversationHist []ConversationTurn // For conversational RAG + historyAgent string // Agent for managing history + queryExpander string // For multi-query RAG + keywordRetriever string // For hybrid RAG + augmentStrategy AugmentationStrategy + augmentTemplate *template.Template // Compiled template for AugmentTemplate } // ConversationTurn represents a single turn in conversation history @@ -62,6 +82,35 @@ func WithReranker(reranker string) RAGOption { } } +// WithAugmentationStrategy selects how the query and retrieved documents are +// combined before being passed to the generator. AugmentPrepend (the default) +// produces a "Context:\n…\n\nQuery:\n…" string; AugmentJSON wraps the two +// fields as a JSON object; AugmentTemplate renders a text/template (set via +// WithAugmentationTemplate, which also flips the strategy to AugmentTemplate). +func WithAugmentationStrategy(s AugmentationStrategy) RAGOption { + return func(r *RAG) { + r.augmentStrategy = s + } +} + +// WithAugmentationTemplate supplies a text/template used to format augmented +// input. The template may reference {{.Context}} and {{.Query}} fields. +// Setting a template implicitly switches the strategy to AugmentTemplate. +// Returns the option as-is on parse failure; the error surfaces at first +// Execute call so configuration mistakes aren't lost. +func WithAugmentationTemplate(tmpl string) RAGOption { + return func(r *RAG) { + // Parse failures will be surfaced at Execute time via augmentInput. + parsed, err := template.New("rag_augment").Parse(tmpl) + if err != nil { + r.augmentTemplate = nil + } else { + r.augmentTemplate = parsed + } + r.augmentStrategy = AugmentTemplate + } +} + // NewRAG creates a new RAG orchestrator func NewRAG(name string, runtime agent.Runtime, retriever, generator string, opts ...RAGOption) *RAG { r := &RAG{ @@ -153,7 +202,11 @@ func (r *RAG) Execute(ctx context.Context, input *agent.Message) (*agent.Message } // Step 3: Generate answer with retrieved context - augmentedInput := augmentInput(input, documents) + augmentedInput, err := r.augmentInput(input, documents) + if err != nil { + span.RecordError(err) + return nil, fmt.Errorf("augmentation failed: %w", err) + } generateStart := time.Now() result, err := r.runtime.Call(ctx, r.generator, augmentedInput) @@ -395,21 +448,55 @@ func (r *RAG) hybridRetrieve(ctx context.Context, input *agent.Message) (*agent. }, nil } -// augmentInput combines the original query with retrieved documents -func augmentInput(query, documents *agent.Message) *agent.Message { +// augmentInput combines the original query with retrieved documents using the +// orchestrator's configured AugmentationStrategy. When no documents were +// retrieved (nil or empty payload), the original query is returned unchanged +// so that the generator can still respond — falling through to "ungrounded" +// generation is preferable to failing the whole pipeline on a sparse retriever. +func (r *RAG) augmentInput(query, documents *agent.Message) (*agent.Message, error) { if query == nil || query.Message == nil { - return query + return query, nil } if documents == nil || documents.Message == nil || documents.Payload == "" { // No documents retrieved, return original query - return query + return query, nil + } + + var augmentedPayload string + switch r.augmentStrategy { + case AugmentJSON: + buf, err := json.Marshal(struct { + Context string `json:"context"` + Query string `json:"query"` + }{ + Context: documents.Payload, + Query: query.Payload, + }) + if err != nil { + return nil, fmt.Errorf("encode json augmentation: %w", err) + } + augmentedPayload = string(buf) + case AugmentTemplate: + if r.augmentTemplate == nil { + return nil, fmt.Errorf("augmentation strategy is template but no template was configured (use WithAugmentationTemplate)") + } + var buf strings.Builder + err := r.augmentTemplate.Execute(&buf, struct { + Context string + Query string + }{ + Context: documents.Payload, + Query: query.Payload, + }) + if err != nil { + return nil, fmt.Errorf("render augmentation template: %w", err) + } + augmentedPayload = buf.String() + default: // AugmentPrepend + augmentedPayload = fmt.Sprintf("Context:\n%s\n\nQuery:\n%s", documents.Payload, query.Payload) } - // Create augmented message with both query and retrieved context - // Format: "Context:\n{documents}\n\nQuery:\n{query}" - augmentedPayload := fmt.Sprintf("Context:\n%s\n\nQuery:\n%s", documents.Payload, query.Payload) - // Preserve metadata from both messages metadata := make(map[string]any) if query.Metadata != nil { @@ -427,7 +514,7 @@ func augmentInput(query, documents *agent.Message) *agent.Message { Timestamp: query.Timestamp, Metadata: metadata, }, - } + }, nil } // RAG variants diff --git a/internal/orchestration/rag_test.go b/internal/orchestration/rag_test.go index 103d374..57bd43b 100644 --- a/internal/orchestration/rag_test.go +++ b/internal/orchestration/rag_test.go @@ -2,6 +2,7 @@ package orchestration import ( "context" + "encoding/json" "strings" "testing" "time" @@ -156,8 +157,10 @@ func TestRAGPattern(t *testing.T) { func TestAugmentInput(t *testing.T) { tests := []struct { name string + opts []RAGOption query *agent.Message documents *agent.Message + wantErr bool wantType string checkFunc func(t *testing.T, result *agent.Message) }{ @@ -267,11 +270,74 @@ func TestAugmentInput(t *testing.T) { } }, }, + { + name: "json strategy emits parseable JSON", + opts: []RAGOption{WithAugmentationStrategy(AugmentJSON)}, + query: &agent.Message{ + Message: &pb.Message{Payload: "What is AI?"}, + }, + documents: &agent.Message{ + Message: &pb.Message{Payload: "AI is artificial intelligence"}, + }, + wantType: "rag_augmented", + checkFunc: func(t *testing.T, result *agent.Message) { + var got struct { + Context string `json:"context"` + Query string `json:"query"` + } + if err := json.Unmarshal([]byte(result.Payload), &got); err != nil { + t.Fatalf("payload is not valid JSON: %v\npayload: %s", err, result.Payload) + } + if got.Query != "What is AI?" { + t.Errorf("query field = %q, want %q", got.Query, "What is AI?") + } + if got.Context != "AI is artificial intelligence" { + t.Errorf("context field = %q, want %q", got.Context, "AI is artificial intelligence") + } + }, + }, + { + name: "template strategy renders user template", + opts: []RAGOption{ + WithAugmentationTemplate("Use the following context to answer:\n{{.Context}}\n\nQuestion: {{.Query}}\nAnswer:"), + }, + query: &agent.Message{ + Message: &pb.Message{Payload: "What is AI?"}, + }, + documents: &agent.Message{ + Message: &pb.Message{Payload: "AI is artificial intelligence"}, + }, + wantType: "rag_augmented", + checkFunc: func(t *testing.T, result *agent.Message) { + want := "Use the following context to answer:\nAI is artificial intelligence\n\nQuestion: What is AI?\nAnswer:" + if result.Payload != want { + t.Errorf("payload mismatch\n got: %q\n want: %q", result.Payload, want) + } + }, + }, + { + name: "template strategy without template configured returns error", + opts: []RAGOption{WithAugmentationStrategy(AugmentTemplate)}, + query: &agent.Message{ + Message: &pb.Message{Payload: "q"}, + }, + documents: &agent.Message{ + Message: &pb.Message{Payload: "d"}, + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := augmentInput(tt.query, tt.documents) + r := NewRAG("test", NewMockRuntime(), "retriever", "generator", tt.opts...) + result, err := r.augmentInput(tt.query, tt.documents) + if (err != nil) != tt.wantErr { + t.Fatalf("augmentInput() err = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr { + return + } if tt.wantType != "" && result != nil && result.Type != tt.wantType { t.Errorf("Type = %s, want %s", result.Type, tt.wantType) }