diff --git a/common/chunkers/character_chunker.py b/common/chunkers/character_chunker.py index 165c9cb..6d4138a 100644 --- a/common/chunkers/character_chunker.py +++ b/common/chunkers/character_chunker.py @@ -1,16 +1,16 @@ from common.chunkers.base_chunker import BaseChunker +_DEFAULT_FALLBACK_SIZE = 4096 + class CharacterChunker(BaseChunker): - def __init__(self, chunk_size=1024, overlap_size=0): - if chunk_size <= overlap_size: - raise ValueError("Chunk size must be larger than overlap size") - self.chunk_size = chunk_size + def __init__(self, chunk_size=0, overlap_size=0): + self.chunk_size = chunk_size if chunk_size > 0 else _DEFAULT_FALLBACK_SIZE self.overlap_size = overlap_size def chunk(self, input_string): - if self.chunk_size <= 0: - return [] + if self.chunk_size <= self.overlap_size: + raise ValueError("Chunk size must be larger than overlap size") chunks = [] i = 0 diff --git a/common/chunkers/html_chunker.py b/common/chunkers/html_chunker.py index e598605..326dff8 100644 --- a/common/chunkers/html_chunker.py +++ b/common/chunkers/html_chunker.py @@ -15,7 +15,12 @@ from typing import Optional, List, Tuple import re from common.chunkers.base_chunker import BaseChunker +from common.chunkers.separators import TEXT_SEPARATORS from langchain_text_splitters import HTMLSectionSplitter +from langchain.text_splitter import RecursiveCharacterTextSplitter + + +_DEFAULT_FALLBACK_SIZE = 4096 class HTMLChunker(BaseChunker): @@ -25,12 +30,20 @@ class HTMLChunker(BaseChunker): - Automatically detects which headers (h1-h6) are present in the HTML - Uses only the headers that exist in the document for optimal chunking - If custom headers are provided, uses those instead of auto-detection + - Supports chunk_size / chunk_overlap: when chunk_size > 0, oversized + header-based chunks are further split with RecursiveCharacterTextSplitter + - When chunk_size is 0 (default), a fallback of 4096 is used so that + headerless HTML documents are still split into reasonable chunks """ def __init__( self, - headers: Optional[List[Tuple[str, str]]] = None # e.g. [("h1", "Header 1"), ("h2", "Header 2")] + chunk_size: int = 0, + chunk_overlap: int = 0, + headers: Optional[List[Tuple[str, str]]] = None, ): + self.chunk_size = chunk_size if chunk_size > 0 else _DEFAULT_FALLBACK_SIZE + self.chunk_overlap = chunk_overlap self.headers = headers def _detect_headers(self, html_content: str) -> List[Tuple[str, str]]: @@ -77,8 +90,23 @@ def chunk(self, input_string: str) -> List[str]: splitter = HTMLSectionSplitter(headers_to_split_on=headers_to_use) docs = splitter.split_text(input_string) - # Extract text content from Document objects - return [doc.page_content for doc in docs] + initial_chunks = [doc.page_content for doc in docs] + + if any(len(chunk) > self.chunk_size for chunk in initial_chunks): + recursive_splitter = RecursiveCharacterTextSplitter( + separators=TEXT_SEPARATORS, + chunk_size=self.chunk_size, + chunk_overlap=self.chunk_overlap, + ) + final_chunks = [] + for chunk in initial_chunks: + if len(chunk) > self.chunk_size: + final_chunks.extend(recursive_splitter.split_text(chunk)) + else: + final_chunks.append(chunk) + return final_chunks + + return initial_chunks def __call__(self, input_string: str) -> List[str]: return self.chunk(input_string) diff --git a/common/chunkers/markdown_chunker.py b/common/chunkers/markdown_chunker.py index aabc0fc..2d4c4ce 100644 --- a/common/chunkers/markdown_chunker.py +++ b/common/chunkers/markdown_chunker.py @@ -17,6 +17,11 @@ from langchain_text_splitters.markdown import ExperimentalMarkdownSyntaxTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter +# When chunk_size is not configured, cap any heading-section that exceeds this +# so that form-based PDFs (tables/bold but no # headings) are not left as a +# single multi-thousand-character chunk. +_DEFAULT_FALLBACK_SIZE = 4096 + class MarkdownChunker(BaseChunker): @@ -25,31 +30,33 @@ def __init__( chunk_size: int = 0, chunk_overlap: int = 0 ): - self.chunk_size = chunk_size + self.chunk_size = chunk_size if chunk_size > 0 else _DEFAULT_FALLBACK_SIZE self.chunk_overlap = chunk_overlap def chunk(self, input_string): md_splitter = ExperimentalMarkdownSyntaxTextSplitter() + # ExperimentalMarkdownSyntaxTextSplitter splits on # headings only. + # Documents without headings (e.g. form PDFs with tables/bold but no #) + # are returned as a single section, so a recursive fallback is always + # applied when any section exceeds the configured (or default) limit. initial_chunks = [x.page_content for x in md_splitter.split_text(input_string)] - md_chunks = [] - if self.chunk_size > 0: + if any(len(chunk) > self.chunk_size for chunk in initial_chunks): recursive_splitter = RecursiveCharacterTextSplitter( separators=TEXT_SEPARATORS, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, ) - - if any(len(chunk) > self.chunk_size for chunk in initial_chunks): - for chunk in initial_chunks: - if len(chunk) > self.chunk_size: - # Split oversized chunks further - md_chunks.extend(recursive_splitter.split_text(chunk)) - else: - md_chunks.append(chunk) - - return md_chunks if md_chunks else initial_chunks + md_chunks = [] + for chunk in initial_chunks: + if len(chunk) > self.chunk_size: + md_chunks.extend(recursive_splitter.split_text(chunk)) + else: + md_chunks.append(chunk) + return md_chunks + + return initial_chunks def __call__(self, input_string): return self.chunk(input_string) diff --git a/common/chunkers/recursive_chunker.py b/common/chunkers/recursive_chunker.py index a69bdd4..4c8a324 100644 --- a/common/chunkers/recursive_chunker.py +++ b/common/chunkers/recursive_chunker.py @@ -16,10 +16,12 @@ from common.chunkers.separators import TEXT_SEPARATORS from langchain.text_splitter import RecursiveCharacterTextSplitter +_DEFAULT_FALLBACK_SIZE = 4096 + class RecursiveChunker(BaseChunker): - def __init__(self, chunk_size=1024, overlap_size=0): - self.chunk_size = chunk_size + def __init__(self, chunk_size=0, overlap_size=0): + self.chunk_size = chunk_size if chunk_size > 0 else _DEFAULT_FALLBACK_SIZE self.overlap_size = overlap_size def chunk(self, input_string): diff --git a/common/config.py b/common/config.py index 18a4288..2b58581 100644 --- a/common/config.py +++ b/common/config.py @@ -259,8 +259,9 @@ def get_multimodal_service() -> LLM_Model: gsPort=db_config.get("gsPort", "14240"), restppPort=db_config.get("restppPort", "9000"), graphname=db_config.get("graphname", ""), + apiToken=db_config.get("apiToken", ""), ) - if db_config.get("getToken"): + if not db_config.get("apiToken") and db_config.get("getToken"): conn.getToken() embedding_store = TigerGraphEmbeddingStore( diff --git a/common/db/connections.py b/common/db/connections.py index 4cf21d9..fab87c3 100644 --- a/common/db/connections.py +++ b/common/db/connections.py @@ -120,6 +120,34 @@ def get_db_connection_pwd_manual( return conn def elevate_db_connection_to_token(host, username, password, graphname, async_conn: bool = False) -> TigerGraphConnectionProxy: + # If a pre-existing apiToken is provided in config, use it directly + # and skip the getToken() call to avoid conflicts. + static_token = db_config.get("apiToken", "") + + if static_token: + LogWriter.info("Using pre-configured apiToken from db_config") + if async_conn: + conn = AsyncTigerGraphConnection( + host=host, + username=username, + password=password, + graphname=graphname, + apiToken=static_token, + restppPort=db_config.get("restppPort", "9000"), + gsPort=db_config.get("gsPort", "14240"), + ) + else: + conn = TigerGraphConnection( + host=host, + username=username, + password=password, + graphname=graphname, + apiToken=static_token, + restppPort=db_config.get("restppPort", "9000"), + gsPort=db_config.get("gsPort", "14240"), + ) + return conn + conn = TigerGraphConnection( host=host, username=username, @@ -129,7 +157,7 @@ def elevate_db_connection_to_token(host, username, password, graphname, async_co gsPort=db_config.get("gsPort", "14240") ) - if db_config["getToken"]: + if db_config.get("getToken"): try: apiToken = conn.getToken()[0] except HTTPError: diff --git a/common/llm_services/base_llm.py b/common/llm_services/base_llm.py index bf159fb..1dafd3d 100644 --- a/common/llm_services/base_llm.py +++ b/common/llm_services/base_llm.py @@ -109,13 +109,19 @@ def route_response_prompt(self): prompt = """\ You are an expert at routing a user question to a vectorstore, function calls, or conversation history. Use the conversation history for questions that are similar to previous ones or that reference earlier answers or responses. -Use the vectorstore for questions on that would be best suited by text documents. +Use the vectorstore for questions that would be best suited by text documents. Use the function calls for questions that ask about structured data, or operations on structured data. Questions referring to same entities in a previous, earlier, or above answer or response should be routed to the conversation history. Keep in mind that some questions about documents such as "how many documents are there?" can be answered by function calls. The function calls can be used to answer questions about these entities: {v_types} and relationships: {e_types}. +IMPORTANT: Questions about graph database statistics or metadata MUST be routed to function calls. This includes: +- Counting vertices/nodes/edges (e.g. "how many vertices are there", "how many edges in the graph") +- Listing or describing vertex/edge types, schema, or graph structure +- Aggregations, totals, or summaries of data stored in the graph database +- Any question mentioning "graph", "graph db", "graph database", "vertices", "nodes", or "edges" in the context of statistics or counts +These are database queries, NOT document lookups — always route them to function calls. Otherwise, use vectorstore. Choose one of 'functions', 'vectorstore', or 'history' based on the question and conversation history. -Return the a JSON with a single key 'datasource' and no premable or explaination. +Return a JSON with a single key 'datasource' and no preamble or explanation. Question to route: {question} Conversation history: {conversation} Format: {format_instructions}\ diff --git a/common/prompts/aws_bedrock_claude3haiku/generate_function.txt b/common/prompts/aws_bedrock_claude3haiku/generate_function.txt index 00ae7f0..359b46c 100644 --- a/common/prompts/aws_bedrock_claude3haiku/generate_function.txt +++ b/common/prompts/aws_bedrock_claude3haiku/generate_function.txt @@ -1,5 +1,5 @@ Use the vertex types, edge types, and their attributes and IDs below to write the pyTigerGraph function call to answer the question using a pyTigerGraph connection. -When the question asks for "How many", make sure to always select a function that contains "Count" in the description/function call. Make sure never to generate a function that is not listed below. +When the question asks for "How many", counts, totals, or statistics about vertices/nodes/edges in the graph or graph database, make sure to always select a function that contains "Count" in the description/function call. For example, questions like "how many vertices are there in the graph" or "how many vertices are there in the graph db" should use getVertexCount or getEdgeCount. Make sure never to generate a function that is not listed below. When certain entities are mapped to vertex attributes, may consider to generate a WHERE clause. If a WHERE clause is generated, please follow the instruction with proper quoting. To construct a WHERE clause string. Ensure that string attribute values are properly quoted. For example, if the generated function contains "('Person', where='name=William Torres')", Expected Output: "('Person', where='name="William Torres"')", This rule applies to all types of attributes. e.g., name, email, address and so on. diff --git a/common/prompts/aws_bedrock_titan/generate_function.txt b/common/prompts/aws_bedrock_titan/generate_function.txt index 22ddb60..b0be05c 100644 --- a/common/prompts/aws_bedrock_titan/generate_function.txt +++ b/common/prompts/aws_bedrock_titan/generate_function.txt @@ -1,5 +1,5 @@ Use the vertex types, edge types, and their attributes and IDs to write the pyTigerGraph function call to answer the question using a pyTigerGraph connection. -When the question asks for "How many", make sure to always select a function that contains "Count" in the description/function call. Make sure never to generate a function that is not listed below. +When the question asks for "How many", counts, totals, or statistics about vertices/nodes/edges in the graph or graph database, make sure to always select a function that contains "Count" in the description/function call. For example, questions like "how many vertices are there in the graph" or "how many vertices are there in the graph db" should use getVertexCount or getEdgeCount. Make sure never to generate a function that is not listed below. When certain entities are mapped to vertex attributes, may consider to generate a WHERE clause. If a WHERE clause is generated, please follow the instruction with proper quoting. To construct a WHERE clause string. Ensure that string attribute values are properly quoted. For example, if the generated function contains "('Person', where='name=William Torres')", Expected Output: "('Person', where='name="William Torres"')", This rule applies to all types of attributes. e.g., name, email, address and so on. diff --git a/common/prompts/azure_open_ai_gpt35_turbo_instruct/generate_function.txt b/common/prompts/azure_open_ai_gpt35_turbo_instruct/generate_function.txt index c42e466..e0a83d0 100644 --- a/common/prompts/azure_open_ai_gpt35_turbo_instruct/generate_function.txt +++ b/common/prompts/azure_open_ai_gpt35_turbo_instruct/generate_function.txt @@ -1,5 +1,5 @@ Use the vertex types, edge types, and their attributes and IDs below to write the pyTigerGraph function call to answer the question using a pyTigerGraph connection. -When the question asks for "How many", make sure to always select a function that contains "Count" in the description/function call. Make sure never to generate a function that is not listed below. +When the question asks for "How many", counts, totals, or statistics about vertices/nodes/edges in the graph or graph database, make sure to always select a function that contains "Count" in the description/function call. For example, questions like "how many vertices are there in the graph" or "how many vertices are there in the graph db" should use getVertexCount or getEdgeCount. Make sure never to generate a function that is not listed below. When certain entities are mapped to vertex attributes, may consider to generate a WHERE clause. If a WHERE clause is generated, please follow the instruction with proper quoting. To construct a WHERE clause string. Ensure that string attribute values are properly quoted. For example, if the generated function contains "('Person', where='name=William Torres')", Expected Output: "('Person', where='name="William Torres"')", This rule applies to all types of attributes. e.g., name, email, address and so on. diff --git a/common/prompts/custom/aml/generate_function.txt b/common/prompts/custom/aml/generate_function.txt index 00ae7f0..359b46c 100644 --- a/common/prompts/custom/aml/generate_function.txt +++ b/common/prompts/custom/aml/generate_function.txt @@ -1,5 +1,5 @@ Use the vertex types, edge types, and their attributes and IDs below to write the pyTigerGraph function call to answer the question using a pyTigerGraph connection. -When the question asks for "How many", make sure to always select a function that contains "Count" in the description/function call. Make sure never to generate a function that is not listed below. +When the question asks for "How many", counts, totals, or statistics about vertices/nodes/edges in the graph or graph database, make sure to always select a function that contains "Count" in the description/function call. For example, questions like "how many vertices are there in the graph" or "how many vertices are there in the graph db" should use getVertexCount or getEdgeCount. Make sure never to generate a function that is not listed below. When certain entities are mapped to vertex attributes, may consider to generate a WHERE clause. If a WHERE clause is generated, please follow the instruction with proper quoting. To construct a WHERE clause string. Ensure that string attribute values are properly quoted. For example, if the generated function contains "('Person', where='name=William Torres')", Expected Output: "('Person', where='name="William Torres"')", This rule applies to all types of attributes. e.g., name, email, address and so on. diff --git a/common/prompts/gcp_vertexai_palm/generate_function.txt b/common/prompts/gcp_vertexai_palm/generate_function.txt index d8fda1f..fe7d3cc 100644 --- a/common/prompts/gcp_vertexai_palm/generate_function.txt +++ b/common/prompts/gcp_vertexai_palm/generate_function.txt @@ -1,5 +1,5 @@ Use the vertex types, edge types, and their attributes and IDs below to write the pyTigerGraph function call to answer the question using a pyTigerGraph connection. -When the question asks for "How many", make sure to always select a function that contains "Count" in the description/function call. Make sure never to generate a function that is not listed below. +When the question asks for "How many", counts, totals, or statistics about vertices/nodes/edges in the graph or graph database, make sure to always select a function that contains "Count" in the description/function call. For example, questions like "how many vertices are there in the graph" or "how many vertices are there in the graph db" should use getVertexCount or getEdgeCount. Make sure never to generate a function that is not listed below. When certain entities are mapped to vertex attributes, may consider to generate a WHERE clause. If a WHERE clause is generated, please follow the instruction with proper quoting. To construct a WHERE clause string. Ensure that string attribute values are properly quoted. For example, if the generated function contains "('Person', where='name=William Torres')", Expected Output: "('Person', where='name="William Torres"')", This rule applies to all types of attributes. e.g., name, email, address and so on. diff --git a/common/prompts/google_gemini/generate_function.txt b/common/prompts/google_gemini/generate_function.txt index 781d63d..a7e4ee0 100644 --- a/common/prompts/google_gemini/generate_function.txt +++ b/common/prompts/google_gemini/generate_function.txt @@ -1,5 +1,5 @@ Use the vertex types, edge types, and their attributes and IDs below to write the pyTigerGraph function call to answer the question using a pyTigerGraph connection. -When the question asks for "How many", make sure to always select a function that contains "Count" in the description/function call. Make sure never to generate a function that is not listed below. +When the question asks for "How many", counts, totals, or statistics about vertices/nodes/edges in the graph or graph database, make sure to always select a function that contains "Count" in the description/function call. For example, questions like "how many vertices are there in the graph" or "how many vertices are there in the graph db" should use getVertexCount or getEdgeCount. Make sure never to generate a function that is not listed below. When certain entities are mapped to vertex attributes, may consider to generate a WHERE clause. If a WHERE clause is generated, please follow the instruction with proper quoting. To construct a WHERE clause string. Ensure that string attribute values are properly quoted. For example, if the generated function contains "('Person', where='name=William Torres')", Expected Output: "('Person', where='name="William Torres"')", This rule applies to all types of attributes. e.g., name, email, address and so on. diff --git a/common/prompts/llama_70b/generate_function.txt b/common/prompts/llama_70b/generate_function.txt index 781d63d..a7e4ee0 100644 --- a/common/prompts/llama_70b/generate_function.txt +++ b/common/prompts/llama_70b/generate_function.txt @@ -1,5 +1,5 @@ Use the vertex types, edge types, and their attributes and IDs below to write the pyTigerGraph function call to answer the question using a pyTigerGraph connection. -When the question asks for "How many", make sure to always select a function that contains "Count" in the description/function call. Make sure never to generate a function that is not listed below. +When the question asks for "How many", counts, totals, or statistics about vertices/nodes/edges in the graph or graph database, make sure to always select a function that contains "Count" in the description/function call. For example, questions like "how many vertices are there in the graph" or "how many vertices are there in the graph db" should use getVertexCount or getEdgeCount. Make sure never to generate a function that is not listed below. When certain entities are mapped to vertex attributes, may consider to generate a WHERE clause. If a WHERE clause is generated, please follow the instruction with proper quoting. To construct a WHERE clause string. Ensure that string attribute values are properly quoted. For example, if the generated function contains "('Person', where='name=William Torres')", Expected Output: "('Person', where='name="William Torres"')", This rule applies to all types of attributes. e.g., name, email, address and so on. diff --git a/common/prompts/openai_gpt4/generate_function.txt b/common/prompts/openai_gpt4/generate_function.txt index 00ae7f0..359b46c 100644 --- a/common/prompts/openai_gpt4/generate_function.txt +++ b/common/prompts/openai_gpt4/generate_function.txt @@ -1,5 +1,5 @@ Use the vertex types, edge types, and their attributes and IDs below to write the pyTigerGraph function call to answer the question using a pyTigerGraph connection. -When the question asks for "How many", make sure to always select a function that contains "Count" in the description/function call. Make sure never to generate a function that is not listed below. +When the question asks for "How many", counts, totals, or statistics about vertices/nodes/edges in the graph or graph database, make sure to always select a function that contains "Count" in the description/function call. For example, questions like "how many vertices are there in the graph" or "how many vertices are there in the graph db" should use getVertexCount or getEdgeCount. Make sure never to generate a function that is not listed below. When certain entities are mapped to vertex attributes, may consider to generate a WHERE clause. If a WHERE clause is generated, please follow the instruction with proper quoting. To construct a WHERE clause string. Ensure that string attribute values are properly quoted. For example, if the generated function contains "('Person', where='name=William Torres')", Expected Output: "('Person', where='name="William Torres"')", This rule applies to all types of attributes. e.g., name, email, address and so on. diff --git a/common/utils/text_extractors.py b/common/utils/text_extractors.py index e2bd6df..2b3e78d 100644 --- a/common/utils/text_extractors.py +++ b/common/utils/text_extractors.py @@ -5,10 +5,10 @@ import os import json import logging -import uuid import base64 import io import re +import tempfile import threading from pathlib import Path import shutil @@ -21,7 +21,54 @@ _pymupdf4llm_lock = threading.Lock() # regex for markdown images: ![alt](path) -_md_pattern = re.compile(r'!\[([^\]]*)\]\(([^)\s]+)\)') +# [^)]+ (not [^)\s]+) so that paths containing spaces are captured correctly. +# pymupdf4llm can generate image filenames with spaces; the narrower \s exclusion +# caused extract_images() to silently return [] for those files, deleting the temp +# folder and leaving broken references in the markdown. +_md_pattern = re.compile(r'!\[([^\]]*)\]\(([^)]+)\)') + +# Matches a ColN placeholder header cell produced by pymupdf4llm when it +# cannot detect a column header from the PDF structure (common in form PDFs). +_coln_pattern = re.compile(r'\bCol\d+\b') + + +def _clean_pdf_markdown(markdown: str) -> str: + """Apply post-processing to markdown produced by pymupdf4llm for form PDFs. + + Two specific artefacts are fixed: + + 1. **Duplicate table rows** — complex form PDFs (e.g. IRS forms) often have + overlapping text layers (a rendered background layer plus a searchable text + layer). pymupdf4llm can emit the same row twice: once from the background + layer (no formatting, missing spaces) and once from the text layer (bold, + correct spacing). The duplicate row that appears immediately after the + original is removed; when the content is identical after stripping bold + markers, the richer (longer) version is kept. + + 2. **ColN placeholder headers** — pymupdf4llm uses "Col1", "Col2", … when it + cannot derive a header from the PDF's column structure. These are replaced + with empty strings so the table is still valid markdown but does not expose + internal artefacts to downstream consumers. + """ + # --- Pass 1: remove ColN placeholders --- + markdown = _coln_pattern.sub('', markdown) + + # --- Pass 2: deduplicate consecutive table rows --- + lines = markdown.splitlines() + cleaned: list[str] = [] + for line in lines: + if cleaned and line.startswith('|') and cleaned[-1].startswith('|'): + prev = cleaned[-1] + norm_cur = re.sub(r'\*+', '', line).strip() + norm_prev = re.sub(r'\*+', '', prev).strip() + if norm_cur == norm_prev: + if len(line) > len(prev): + cleaned[-1] = line + continue + cleaned.append(line) + + return '\n'.join(cleaned) + def extract_images(md_text): """ @@ -291,15 +338,44 @@ def extract_text_from_file_with_images_as_docs(file_path, graphname=None): "position": 0 }] +def _sanitize_image_filenames(image_folder, markdown_content): + """Rename image files that contain spaces (replace with underscores). + + pymupdf4llm can produce filenames with spaces. Renaming them avoids + downstream issues with path parsing and markdown rendering. + + Returns the updated markdown_content with paths adjusted to match the + renamed files. + """ + if not image_folder.exists(): + return markdown_content + + for img_file in image_folder.iterdir(): + if not img_file.is_file() or ' ' not in img_file.name: + continue + new_name = img_file.name.replace(' ', '_') + new_path = img_file.with_name(new_name) + img_file.rename(new_path) + old_ref = str(img_file) + new_ref = str(new_path) + markdown_content = markdown_content.replace(old_ref, new_ref) + + return markdown_content + + def _extract_pdf_with_images_as_docs(file_path, base_doc_id, graphname=None): """ Extract PDF as ONE markdown document with inline image references using pymupdf4llm. Uses unique temporary folder per PDF to allow parallel processing. After processing, delete the extracted image folder. """ - # Use unique folder per PDF to allow parallel processing without conflicts - unique_folder_id = uuid.uuid4().hex[:12] - image_output_folder = Path(f"tg_temp_{unique_folder_id}") + # Use a unique ABSOLUTE temp folder per PDF. + # A relative path would resolve to whatever the process CWD happens to be at + # call time (varies across ThreadPoolExecutor threads in container deployments). + # pymupdf4llm embeds os.path.join(image_path, filename) in the markdown, so an + # absolute image_path produces absolute embedded paths that PIL can always open + # regardless of CWD. + image_output_folder = Path(tempfile.mkdtemp(prefix="tg_pdf_")) try: import pymupdf4llm @@ -346,7 +422,24 @@ def _extract_pdf_with_images_as_docs(file_path, base_doc_id, graphname=None): }] if not markdown_content or not markdown_content.strip(): - logger.warning(f"No content extracted from PDF: {file_path}") + logger.warning( + f"No text layer found in PDF: {file_path}. " + "The file may be a scanned image-only PDF — consider enabling OCR." + ) + if image_output_folder.exists(): + shutil.rmtree(image_output_folder, ignore_errors=True) + return [{ + "doc_id": base_doc_id, + "doc_type": "markdown", + "content": f"[Scanned PDF — no text layer extracted: {file_path.name}]", + "position": 0 + }] + + # Clean up artefacts common in form PDFs (duplicate rows, ColN headers) + markdown_content = _clean_pdf_markdown(markdown_content) + + # Rename image files that contain spaces to avoid path-parsing issues + markdown_content = _sanitize_image_filenames(image_output_folder, markdown_content) # Extract image references from markdown image_refs = extract_images(markdown_content) diff --git a/docker-compose.yml b/docker-compose.yml index 8be754b..b228151 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,7 +11,7 @@ services: depends_on: - graphrag-ecc - chat-history - - tigergraph +# - tigergraph environment: SERVER_CONFIG: "/code/configs/server_config.json" LOGLEVEL: "INFO" @@ -73,14 +73,14 @@ services: - graphrag-ui - graphrag - tigergraph: - image: tigergraph/community:4.2.1 - container_name: tigergraph - platform: linux/amd64 - ports: - - "14240:14240" - volumes: - - tigergraph_data:/home/tigergraph/tigergraph/data - -volumes: - tigergraph_data: +# tigergraph: +# image: tigergraph/community:4.2.1 +# container_name: tigergraph +# platform: linux/amd64 +# ports: +# - "14240:14240" +# volumes: +# - tigergraph_data:/home/tigergraph/tigergraph/data +# +# volumes: +# tigergraph_data: diff --git a/ecc/app/ecc_util.py b/ecc/app/ecc_util.py index 75a3f87..35bbcaa 100644 --- a/ecc/app/ecc_util.py +++ b/ecc/app/ecc_util.py @@ -28,7 +28,7 @@ def get_chunker(chunker_type: str = ""): ) elif chunker_type == "character": chunker = character_chunker.CharacterChunker( - chunk_size=chunker_config.get("chunk_size", 1024), + chunk_size=chunker_config.get("chunk_size", 4096), overlap_size=chunker_config.get("overlap_size", 0), ) elif chunker_type == "markdown": @@ -38,11 +38,13 @@ def get_chunker(chunker_type: str = ""): ) elif chunker_type == "html": chunker = html_chunker.HTMLChunker( - headers=chunker_config.get("headers", None) + chunk_size=chunker_config.get("chunk_size", 0), + chunk_overlap=chunker_config.get("overlap_size", 0), + headers=chunker_config.get("headers", None), ) elif chunker_type == "recursive": chunker = recursive_chunker.RecursiveChunker( - chunk_size=chunker_config.get("chunk_size", 1024), + chunk_size=chunker_config.get("chunk_size", 4096), overlap_size=chunker_config.get("overlap_size", 0), ) elif chunker_type == "single" or chunker_type == "image": diff --git a/ecc/app/graphrag/graph_rag.py b/ecc/app/graphrag/graph_rag.py index 49ef496..5544789 100644 --- a/ecc/app/graphrag/graph_rag.py +++ b/ecc/app/graphrag/graph_rag.py @@ -179,7 +179,8 @@ async def upsert(upsert_chan: Channel): async def load(conn: AsyncTigerGraphConnection): logger.info("Data Loading Start") dd = lambda: defaultdict(dd) # infinite default dict - batch_size = 500 + batch_size = graphrag_config.get("load_batch_size", 500) + upsert_delay = graphrag_config.get("upsert_delay", 0) # while the load q is still open or has contents while not load_q.closed() or not load_q.empty(): if load_q.closed(): @@ -227,11 +228,12 @@ async def load(conn: AsyncTigerGraphConnection): f"Upserting batch size of {size}. ({n_verts} verts | {n_edges} edges. {len(data.encode())/1000:,} kb)" ) - loading_event.clear() - if n_verts >0 or n_edges >0: + if n_verts > 0 or n_edges > 0: + loading_event.clear() await upsert_batch(conn, data) - await asyncio.sleep(5) - loading_event.set() + loading_event.set() + if upsert_delay > 0: + await asyncio.sleep(upsert_delay) else: await asyncio.sleep(1) diff --git a/graphrag/app/supportai/supportai_ingest.py b/graphrag/app/supportai/supportai_ingest.py index afe4a67..4ba69f1 100644 --- a/graphrag/app/supportai/supportai_ingest.py +++ b/graphrag/app/supportai/supportai_ingest.py @@ -53,7 +53,9 @@ def chunk_document(self, document, chunker, chunker_params): from common.chunkers.html_chunker import HTMLChunker chunker = HTMLChunker( - headers=chunker_params.get("headers", None) + chunk_size=chunker_params.get("chunk_size", 0), + chunk_overlap=chunker_params.get("overlap_size", 0), + headers=chunker_params.get("headers", None), ) elif chunker.lower() == "markdown": from common.chunkers.markdown_chunker import MarkdownChunker diff --git a/graphrag/tests/test_connections.py b/graphrag/tests/test_connections.py new file mode 100644 index 0000000..40fdbb8 --- /dev/null +++ b/graphrag/tests/test_connections.py @@ -0,0 +1,117 @@ +"""Unit tests for common.db.connections apiToken support.""" + +import sys +import unittest +from unittest.mock import patch, MagicMock + +# Mock heavy dependencies before importing the module under test +_mock_modules = {} +for mod_name in [ + "common.config", + "common.logs", + "common.logs.logwriter", + "common.logs.log", + "common.metrics", + "common.metrics.tg_proxy", + "common.metrics.prometheus_metrics", + "common.embeddings", + "common.embeddings.embedding_services", + "common.embeddings.tigergraph_embedding_store", + "common.llm_services", + "common.session", + "common.status", + "langchain", + "langchain.schema", + "langchain.schema.embeddings", + "prometheus_client", +]: + if mod_name not in sys.modules: + _mock_modules[mod_name] = MagicMock() + sys.modules[mod_name] = _mock_modules[mod_name] + +# Provide the values that connections.py reads at import time +sys.modules["common.config"].security = MagicMock() + +# Provide TigerGraphConnectionProxy +mock_proxy_cls = MagicMock() +sys.modules["common.metrics.tg_proxy"].TigerGraphConnectionProxy = mock_proxy_cls + +# Provide LogWriter +mock_logwriter = MagicMock() +sys.modules["common.logs.logwriter"].LogWriter = mock_logwriter + +from pyTigerGraph import TigerGraphConnection, AsyncTigerGraphConnection + + +MOCK_DB_CONFIG_BASE = { + "hostname": "http://test-host", + "restppPort": "9000", + "gsPort": "14240", + "getToken": False, + "default_timeout": 300, +} + + +class TestElevateDbConnectionWithApiToken(unittest.TestCase): + """Test that elevate_db_connection_to_token honours apiToken from db_config.""" + + def _import_and_patch(self, db_config_override): + """Import elevate_db_connection_to_token with a patched db_config.""" + sys.modules["common.config"].db_config = db_config_override + # Re-import to pick up fresh db_config reference + import importlib + import common.db.connections as conn_mod + importlib.reload(conn_mod) + return conn_mod.elevate_db_connection_to_token + + def test_static_api_token_used_directly(self): + cfg = {**MOCK_DB_CONFIG_BASE, "apiToken": "static_tok_123"} + elevate = self._import_and_patch(cfg) + + conn = elevate("http://test-host", "user", "pass", "TestGraph") + + self.assertEqual(conn.apiToken, "static_tok_123") + self.assertEqual(conn.host, "http://test-host") + self.assertEqual(conn.graphname, "TestGraph") + + def test_api_token_skips_get_token(self): + cfg = {**MOCK_DB_CONFIG_BASE, "apiToken": "tok", "getToken": True} + elevate = self._import_and_patch(cfg) + + with patch.object(TigerGraphConnection, "getToken") as mock_get: + conn = elevate("http://test-host", "user", "pass", "TestGraph") + mock_get.assert_not_called() + + self.assertEqual(conn.apiToken, "tok") + + def test_static_api_token_async_conn(self): + cfg = {**MOCK_DB_CONFIG_BASE, "apiToken": "async_tok"} + elevate = self._import_and_patch(cfg) + + conn = elevate("http://test-host", "user", "pass", "TestGraph", async_conn=True) + + self.assertIsInstance(conn, AsyncTigerGraphConnection) + self.assertEqual(conn.apiToken, "async_tok") + + def test_empty_api_token_falls_through(self): + cfg = {**MOCK_DB_CONFIG_BASE, "apiToken": ""} + elevate = self._import_and_patch(cfg) + + conn = elevate("http://test-host", "user", "pass", "TestGraph") + + # Empty token treated as not set — password auth + self.assertIsInstance(conn, TigerGraphConnection) + self.assertEqual(conn.username, "user") + + def test_no_api_token_key_falls_through(self): + cfg = {**MOCK_DB_CONFIG_BASE} + elevate = self._import_and_patch(cfg) + + conn = elevate("http://test-host", "user", "pass", "TestGraph") + + self.assertIsInstance(conn, TigerGraphConnection) + self.assertEqual(conn.username, "user") + + +if __name__ == "__main__": + unittest.main()