-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathpostgres_copilot_chat.py
More file actions
1399 lines (1201 loc) · 84.9 KB
/
postgres_copilot_chat.py
File metadata and controls
1399 lines (1201 loc) · 84.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import asyncio
import asyncio
import os
import sys
import subprocess # Keep for StdioServerParameters if server is run as subprocess
# from dotenv import load_dotenv # Will be handled by config_manager
from typing import Optional, Any, Dict, Tuple
from pathlib import Path # Added
from contextlib import AsyncExitStack
import datetime
import litellm
from pydantic import ValidationError # For catching Pydantic errors
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.types import Tool as McpTool
import copy
import re
import json
import memory_module
import initialization_module
import sql_generation_module
import insights_module
import revision_insights_module # Added
import database_navigation_module
import revise_query_module # Added
import vector_store_module # Added for RAG
import model_change_module # Added for /change_model
import error_handler_module
import token_utils
from token_logging_module import log_token_usage
from colorama import Fore, Style, init as colorama_init # Added for Colorama
import config_manager # Changed to absolute import
import inspect
from pydantic_models import ( # Changed to absolute import
SQLGenerationResponse,
FeedbackReportContentModel,
FeedbackIteration,
RevisionReportContentModel, # Added
RevisionIteration # Added
)
# --- Configuration Setup ---
# Configuration, including API keys and paths, is now handled by config_manager.py
# The get_app_config() call in main() will ensure this is loaded/set up.
# LiteLLM doesn't require a global configure call.
# API keys are set as environment variables by config_manager.py based on user's choice.
# LiteLLM will automatically use the appropriate credentials based on the model ID prefix.
class LiteLLMMcpClient:
"""A client that connects to an MCP server and uses LiteLLM for interaction."""
def __init__(self, app_config: dict, system_instruction: Optional[str] = None): # Added app_config
self.session: Optional[ClientSession] = None
self.exit_stack = AsyncExitStack()
self.app_config = app_config # Store app_config
self.model_name = self.app_config.get("model_id") # Get from app_config
self.provider = self.app_config.get("llm_provider")
self.llm_api_key = self.app_config.get("api_key") # Store for potential direct use if needed
# Determine model provider based on model name prefix
# Provider name from config_data['llm_provider'] can also be used directly
self.model_provider = self.app_config.get("llm_provider", "Unknown").capitalize()
# Fallback logic if llm_provider not in config, or to be more specific
if self.model_name.startswith("gemini/"):
self.model_provider = "Google AI (Gemini)"
elif self.model_name.startswith("gpt-") or self.model_name.startswith("openai/"):
self.model_provider = "OpenAI"
elif self.model_name.startswith("bedrock/"):
self.model_provider = "AWS Bedrock"
if "claude" in self.model_name.lower():
self.model_provider += " (Anthropic Claude)"
elif self.model_name.startswith("anthropic/") or self.model_name.startswith("claude"):
self.model_provider = "Anthropic"
elif self.model_name.startswith("ollama/"):
self.model_provider = "Ollama (Local)"
self.system_instruction_content = system_instruction
self.conversation_history: list[Dict[str, Any]] = []
if self.system_instruction_content:
self.conversation_history.append({"role": "system", "content": self.system_instruction_content})
# RAG and Display Thresholds are now hardcoded in vector_store_module.
# Removing client-level attributes for these.
# self.rag_similarity_threshold = vector_store_config.LITELLM_RAG_THRESHOLD (Removed)
# self.display_similarity_threshold = vector_store_config.LITELLM_DISPLAY_THRESHOLD (Removed)
# Memory directories are ensured by memory_module on its import.
# No need for verbose printing of these paths here.
# Core state variables
self.is_initialized: bool = False
self.current_db_connection_string: Optional[str] = None
self.current_db_name_identifier: Optional[str] = None # e.g., "california_schools_db"
# Data used for context
self.db_schema_and_sample_data: Optional[Dict[str, Any]] = None # Loaded from memory/schema/
self.cumulative_insights_content: Optional[str] = None # Loaded from memory/insights/summarized_insights.md
# State for the current query/feedback cycle
self.current_natural_language_question: Optional[str] = None
# Holds the Pydantic model of the full feedback report being built/iterated on
self.current_feedback_report_content: Optional[FeedbackReportContentModel] = None
# Note: last_generated_sql, last_sql_explanation etc. are now part of current_feedback_report_content
# State for query revision feature
self.current_revision_report_content: Optional[RevisionReportContentModel] = None
self.is_in_revision_mode: bool = False
self.feedback_used_in_current_revision_cycle: bool = False
self.feedback_log_in_revision: list[Dict[str, str]] = []
self.active_table_scope: Optional[List[str]] = None
self.table_categories: Optional[Dict[str, List[str]]] = None
# The first_run concept is now implicitly handled by config_manager.get_app_config()
# which triggers initial_setup() if config.json doesn't exist.
# We can remove _get_first_run_flag_path and _check_and_set_first_run.
# The welcome message can be shown if initial_setup was triggered.
async def _cleanup_database_session(self, full_cleanup: bool = True):
# print("Cleaning up database session state...") # User doesn't need to see this
if full_cleanup:
self.is_initialized = False
self.current_db_connection_string = None
self.current_db_name_identifier = None
self.db_schema_and_sample_data = None
self.cumulative_insights_content = None
self._reset_feedback_cycle_state()
self._reset_revision_cycle_state() # Added
# print("Database session state cleaned.")
def _reset_feedback_cycle_state(self):
"""Resets state for a new natural language query and its feedback cycle."""
# print("Resetting feedback cycle state for new SQL generation...") # Internal detail
self.current_natural_language_question = None
self.current_feedback_report_content = None
# If starting a new feedback cycle, revision mode should also reset
self._reset_revision_cycle_state()
def _reset_revision_cycle_state(self):
"""Resets state for a new query revision cycle."""
# print("Resetting revision cycle state...") # Internal detail
self.current_revision_report_content = None
self.is_in_revision_mode = False
self.feedback_used_in_current_revision_cycle = False
self.feedback_log_in_revision = []
def _extract_connection_string_and_db_name(self, query: str) -> Tuple[Optional[str], Optional[str]]:
# Extracts connection string and attempts to derive a db_name from it.
# Example: postgresql://user:password@host:port/dbname
conn_str_match = re.search(r"(postgresql://\S+:\S+@\S+:\d+/(\S+))", query)
if conn_str_match:
full_conn_str = conn_str_match.group(1)
db_name_part = conn_str_match.group(2)
# Sanitize db_name_part to be a valid filename component
sanitized_db_name = re.sub(r'[^\w\-_\.]', '_', db_name_part) if db_name_part else "unknown_db"
return full_conn_str, sanitized_db_name
return None, None
def _extract_mcp_tool_call_output(self, tool_call_result: Any, tool_name: str) -> Any:
if not self.session: return "Error: MCP session not available."
output = None
if hasattr(tool_call_result, 'content') and isinstance(tool_call_result.content, list) and \
len(tool_call_result.content) > 0 and hasattr(tool_call_result.content[0], 'text') and \
tool_call_result.content[0].text is not None:
output = tool_call_result.content[0].text
elif hasattr(tool_call_result, 'output'): output = tool_call_result.output
elif hasattr(tool_call_result, 'result'): output = tool_call_result.result
elif tool_call_result is not None:
if not isinstance(tool_call_result, (str, int, float, bool, list, dict)):
try: output = str(tool_call_result)
except Exception as e: output = f"Error: Tool output format not recognized and failed to convert to string. Type: {type(tool_call_result)}, Error: {e}"
else: output = tool_call_result
else:
output = "Error: Tool output not found or format not recognized."
# Log the raw response
error_handler_module.log_mcp_response(tool_name, output)
return output
async def connect_to_mcp_server(self, server_script_path: Path): # Changed type hint to Path
# print(f"Connecting to MCP server script: {server_script_path}") # Internal detail
# server_script_path is now expected to be an absolute Path object from main()
is_python = server_script_path.suffix == '.py' # Use Path.suffix
command = sys.executable # sys.executable is a string path to the python interpreter
# Ensure args for StdioServerParameters are strings
server_args = [str(server_script_path)]
# The command to run is the python interpreter if it's a .py script
# If it were a node script, command would be 'node', etc.
# StdioServerParameters takes the command and then its arguments.
# So, if is_python, command is sys.executable, and server_args is [path_to_script.py]
server_params = StdioServerParameters(command=command, args=server_args, env=None)
try:
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
read_stream, write_stream = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(read_stream, write_stream))
await self.session.initialize()
response = await self.session.list_tools()
mcp_tools_list: list[McpTool] = response.tools
# print(f"Connected to MCP server with {len(mcp_tools_list)} tools: {[tool.name for tool in mcp_tools_list]}") # Internal detail
if not mcp_tools_list:
error_handler_module.display_message("No tools discovered from the MCP server.", level="FATAL")
self.litellm_tools = []
for mcp_tool_obj in mcp_tools_list:
tool_name = getattr(mcp_tool_obj, 'name', None)
tool_desc = getattr(mcp_tool_obj, 'description', None)
tool_schema = copy.deepcopy(getattr(mcp_tool_obj, 'inputSchema', {}))
if tool_schema:
if 'properties' not in tool_schema and tool_schema:
pass # No debug print
elif 'properties' in tool_schema:
tool_schema['type'] = 'object'
tool_schema.pop('title', None)
if 'properties' in tool_schema and isinstance(tool_schema['properties'], dict):
for prop_name in list(tool_schema['properties'].keys()):
if isinstance(tool_schema['properties'][prop_name], dict):
tool_schema['properties'][prop_name].pop('title', None)
if tool_name and tool_desc:
self.litellm_tools.append({
"type": "function",
"function": {
"name": tool_name,
"description": tool_desc,
"parameters": tool_schema or {"type": "object", "properties": {}}
}
})
else:
error_handler_module.display_message(f"Skipping MCP tool '{tool_name}' due to missing attributes for LiteLLM conversion.", level="WARNING")
if not self.litellm_tools and mcp_tools_list:
error_handler_module.display_message("No MCP tools could be converted to LiteLLM format.", level="FATAL")
# print(f"LiteLLM client configured to use model '{self.model_name}' with {len(self.litellm_tools)} tools.") # Internal detail
except Exception as e:
error_handler_module.display_message(f"Fatal error during MCP server connection or LiteLLM setup: {e}. Please ensure the MCP server script is correct and executable.", level="FATAL")
await self.cleanup()
async def _send_message_to_llm(self, messages: list, user_query: str, schema_tokens: int = 0, tools: Optional[list] = None, tool_choice: str = "auto",
response_format: Optional[dict] = None) -> Any:
"""Sends messages to LiteLLM and handles response, including tool calls."""
try:
# Get caller info
frame = inspect.currentframe()
caller_frame = frame.f_back
origin_script = caller_frame.f_code.co_filename
origin_line = caller_frame.f_lineno
# Prepare kwargs for acompletion
kwargs = {
"model": self.model_name,
"messages": messages,
}
# Add provider-specific credentials from app_config
provider = self.app_config.get("llm_provider")
if provider == "bedrock":
kwargs["aws_access_key_id"] = self.app_config.get("access_key_id")
kwargs["aws_secret_access_key"] = self.app_config.get("secret_access_key")
kwargs["aws_region_name"] = self.app_config.get("region")
else:
# Default to using api_key for other providers
kwargs["api_key"] = self.app_config.get("api_key")
# Add tools if provided
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = tool_choice
# Add response_format if provided (for JSON responses with OpenAI models)
if response_format and (self.model_name.startswith("gpt-") or self.model_name.startswith("openai/")):
kwargs["response_format"] = response_format
try:
response = await litellm.acompletion(**kwargs)
# Log token usage on success
prompt_text = messages[-1]['content']
prompt_tokens = token_utils.count_tokens(prompt_text, self.model_name, self.provider)
input_tokens = response.usage.prompt_tokens
output_tokens = response.usage.completion_tokens
llm_response = response.choices[0].message.content
log_token_usage(
origin_script=origin_script,
origin_line=origin_line,
user_query=user_query,
prompt=prompt_text,
prompt_tokens=prompt_tokens,
schema_tokens=schema_tokens,
input_tokens=input_tokens,
output_tokens=output_tokens,
llm_response=llm_response,
model_id=self.model_name
)
# Add user message to history (assistant response added after processing)
# The last message in `messages` is the current user prompt
if messages[-1]["role"] == "user":
self.conversation_history.append(messages[-1])
return response
except litellm.RateLimitError as e:
# Log token usage on rate limit error
prompt_text = messages[-1]['content']
prompt_tokens = token_utils.count_tokens(prompt_text, self.model_name, self.provider)
log_token_usage(
origin_script=origin_script,
origin_line=origin_line,
user_query=user_query,
prompt=prompt_text,
prompt_tokens=prompt_tokens,
schema_tokens=schema_tokens,
input_tokens=prompt_tokens, # Input tokens are the prompt tokens
output_tokens=0, # No output tokens
llm_response=f"RateLimitError: {e}",
model_id=self.model_name
)
raise e # Re-raise the exception to be handled by the caller
except Exception as e:
# Do not log here for other exceptions; the caller is responsible for handling and logging.
# Just add a placeholder to history and re-raise.
if messages[-1]["role"] == "user" and messages[-1] not in self.conversation_history:
self.conversation_history.append(messages[-1])
self.conversation_history.append({"role": "assistant", "content": "I'm having trouble processing your request."})
raise # Re-raise the exception to be handled by the caller
async def _process_llm_response(self, llm_response: Any) -> Tuple[str, bool]:
"""Processes LiteLLM response, handles tool calls, and returns text response and if a tool was called."""
assistant_response_content = ""
tool_calls_made = False
if not llm_response or not llm_response.choices or not llm_response.choices[0].message:
error_handler_module.display_message("Empty or invalid response from LLM.", level="ERROR")
assistant_response_content = "Error: Empty or invalid response from LLM." # Keep for history
self.conversation_history.append({"role": "assistant", "content": assistant_response_content})
return assistant_response_content, tool_calls_made
message = llm_response.choices[0].message
# Storing the raw assistant message (potentially with tool_calls)
assistant_message_for_history = {"role": "assistant"}
if message.content:
assistant_message_for_history["content"] = message.content
assistant_response_content = message.content # Initial text part
if hasattr(message, 'tool_calls') and message.tool_calls:
tool_calls_made = True
assistant_message_for_history["tool_calls"] = [] # For history
# LiteLLM returns tool_calls in OpenAI format
# [{"id": "call_abc", "type": "function", "function": {"name": "tool_name", "arguments": "{...}"}}]
tool_call_responses_for_next_llm_call = []
for tool_call in message.tool_calls:
tool_call_id = tool_call.id
tool_name = tool_call.function.name
tool_args_str = tool_call.function.arguments
# Add the tool call itself to history
assistant_message_for_history["tool_calls"].append({
"id": tool_call_id,
"type": "function",
"function": {"name": tool_name, "arguments": tool_args_str}
})
try:
tool_args = json.loads(tool_args_str)
mcp_tool_result_obj = await self.session.call_tool(tool_name, tool_args)
tool_output = self._extract_mcp_tool_call_output(mcp_tool_result_obj, tool_name)
except json.JSONDecodeError as e_json:
tool_output = f"Error: Invalid JSON arguments for tool {tool_name}: {e_json}. Arguments received: {tool_args_str}"
error_handler_module.display_message(f"MCP Tool Error: Invalid JSON arguments for tool {tool_name}: {e_json}. Args: {tool_args_str}", level="ERROR")
except Exception as e_tool:
tool_output = f"Error executing tool {tool_name}: {e_tool}"
error_handler_module.display_message(f"Error executing tool {tool_name}: {e_tool}", level="ERROR")
tool_call_responses_for_next_llm_call.append({
"tool_call_id": tool_call_id,
"role": "tool",
"name": tool_name,
"content": str(tool_output) # Ensure content is string
})
# Add the assistant's message (that included the tool call request) to history
self.conversation_history.append(assistant_message_for_history)
# Add all tool responses to history
for resp in tool_call_responses_for_next_llm_call:
self.conversation_history.append(resp)
# Make a follow-up call to LLM with tool responses
# print("Sending tool responses back to LLM...") # Debug
# The conversation history already contains:
# ..., user_prompt, assistant_tool_call_request, tool_response_1, ...
# So we can just send the current self.conversation_history
follow_up_llm_response = await self._send_message_to_llm(self.conversation_history, user_query, tools=self.litellm_tools)
# Process this new response (which should be the final text from LLM after tools)
# This recursive call is safe as long as LLM doesn't loop infinitely on tool calls
assistant_response_content, _ = await self._process_llm_response(follow_up_llm_response)
# The _process_llm_response will handle adding the final assistant message to history.
else: # No tool calls, just a direct text response
if message.content: # Ensure there's content
assistant_response_content = message.content
# Add assistant's direct response to history
self.conversation_history.append({"role": "assistant", "content": assistant_response_content})
else: # Should not happen if the first check passed, but as a safeguard
error_handler_module.display_message("LLM response message has no content.", level="ERROR")
assistant_response_content = "Error: LLM response message has no content." # Keep for history
self.conversation_history.append({"role": "assistant", "content": assistant_response_content})
return assistant_response_content, tool_calls_made
async def _handle_initialization(self, connection_string: str, db_name_id: str):
"""Handles the full DB initialization flow."""
await self._cleanup_database_session(full_cleanup=True) # Clean slate for new DB
success, message, schema_data = await initialization_module.perform_initialization(
self, connection_string, db_name_id
)
if success:
self.is_initialized = True
self.current_db_connection_string = connection_string
self.current_db_name_identifier = db_name_id
self.db_schema_and_sample_data = schema_data
# Load cumulative insights for this specific DB if they exist
self.cumulative_insights_content = memory_module.read_insights_file(self.current_db_name_identifier)
self._reset_feedback_cycle_state() # Ensure clean state for new DB
self._reset_revision_cycle_state() # Ensure clean state for new DB
error_handler_module.display_response(message)
else:
await self._cleanup_database_session(full_cleanup=True)
error_handler_module.display_message(message, level="ERROR")
async def dispatch_command(self, query: str):
if not self.session:
error_handler_module.display_message("Client not fully initialized (MCP session missing).", level="ERROR")
return
# Flexible command parsing: command can be followed by ':' or space, then argument
command_match = re.match(r"/(\w+)(?:\s*:?\s*)(.*)", query, re.IGNORECASE)
base_command_lower = ""
argument_text = ""
if command_match:
base_command_lower = command_match.group(1).lower()
argument_text = command_match.group(2).strip()
elif query.startswith("/"): # Handles commands like /approved with no args
base_command_lower = query[1:].lower()
# If not a command starting with "/", it will be handled by implicit init or navigation query later
# Command: /change_model
if base_command_lower == "change_model":
# The new handle_change_model_interactive saves the config file itself.
# We need to reload the config if it returns success.
success, message = await model_change_module.handle_change_model_interactive(self.app_config)
if success:
print("Reloading configuration with new profile...")
# Reload the app_config from the updated file
self.app_config = config_manager.get_app_config()
# Update client's internal LLM settings from the newly loaded app_config
self.model_name = self.app_config.get("model_id")
self.llm_api_key = self.app_config.get("api_key")
new_provider_name = self.app_config.get("llm_provider", "Unknown").capitalize()
# This logic correctly re-evaluates the provider display name
if self.model_name.startswith("gemini/"): self.model_provider = "Google AI (Gemini)"
elif self.model_name.startswith("gpt-") or self.model_name.startswith("openai/"): self.model_provider = "OpenAI"
elif self.model_name.startswith("bedrock/"):
self.model_provider = "AWS Bedrock"
if "claude" in self.model_name.lower(): self.model_provider += " (Anthropic Claude)"
elif self.model_name.startswith("anthropic/") or self.model_name.startswith("claude"): self.model_provider = "Anthropic"
elif self.model_name.startswith("ollama/"): self.model_provider = "Ollama (Local)"
else: self.model_provider = new_provider_name
# Reset conversation history as the context might be irrelevant for a new model/provider
self.conversation_history = []
if self.system_instruction_content:
self.conversation_history.append({"role": "system", "content": self.system_instruction_content})
final_message = (
f"{message}\n"
f"Client updated to use new profile. Active model: {self.model_name} from {self.model_provider}.\n"
f"Conversation history has been reset."
)
error_handler_module.display_response(final_message)
else:
# Display cancellation or error message from the module
error_handler_module.display_response(message)
return
# Command: /change_profile
if base_command_lower == "change_profile":
config = config_manager.load_config()
profiles = config.get("llm_profiles", {})
if len(profiles) <= 1:
error_handler_module.display_response("Only one profile exists. Add more profiles to use this command.")
return
print("Please choose a profile to switch to:")
profile_aliases = list(profiles.keys())
for i, alias in enumerate(profile_aliases):
print(f"{i+1}. {alias}")
choice = -1
while choice < 1 or choice > len(profile_aliases):
try:
raw_choice = input(f"Enter your choice (1-{len(profile_aliases)}): ")
choice = int(raw_choice)
except ValueError:
print("Invalid input. Please enter a number.")
chosen_alias = profile_aliases[choice - 1]
config["active_llm_profile_alias"] = chosen_alias
config_manager.save_config(config)
print(f"Reloading configuration with new profile '{chosen_alias}'...")
# Manually rebuild app_config to avoid the double prompt from get_app_config()
active_profile = config.get("llm_profiles", {}).get(chosen_alias)
if not active_profile:
error_handler_module.display_message(f"Error: Could not find the selected profile '{chosen_alias}' after saving.", level="ERROR")
return
app_config = {
"memory_base_dir": config.get("memory_base_dir"),
"approved_queries_dir": config.get("approved_queries_dir"),
"nl2sql_vector_store_base_dir": config.get("nl2sql_vector_store_base_dir"),
"llm_provider": active_profile.get("provider"),
"model_id": f"{active_profile.get('provider')}/{active_profile.get('model_id')}",
"active_database_alias": config.get("active_database_alias"),
"active_database_connection_string": config.get("database_connections", {}).get(config.get("active_database_alias"))
}
credentials = active_profile.get("credentials", {})
app_config.update(credentials)
self.app_config = app_config
self.model_name = self.app_config.get("model_id")
self.llm_api_key = self.app_config.get("api_key")
new_provider_name = self.app_config.get("llm_provider", "Unknown").capitalize()
if self.model_name.startswith("gemini/"): self.model_provider = "Google AI (Gemini)"
elif self.model_name.startswith("gpt-") or self.model_name.startswith("openai/"): self.model_provider = "OpenAI"
elif self.model_name.startswith("bedrock/"):
self.model_provider = "AWS Bedrock"
if "claude" in self.model_name.lower(): self.model_provider += " (Anthropic Claude)"
elif self.model_name.startswith("anthropic/") or self.model_name.startswith("claude"): self.model_provider = "Anthropic"
elif self.model_name.startswith("ollama/"): self.model_provider = "Ollama (Local)"
else: self.model_provider = new_provider_name
self.conversation_history = []
if self.system_instruction_content:
self.conversation_history.append({"role": "system", "content": self.system_instruction_content})
final_message = (
f"Active profile switched to '{chosen_alias}'.\n"
f"Client updated to use new profile. Active model: {self.model_name} from {self.model_provider}.\n"
f"Conversation history has been reset."
)
error_handler_module.display_response(final_message)
return
# Handle initialization attempts first
elif base_command_lower == "change_database":
raw_conn_str = argument_text
if raw_conn_str.startswith('"') and raw_conn_str.endswith('"'): # Handle quoted string
raw_conn_str = raw_conn_str[1:-1]
parsed_conn_str, parsed_db_name_id = self._extract_connection_string_and_db_name(raw_conn_str)
if not parsed_conn_str:
error_handler_module.display_message("Invalid connection string format provided with /change_database. Expected: postgresql://user:pass@host:port/dbname", level="ERROR")
return
await self._handle_initialization(parsed_conn_str, parsed_db_name_id)
# Check for implicit initialization if not already initialized
elif not self.is_initialized:
parsed_conn_str, parsed_db_name_id = self._extract_connection_string_and_db_name(query) # Use original query for implicit check
if parsed_conn_str:
await self._handle_initialization(parsed_conn_str, parsed_db_name_id)
return # Exit after handling implicit initialization
else:
# If not initialized and not an attempt to initialize (via /change_database or raw string),
# prompt specifically for connection.
error_handler_module.display_message("Database not initialized. Please provide a connection string (e.g., postgresql://user:pass@host:port/dbname) or use '/change_database {connection_string}' to connect.", level="ERROR")
return # Exit after showing the error
# If we reach here, self.is_initialized must be True.
# Perform a redundant check just in case, though the logic above should ensure it.
elif not self.is_initialized or not self.current_db_name_identifier:
# This state should ideally not be reached if the above logic is correct.
error_handler_module.display_message("Critical Error: Database initialization state is inconsistent. Please try /change_database again or provide a connection string.", level="ERROR")
return
# Command: /reload_scope
elif base_command_lower == "reload_scope":
if not self.is_initialized or not self.current_db_name_identifier:
error_handler_module.display_message("Please connect to a database first with /change_database.", level="ERROR")
return
print("Reloading database scope...")
filtered_schema, message = initialization_module.load_and_filter_schema(self.current_db_name_identifier)
if filtered_schema is not None:
self.db_schema_and_sample_data = filtered_schema
error_handler_module.display_response(message)
else:
error_handler_module.display_message(f"Failed to reload scope: {message}", level="ERROR")
return
# Command: /generate_sql
elif base_command_lower == "generate_sql":
nl_question = argument_text
if not nl_question:
error_handler_module.display_message("Please provide a natural language question after /generate_sql.", level="ERROR")
return
self._reset_feedback_cycle_state()
self._reset_revision_cycle_state()
self.current_natural_language_question = nl_question
print("Generating SQL, please wait...")
sql_gen_result_dict = await sql_generation_module.generate_sql_query(
self, nl_question, self.db_schema_and_sample_data, self.cumulative_insights_content,
row_limit_for_preview=1 # Ensure 1 row for preview from sql_generation_module
)
if sql_gen_result_dict.get("sql_query"):
self.current_feedback_report_content = FeedbackReportContentModel(
natural_language_question=nl_question,
initial_sql_query=sql_gen_result_dict["sql_query"],
initial_explanation=sql_gen_result_dict.get("explanation", "N/A"),
final_corrected_sql_query=sql_gen_result_dict["sql_query"],
final_explanation=sql_gen_result_dict.get("explanation", "N/A")
)
else:
self.current_feedback_report_content = None
base_message_to_user = sql_gen_result_dict.get("message_to_user", "Error: No message from SQL generation.")
# Append execution result or error to the message_to_user
exec_result = sql_gen_result_dict.get("execution_result")
exec_error = sql_gen_result_dict.get("execution_error")
# If there's an execution error from the generation loop (like ContextWindowExceeded),
# it should be displayed directly, and we shouldn't proceed with other formatting.
if exec_error and not sql_gen_result_dict.get("sql_query"):
error_handler_module.display_message(exec_error, level="ERROR")
return
if exec_error:
base_message_to_user += f"\nExecution Error: {exec_error}\n"
elif exec_result is not None:
preview_str = ""
if isinstance(exec_result, dict) and exec_result.get("status") == "success":
data = exec_result.get("data")
if data and isinstance(data, list) and len(data) > 0:
preview_str = str(data[0])
elif 'message' in exec_result:
preview_str = exec_result['message']
else:
preview_str = "Query executed successfully, but no rows were returned."
else:
preview_str = str(exec_result)
if len(preview_str) > 200: # Truncate if too long
preview_str = preview_str[:197] + "..."
base_message_to_user += f"\nExecution successful. Result preview (1 row): {preview_str}\n"
# --- Augment message with display few-shot examples ---
# Ensure base_message_to_user is a string before appending.
if isinstance(base_message_to_user, str):
if self.current_db_name_identifier and self.current_natural_language_question:
try:
# Using hardcoded display threshold from vector_store_module
display_threshold_from_module = vector_store_module.LITELLM_DISPLAY_THRESHOLD
display_examples = vector_store_module.search_similar_nlqs(
db_name_identifier=self.current_db_name_identifier,
query_nlq=self.current_natural_language_question,
k=3, # Show up to 3 examples
threshold=display_threshold_from_module
)
if display_examples:
display_message_parts = ["\n\n--- Relevant Approved Examples (Similarity >= " f"{display_threshold_from_module}" ") ---"]
for i, ex in enumerate(display_examples):
display_message_parts.append(f"Example {i+1} (Similarity: {ex['similarity_score']:.2f}):")
display_message_parts.append(f" Q: \"{ex['nlq']}\"")
display_message_parts.append(f" A: ```sql\n{ex['sql']}\n```")
base_message_to_user += "\n" + "\n".join(display_message_parts)
else:
# This is not an error, just informational.
base_message_to_user += f"\n\n--- No similar approved examples found (Similarity >= {display_threshold_from_module}) ---"
except Exception as e_display_rag:
# Log the exception but don't crash the main flow.
error_handler_module.handle_exception(e_display_rag, self.current_natural_language_question, {"context": "Display RAG examples"})
base_message_to_user += "\n\n--- Could not retrieve similar examples due to an error. ---"
# --- End Augment ---
error_handler_module.display_response(base_message_to_user)
# Command: /feedback
elif base_command_lower == "feedback":
user_feedback_text = argument_text
if not user_feedback_text:
error_handler_module.display_message("Please provide your feedback text after /feedback.", level="ERROR")
return
print("Processing feedback, please wait...")
if self.is_in_revision_mode and self.current_revision_report_content and self.current_revision_report_content.final_revised_sql_query:
# Apply feedback to the current revision
current_sql = self.current_revision_report_content.final_revised_sql_query
current_explanation = self.current_revision_report_content.final_revised_explanation or "N/A"
# Prompt for correcting SQL based on feedback (within revision context)
feedback_prompt_for_revision = (
f"You are an expert PostgreSQL SQL assistant. A user is providing feedback on a previously revised SQL query.\n"
f"CURRENT REVISED SQL QUERY:\n```sql\n{current_sql}\n```\n"
f"ITS EXPLANATION:\n{current_explanation}\n\n"
f"USER FEEDBACK: \"{user_feedback_text}\"\n\n"
f"Based on this feedback, please provide a corrected SQL query (must start with SELECT) and a brief explanation for the correction.\n"
f"Respond ONLY with a single JSON object matching this structure: "
f"{{ \"sql_query\": \"<Your corrected SELECT SQL query>\", \"explanation\": \"<Your explanation for the correction>\" }}\n"
)
MAX_FEEDBACK_RETRIES_IN_REVISION = 1
corrected_sql_from_feedback = None
corrected_explanation_from_feedback = None
for attempt in range(MAX_FEEDBACK_RETRIES_IN_REVISION + 1):
try:
messages_for_llm = self.conversation_history + [{"role": "user", "content": feedback_prompt_for_revision}]
response_format = None
if self.model_name.startswith("gpt-") or self.model_name.startswith("openai/"):
response_format = {"type": "json_object"}
llm_response_obj = await self._send_message_to_llm(messages=messages_for_llm, user_query=user_feedback_text, response_format=response_format)
response_text, _ = await self._process_llm_response(llm_response_obj)
if response_text.startswith("```json"): response_text = response_text[7:]
if response_text.endswith("```"): response_text = response_text[:-3]
parsed_correction = SQLGenerationResponse.model_validate_json(response_text.strip())
if not parsed_correction.sql_query or not parsed_correction.sql_query.strip().upper().startswith("SELECT"):
raise ValueError("Corrected SQL from feedback must start with SELECT.")
corrected_sql_from_feedback = parsed_correction.sql_query
corrected_sql_from_feedback = parsed_correction.sql_query
corrected_explanation_from_feedback = parsed_correction.explanation or "N/A"
# Log feedback for potential report generation on approve_revision
self.feedback_used_in_current_revision_cycle = True
self.feedback_log_in_revision.append({
"user_feedback_text": user_feedback_text,
"sql_before_feedback": current_sql, # SQL before this feedback
"explanation_before_feedback": current_explanation, # Explanation for current_sql
"corrected_sql_attempt": corrected_sql_from_feedback, # SQL after this feedback
"corrected_explanation": corrected_explanation_from_feedback # Explanation for new SQL
})
break
except (ValidationError, json.JSONDecodeError, ValueError) as e:
if attempt == MAX_FEEDBACK_RETRIES_IN_REVISION:
error_handler_module.display_message(f"Error processing feedback on revised query: {e}. Please try rephrasing your feedback.", level="ERROR")
return
feedback_prompt_for_revision = f"Your previous attempt to correct the SQL based on feedback was invalid (Error: {e}). Please try again, ensuring the JSON output has 'sql_query' (starting with SELECT) and 'explanation'."
except Exception as e_gen:
if attempt == MAX_FEEDBACK_RETRIES_IN_REVISION:
error_handler_module.display_message(f"Unexpected error processing feedback on revised query: {e_gen}.", level="ERROR")
return
feedback_prompt_for_revision = "An unexpected error occurred. Please try to regenerate the corrected SQL and explanation based on the feedback."
if corrected_sql_from_feedback and self.current_revision_report_content:
new_iteration = RevisionIteration(
user_revision_prompt=f"Feedback: {user_feedback_text}",
revised_sql_attempt=corrected_sql_from_feedback,
revised_explanation=corrected_explanation_from_feedback
)
self.current_revision_report_content.revision_iterations.append(new_iteration)
self.current_revision_report_content.final_revised_sql_query = corrected_sql_from_feedback
self.current_revision_report_content.final_revised_explanation = corrected_explanation_from_feedback
exec_result, exec_error = None, None
try:
exec_obj = await self.session.call_tool("execute_postgres_query", {"query": corrected_sql_from_feedback, "row_limit": 1})
raw_output = self._extract_mcp_tool_call_output(exec_obj, "execute_postgres_query")
if isinstance(raw_output, str) and "Error:" in raw_output:
exec_error = raw_output
else:
exec_result = raw_output
except Exception as e_exec:
exec_error = str(e_exec)
user_msg = f"Feedback applied to revised query. New SQL attempt:\n```sql\n{corrected_sql_from_feedback}\n```\n"
user_msg += f"Explanation:\n{corrected_explanation_from_feedback}\n"
if exec_error:
user_msg += f"\nExecution Error for new SQL: {exec_error}\n"
elif exec_result is not None:
preview_str = ""
if isinstance(exec_result, list) and len(exec_result) == 1 and isinstance(exec_result[0], dict):
single_row_dict = exec_result[0]
preview_str = str(single_row_dict)
elif isinstance(exec_result, str) and exec_result.endswith(".md"):
preview_str = f"Query result saved to {os.path.basename(exec_result)}"
else:
preview_str = str(exec_result)
if len(preview_str) > 200:
preview_str = preview_str[:197] + "..."
user_msg += f"\nExecution of new SQL successful. Result preview (1 row): {preview_str}\n"
user_msg += "Use `/revise Your new prompt`, more `/feedback`, or `/approve_revision`."
error_handler_module.display_response(user_msg)
else:
error_handler_module.display_message("Failed to apply feedback to the revised query.", level="ERROR")
elif self.current_feedback_report_content and self.current_natural_language_question:
self._reset_revision_cycle_state()
current_report_json = self.current_feedback_report_content.model_dump_json(indent=2)
feedback_model_schema_dict = FeedbackReportContentModel.model_json_schema()
feedback_model_schema_str = json.dumps(feedback_model_schema_dict, indent=2)
feedback_prompt = (
f"You are refining a SQL query based on user feedback and updating a detailed report.\n"
f"The user's original question was: \"{self.current_natural_language_question}\"\n"
f"The current state of the feedback report (JSON format) is:\n```json\n{current_report_json}\n```\n"
f"The user has provided new feedback: \"{user_feedback_text}\"\n\n"
f"Your tasks:\n"
f"1. Generate a new `corrected_sql_attempt` and `corrected_explanation` based on this latest feedback and the *previous* `final_corrected_sql_query` from the report.\n"
f"2. Create a new `FeedbackIteration` object containing this `user_feedback_text`, your new `corrected_sql_attempt`, and `corrected_explanation`.\n"
f"3. Append this new `FeedbackIteration` to the `feedback_iterations` list in the report.\n"
f"4. Update the report's `final_corrected_sql_query` and `final_explanation` to your latest attempt.\n"
f"5. Re-evaluate and update the LLM analysis sections: `why_initial_query_was_wrong_or_suboptimal` (if applicable, comparing to initial), `why_final_query_works_or_is_improved` (explaining your latest correction), `database_insights_learned_from_this_query`, and `sql_lessons_learned_from_this_query` based on the *entire* history including this new iteration.\n\n"
f"Respond ONLY with the complete, updated JSON object for the `FeedbackReportContentModel`, conforming to this schema:\n"
f"```json\n{feedback_model_schema_str}\n```\n"
f"Ensure the `corrected_sql_attempt` and `final_corrected_sql_query` start with SELECT."
)
MAX_FEEDBACK_RETRIES = 1
for attempt in range(MAX_FEEDBACK_RETRIES + 1):
try:
messages_for_llm = self.conversation_history + [{"role": "user", "content": feedback_prompt}]
response_format = None
if self.model_name.startswith("gpt-") or self.model_name.startswith("openai/"):
response_format = {"type": "json_object"}
llm_response_obj = await self._send_message_to_llm(
messages=messages_for_llm,
user_query=user_feedback_text,
response_format=response_format
)
response_text, _ = await self._process_llm_response(llm_response_obj)
if response_text.startswith("```json"): response_text = response_text[7:]
if response_text.endswith("```"): response_text = response_text[:-3]
updated_report_model = FeedbackReportContentModel.model_validate_json(response_text.strip())
if not updated_report_model.final_corrected_sql_query or \
not updated_report_model.final_corrected_sql_query.strip().upper().startswith("SELECT"):
raise ValueError("Corrected SQL in feedback report must start with SELECT.")
self.current_feedback_report_content = updated_report_model
exec_result, exec_error = None, None
try:
exec_obj = await self.session.call_tool("execute_postgres_query", {"query": updated_report_model.final_corrected_sql_query, "row_limit": 1})
raw_output = self._extract_mcp_tool_call_output(exec_obj, "execute_postgres_query")
if isinstance(raw_output, str) and "Error:" in raw_output:
exec_error = raw_output
else:
exec_result = raw_output
except Exception as e_exec:
exec_error = str(e_exec)
user_msg = f"Feedback processed. New SQL attempt:\n```sql\n{updated_report_model.final_corrected_sql_query}\n```\n"
user_msg += f"Explanation:\n{updated_report_model.final_explanation}\n"
if exec_error:
user_msg += f"\nExecution Error for new SQL: {exec_error}\n"
elif exec_result is not None:
preview_str = ""
if isinstance(exec_result, list) and len(exec_result) == 1 and isinstance(exec_result[0], dict):
single_row_dict = exec_result[0]
preview_str = str(single_row_dict)
elif isinstance(exec_result, str) and exec_result.endswith(".md"):
preview_str = f"Query result saved to {os.path.basename(exec_result)}"
else:
preview_str = str(exec_result)
if len(preview_str) > 200:
preview_str = preview_str[:197] + "..."
user_msg += f"\nExecution of new SQL successful. Result preview (1 row): {preview_str}\n"
user_msg += "Provide more /feedback or use /approved to save."
error_handler_module.display_response(user_msg)
return
except (ValidationError, json.JSONDecodeError, ValueError) as e:
if attempt == MAX_FEEDBACK_RETRIES:
error_handler_module.display_message(f"I'm having trouble processing your feedback. Error: {e}. Could you provide your feedback again, perhaps with more specific details?", level="ERROR")
return
feedback_prompt = f"Your previous response for updating the feedback report was invalid. Error: {e}. Please try again."
except Exception as e_gen:
if attempt == MAX_FEEDBACK_RETRIES:
error_handler_module.display_message(f"I encountered an issue while processing your feedback: {e_gen}. Could you rephrase your feedback?", level="ERROR")
return
feedback_prompt = "An unexpected error occurred. Please try to regenerate the updated feedback report JSON."
else:
error_handler_module.display_message("No SQL query generated yet for feedback. Use /generate_sql first.", level="ERROR")
# Command: /approved
elif base_command_lower == "approved":
if not self.current_feedback_report_content:
error_handler_module.display_message("No feedback report to approve. Use /generate_sql first.", level="ERROR")
return
print("Processing approval and updating memory, please wait...") # User-facing wait message
try:
# 1. Save Feedback Markdown
feedback_filepath = memory_module.save_feedback_markdown(
self.current_feedback_report_content,
self.current_db_name_identifier
)
# print(f"Feedback report saved to: {feedback_filepath}") # User sees final message
saved_feedback_md_content = memory_module.read_feedback_file(feedback_filepath)
if not saved_feedback_md_content:
error_handler_module.display_message("Could not read back saved feedback file for insights processing.", level="ERROR")
return
insights_success = await insights_module.generate_and_update_insights(
self,
saved_feedback_md_content,
self.current_db_name_identifier
)
if insights_success:
# print("Insights successfully generated/updated.") # User sees final message
self.cumulative_insights_content = memory_module.read_insights_file(self.current_db_name_identifier)
else:
error_handler_module.display_message("Failed to generate or update insights from this feedback.", level="WARNING")
if self.current_natural_language_question and self.current_feedback_report_content.final_corrected_sql_query:
try:
memory_module.save_nl2sql_pair(
self.current_db_name_identifier,
self.current_natural_language_question,
self.current_feedback_report_content.final_corrected_sql_query
)
# print("NLQ-SQL pair saved.") # User sees final message
except Exception as e_nl2sql:
error_handler_module.display_message(f"Failed to save NLQ-SQL pair: {e_nl2sql}", level="WARNING")
else:
error_handler_module.display_message("Could not save NLQ-SQL pair due to missing NLQ or final SQL query in the report.", level="WARNING")
final_user_message = f"Approved. Feedback report, insights, and NLQ-SQL pair for '{self.current_db_name_identifier}' saved."
# Remove technical error message about insights update failure
self._reset_feedback_cycle_state()
self._reset_revision_cycle_state() # Also reset revision state on approval
error_handler_module.display_response(f"{final_user_message}\nYou can start a new query with /generate_sql.")
except Exception as e:
error_handler_module.display_message(f"Error during approval process: {e}", level="ERROR")
# Command: /revise
elif base_command_lower == "revise":
revision_prompt = argument_text
if not revision_prompt:
error_handler_module.display_message("Please provide a revision prompt after /revise.", level="ERROR")
return
sql_to_start_revision_with = None
if self.is_in_revision_mode and self.current_revision_report_content and self.current_revision_report_content.final_revised_sql_query:
sql_to_start_revision_with = self.current_revision_report_content.final_revised_sql_query
elif self.current_feedback_report_content and self.current_feedback_report_content.final_corrected_sql_query:
sql_to_start_revision_with = self.current_feedback_report_content.final_corrected_sql_query
if not sql_to_start_revision_with:
error_handler_module.display_message("No SQL query available to revise. Use /generate_sql first, or ensure a previous revision/feedback cycle completed with a query.", level="ERROR")
return
if not self.is_in_revision_mode:
self.is_in_revision_mode = True
# self.current_feedback_report_content = None # Decided against this to allow revising feedback results
self.current_revision_report_content = RevisionReportContentModel(
initial_sql_for_revision=sql_to_start_revision_with
)
print(f"Revising SQL based on: \"{revision_prompt}\", please wait...")
revision_history_for_llm_context = []
if self.current_revision_report_content: