From ce341e5b16fd0434ac5ef4de230f655fcdd84a91 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 5 Mar 2026 09:08:23 +0000 Subject: [PATCH] Add Anthropic and Ollama LLM provider support alongside OpenAI - Add multi-provider architecture: OpenAI (default), Anthropic, and Ollama - Ollama uses OpenAI-compatible API, Anthropic uses native SDK - New CLI flags: --provider and --api-base-url - Config and env var support: LLM_PROVIDER, API_BASE_URL, ANTHROPIC_API_KEY - Tiktoken fallback to cl100k_base for non-OpenAI models - Add anthropic SDK to requirements.txt - Update README with provider docs, usage examples, and comparison table https://claude.ai/code/session_01RdrWXrD9rABwBjv2NRYn6u --- README.md | 47 +++++++++++- pageindex/config.yaml | 2 + pageindex/page_index.py | 20 ++++- pageindex/utils.py | 165 +++++++++++++++++++++++++++++++--------- requirements.txt | 1 + run_pageindex.py | 13 +++- 6 files changed, 205 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 7180efd5a..9e60a59ff 100644 --- a/README.md +++ b/README.md @@ -147,27 +147,49 @@ You can follow these steps to generate a PageIndex tree from a PDF document. pip3 install --upgrade -r requirements.txt ``` -### 2. Set your OpenAI API key +### 2. Set your API key -Create a `.env` file in the root directory and add your API key: +Create a `.env` file in the root directory and add your API key for your chosen provider: ```bash +# OpenAI (default) CHATGPT_API_KEY=your_openai_key_here + +# Anthropic (optional) +ANTHROPIC_API_KEY=your_anthropic_key_here + +# Ollama — no API key needed, just have Ollama running locally ``` ### 3. Run PageIndex on your PDF +**OpenAI (default):** ```bash python3 run_pageindex.py --pdf_path /path/to/your/document.pdf ``` +**Anthropic:** +```bash +python3 run_pageindex.py --pdf_path /path/to/your/document.pdf \ + --provider anthropic --model claude-sonnet-4-20250514 +``` + +**Ollama (local models):** +```bash +# Make sure Ollama is running (ollama serve) +python3 run_pageindex.py --pdf_path /path/to/your/document.pdf \ + --provider ollama --model llama3 +``` +
Optional parameters
You can customize the processing with additional optional arguments: ``` ---model OpenAI model to use (default: gpt-4o-2024-11-20) +--model Model to use (default: gpt-4o-2024-11-20) +--provider LLM provider: openai, anthropic, or ollama (default: openai) +--api-base-url Custom API base URL (e.g. http://localhost:11434/v1 for Ollama) --toc-check-pages Pages to check for table of contents (default: 20) --max-pages-per-node Max pages per node (default: 10) --max-tokens-per-node Max tokens per node (default: 20000) @@ -175,6 +197,25 @@ You can customize the processing with additional optional arguments: --if-add-node-summary Add node summary (yes/no, default: yes) --if-add-doc-description Add doc description (yes/no, default: yes) ``` + +You can also set the provider via environment variables instead of CLI flags: +```bash +export LLM_PROVIDER=ollama # or "anthropic" +export API_BASE_URL=http://localhost:11434/v1 # optional, for custom endpoints +``` +
+ +
+Supported LLM Providers +
+ +| Provider | Example Models | API Key Env Var | Notes | +|----------|---------------|-----------------|-------| +| **OpenAI** (default) | `gpt-4o-2024-11-20`, `gpt-4o-mini` | `CHATGPT_API_KEY` | Full support, recommended | +| **Anthropic** | `claude-sonnet-4-20250514`, `claude-haiku-4-5-20251001` | `ANTHROPIC_API_KEY` | Full support | +| **Ollama** | `llama3`, `mistral`, `qwen2.5` | _(none needed)_ | Requires Ollama running locally. Uses OpenAI-compatible API at `http://localhost:11434/v1` | + +**Note:** PageIndex relies on structured JSON output from the LLM. For best results, use capable models (GPT-4o, Claude Sonnet/Opus, or large Ollama models like Llama 3 70B+). Smaller local models may produce lower-quality tree structures.
diff --git a/pageindex/config.yaml b/pageindex/config.yaml index fd73e3a2c..d9e86b730 100644 --- a/pageindex/config.yaml +++ b/pageindex/config.yaml @@ -1,4 +1,6 @@ model: "gpt-4o-2024-11-20" +provider: "openai" # "openai", "anthropic", or "ollama" +api_base_url: null # Custom API base URL (e.g. http://localhost:11434/v1 for Ollama) toc_check_page_num: 20 max_page_num_each_node: 10 max_token_num_each_node: 20000 diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 39018c4df..2f4165304 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -1057,16 +1057,27 @@ async def tree_parser(page_list, opt, doc=None, logger=None): def page_index_main(doc, opt=None): logger = JsonLogger(doc) - + + # Set provider config from opt so all downstream API calls pick it up + if hasattr(opt, 'provider') and opt.provider: + os.environ['LLM_PROVIDER'] = opt.provider + # Re-import module-level variable + from pageindex import utils + utils.LLM_PROVIDER = opt.provider + if hasattr(opt, 'api_base_url') and opt.api_base_url: + os.environ['API_BASE_URL'] = opt.api_base_url + from pageindex import utils + utils.API_BASE_URL = opt.api_base_url + is_valid_pdf = ( - (isinstance(doc, str) and os.path.isfile(doc) and doc.lower().endswith(".pdf")) or + (isinstance(doc, str) and os.path.isfile(doc) and doc.lower().endswith(".pdf")) or isinstance(doc, BytesIO) ) if not is_valid_pdf: raise ValueError("Unsupported input type. Expected a PDF file path or BytesIO object.") print('Parsing PDF...') - page_list = get_page_tokens(doc) + page_list = get_page_tokens(doc, model=opt.model) logger.info({'total_page_number': len(page_list)}) logger.info({'total_token': sum([page[1] for page in page_list])}) @@ -1100,7 +1111,8 @@ async def page_index_builder(): return asyncio.run(page_index_builder()) -def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None, +def page_index(doc, model=None, provider=None, api_base_url=None, + toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None, if_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None, if_add_node_text=None): user_opt = { diff --git a/pageindex/utils.py b/pageindex/utils.py index dc7acd888..12d2de02d 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -18,17 +18,87 @@ from types import SimpleNamespace as config CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") +ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") +LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai") # "openai", "anthropic", or "ollama" +API_BASE_URL = os.getenv("API_BASE_URL") # e.g. http://localhost:11434/v1 for Ollama + + +def _get_provider_config(provider=None, api_key=None, base_url=None): + """Resolve provider, api_key, and base_url from args or environment.""" + provider = provider or LLM_PROVIDER + if provider == "ollama": + return { + "provider": "ollama", + "api_key": api_key or "ollama", + "base_url": base_url or API_BASE_URL or "http://localhost:11434/v1", + } + elif provider == "anthropic": + return { + "provider": "anthropic", + "api_key": api_key or ANTHROPIC_API_KEY, + } + else: # openai (default) + cfg = { + "provider": "openai", + "api_key": api_key or CHATGPT_API_KEY, + } + if base_url or API_BASE_URL: + cfg["base_url"] = base_url or API_BASE_URL + return cfg + + +def _call_anthropic(model, messages, api_key): + """Synchronous Anthropic API call.""" + import anthropic + client = anthropic.Anthropic(api_key=api_key) + # Convert openai-style messages: extract system if present + system_msg = None + user_messages = [] + for m in messages: + if m["role"] == "system": + system_msg = m["content"] + else: + user_messages.append(m) + kwargs = {"model": model, "max_tokens": 8192, "temperature": 0, "messages": user_messages} + if system_msg: + kwargs["system"] = system_msg + response = client.messages.create(**kwargs) + return response.content[0].text + + +async def _call_anthropic_async(model, messages, api_key): + """Asynchronous Anthropic API call.""" + import anthropic + client = anthropic.AsyncAnthropic(api_key=api_key) + system_msg = None + user_messages = [] + for m in messages: + if m["role"] == "system": + system_msg = m["content"] + else: + user_messages.append(m) + kwargs = {"model": model, "max_tokens": 8192, "temperature": 0, "messages": user_messages} + if system_msg: + kwargs["system"] = system_msg + response = await client.messages.create(**kwargs) + return response.content[0].text + def count_tokens(text, model=None): if not text: return 0 - enc = tiktoken.encoding_for_model(model) + try: + enc = tiktoken.encoding_for_model(model) + except KeyError: + # Fallback for non-OpenAI models (Anthropic, Ollama, etc.) + enc = tiktoken.get_encoding("cl100k_base") tokens = enc.encode(text) return len(tokens) -def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def ChatGPT_API_with_finish_reason(model, prompt, api_key=None, chat_history=None, provider=None, base_url=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + pcfg = _get_provider_config(provider, api_key, base_url) + for i in range(max_retries): try: if chat_history: @@ -36,31 +106,40 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ messages.append({"role": "user", "content": prompt}) else: messages = [{"role": "user", "content": prompt}] - - response = client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - if response.choices[0].finish_reason == "length": - return response.choices[0].message.content, "max_output_reached" + + if pcfg["provider"] == "anthropic": + content = _call_anthropic(model, messages, pcfg["api_key"]) + return content, "finished" else: - return response.choices[0].message.content, "finished" + client_kwargs = {"api_key": pcfg["api_key"]} + if "base_url" in pcfg: + client_kwargs["base_url"] = pcfg["base_url"] + client = openai.OpenAI(**client_kwargs) + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + if response.choices[0].finish_reason == "length": + return response.choices[0].message.content, "max_output_reached" + else: + return response.choices[0].message.content, "finished" except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - time.sleep(1) # Wait for 1秒 before retrying + time.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" -def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def ChatGPT_API(model, prompt, api_key=None, chat_history=None, provider=None, base_url=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + pcfg = _get_provider_config(provider, api_key, base_url) + for i in range(max_retries): try: if chat_history: @@ -68,41 +147,56 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): messages.append({"role": "user", "content": prompt}) else: messages = [{"role": "user", "content": prompt}] - - response = client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - - return response.choices[0].message.content + + if pcfg["provider"] == "anthropic": + return _call_anthropic(model, messages, pcfg["api_key"]) + else: + client_kwargs = {"api_key": pcfg["api_key"]} + if "base_url" in pcfg: + client_kwargs["base_url"] = pcfg["base_url"] + client = openai.OpenAI(**client_kwargs) + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + return response.choices[0].message.content + except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - time.sleep(1) # Wait for 1秒 before retrying + time.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" -async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): +async def ChatGPT_API_async(model, prompt, api_key=None, provider=None, base_url=None): max_retries = 10 messages = [{"role": "user", "content": prompt}] + pcfg = _get_provider_config(provider, api_key, base_url) + for i in range(max_retries): try: - async with openai.AsyncOpenAI(api_key=api_key) as client: - response = await client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - return response.choices[0].message.content + if pcfg["provider"] == "anthropic": + return await _call_anthropic_async(model, messages, pcfg["api_key"]) + else: + client_kwargs = {"api_key": pcfg["api_key"]} + if "base_url" in pcfg: + client_kwargs["base_url"] = pcfg["base_url"] + async with openai.AsyncOpenAI(**client_kwargs) as client: + response = await client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + return response.choices[0].message.content except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - await asyncio.sleep(1) # Wait for 1s before retrying + await asyncio.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" @@ -411,7 +505,10 @@ def add_preface_if_needed(data): def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"): - enc = tiktoken.encoding_for_model(model) + try: + enc = tiktoken.encoding_for_model(model) + except KeyError: + enc = tiktoken.get_encoding("cl100k_base") if pdf_parser == "PyPDF2": pdf_reader = PyPDF2.PdfReader(pdf_path) page_list = [] diff --git a/requirements.txt b/requirements.txt index 463db58f1..cac9713a1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ openai==1.101.0 +anthropic>=0.34.0 pymupdf==1.26.4 PyPDF2==3.0.1 python-dotenv==1.1.0 diff --git a/run_pageindex.py b/run_pageindex.py index 107024505..38ad1b5dd 100644 --- a/run_pageindex.py +++ b/run_pageindex.py @@ -10,9 +10,14 @@ parser.add_argument('--pdf_path', type=str, help='Path to the PDF file') parser.add_argument('--md_path', type=str, help='Path to the Markdown file') - parser.add_argument('--model', type=str, default='gpt-4o-2024-11-20', help='Model to use') + parser.add_argument('--model', type=str, default='gpt-4o-2024-11-20', help='Model to use (e.g. gpt-4o-2024-11-20, claude-sonnet-4-20250514, llama3)') + parser.add_argument('--provider', type=str, default='openai', + choices=['openai', 'anthropic', 'ollama'], + help='LLM provider: openai, anthropic, or ollama (default: openai)') + parser.add_argument('--api-base-url', type=str, default=None, + help='Custom API base URL (e.g. http://localhost:11434/v1 for Ollama)') - parser.add_argument('--toc-check-pages', type=int, default=20, + parser.add_argument('--toc-check-pages', type=int, default=20, help='Number of pages to check for table of contents (PDF only)') parser.add_argument('--max-pages-per-node', type=int, default=10, help='Maximum number of pages per node (PDF only)') @@ -54,6 +59,8 @@ # Configure options opt = config( model=args.model, + provider=args.provider, + api_base_url=args.api_base_url, toc_check_page_num=args.toc_check_pages, max_page_num_each_node=args.max_pages_per_node, max_token_num_each_node=args.max_tokens_per_node, @@ -98,6 +105,8 @@ # Create options dict with user args user_opt = { 'model': args.model, + 'provider': args.provider, + 'api_base_url': args.api_base_url, 'if_add_node_summary': args.if_add_node_summary, 'if_add_doc_description': args.if_add_doc_description, 'if_add_node_text': args.if_add_node_text,