diff --git a/README.md b/README.md
index 7180efd5a..9e60a59ff 100644
--- a/README.md
+++ b/README.md
@@ -147,27 +147,49 @@ You can follow these steps to generate a PageIndex tree from a PDF document.
pip3 install --upgrade -r requirements.txt
```
-### 2. Set your OpenAI API key
+### 2. Set your API key
-Create a `.env` file in the root directory and add your API key:
+Create a `.env` file in the root directory and add your API key for your chosen provider:
```bash
+# OpenAI (default)
CHATGPT_API_KEY=your_openai_key_here
+
+# Anthropic (optional)
+ANTHROPIC_API_KEY=your_anthropic_key_here
+
+# Ollama — no API key needed, just have Ollama running locally
```
### 3. Run PageIndex on your PDF
+**OpenAI (default):**
```bash
python3 run_pageindex.py --pdf_path /path/to/your/document.pdf
```
+**Anthropic:**
+```bash
+python3 run_pageindex.py --pdf_path /path/to/your/document.pdf \
+ --provider anthropic --model claude-sonnet-4-20250514
+```
+
+**Ollama (local models):**
+```bash
+# Make sure Ollama is running (ollama serve)
+python3 run_pageindex.py --pdf_path /path/to/your/document.pdf \
+ --provider ollama --model llama3
+```
+
Optional parameters
You can customize the processing with additional optional arguments:
```
---model OpenAI model to use (default: gpt-4o-2024-11-20)
+--model Model to use (default: gpt-4o-2024-11-20)
+--provider LLM provider: openai, anthropic, or ollama (default: openai)
+--api-base-url Custom API base URL (e.g. http://localhost:11434/v1 for Ollama)
--toc-check-pages Pages to check for table of contents (default: 20)
--max-pages-per-node Max pages per node (default: 10)
--max-tokens-per-node Max tokens per node (default: 20000)
@@ -175,6 +197,25 @@ You can customize the processing with additional optional arguments:
--if-add-node-summary Add node summary (yes/no, default: yes)
--if-add-doc-description Add doc description (yes/no, default: yes)
```
+
+You can also set the provider via environment variables instead of CLI flags:
+```bash
+export LLM_PROVIDER=ollama # or "anthropic"
+export API_BASE_URL=http://localhost:11434/v1 # optional, for custom endpoints
+```
+
+
+
+Supported LLM Providers
+
+
+| Provider | Example Models | API Key Env Var | Notes |
+|----------|---------------|-----------------|-------|
+| **OpenAI** (default) | `gpt-4o-2024-11-20`, `gpt-4o-mini` | `CHATGPT_API_KEY` | Full support, recommended |
+| **Anthropic** | `claude-sonnet-4-20250514`, `claude-haiku-4-5-20251001` | `ANTHROPIC_API_KEY` | Full support |
+| **Ollama** | `llama3`, `mistral`, `qwen2.5` | _(none needed)_ | Requires Ollama running locally. Uses OpenAI-compatible API at `http://localhost:11434/v1` |
+
+**Note:** PageIndex relies on structured JSON output from the LLM. For best results, use capable models (GPT-4o, Claude Sonnet/Opus, or large Ollama models like Llama 3 70B+). Smaller local models may produce lower-quality tree structures.
diff --git a/pageindex/config.yaml b/pageindex/config.yaml
index fd73e3a2c..d9e86b730 100644
--- a/pageindex/config.yaml
+++ b/pageindex/config.yaml
@@ -1,4 +1,6 @@
model: "gpt-4o-2024-11-20"
+provider: "openai" # "openai", "anthropic", or "ollama"
+api_base_url: null # Custom API base URL (e.g. http://localhost:11434/v1 for Ollama)
toc_check_page_num: 20
max_page_num_each_node: 10
max_token_num_each_node: 20000
diff --git a/pageindex/page_index.py b/pageindex/page_index.py
index 39018c4df..2f4165304 100644
--- a/pageindex/page_index.py
+++ b/pageindex/page_index.py
@@ -1057,16 +1057,27 @@ async def tree_parser(page_list, opt, doc=None, logger=None):
def page_index_main(doc, opt=None):
logger = JsonLogger(doc)
-
+
+ # Set provider config from opt so all downstream API calls pick it up
+ if hasattr(opt, 'provider') and opt.provider:
+ os.environ['LLM_PROVIDER'] = opt.provider
+ # Re-import module-level variable
+ from pageindex import utils
+ utils.LLM_PROVIDER = opt.provider
+ if hasattr(opt, 'api_base_url') and opt.api_base_url:
+ os.environ['API_BASE_URL'] = opt.api_base_url
+ from pageindex import utils
+ utils.API_BASE_URL = opt.api_base_url
+
is_valid_pdf = (
- (isinstance(doc, str) and os.path.isfile(doc) and doc.lower().endswith(".pdf")) or
+ (isinstance(doc, str) and os.path.isfile(doc) and doc.lower().endswith(".pdf")) or
isinstance(doc, BytesIO)
)
if not is_valid_pdf:
raise ValueError("Unsupported input type. Expected a PDF file path or BytesIO object.")
print('Parsing PDF...')
- page_list = get_page_tokens(doc)
+ page_list = get_page_tokens(doc, model=opt.model)
logger.info({'total_page_number': len(page_list)})
logger.info({'total_token': sum([page[1] for page in page_list])})
@@ -1100,7 +1111,8 @@ async def page_index_builder():
return asyncio.run(page_index_builder())
-def page_index(doc, model=None, toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None,
+def page_index(doc, model=None, provider=None, api_base_url=None,
+ toc_check_page_num=None, max_page_num_each_node=None, max_token_num_each_node=None,
if_add_node_id=None, if_add_node_summary=None, if_add_doc_description=None, if_add_node_text=None):
user_opt = {
diff --git a/pageindex/utils.py b/pageindex/utils.py
index dc7acd888..12d2de02d 100644
--- a/pageindex/utils.py
+++ b/pageindex/utils.py
@@ -18,17 +18,87 @@
from types import SimpleNamespace as config
CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
+ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
+LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai") # "openai", "anthropic", or "ollama"
+API_BASE_URL = os.getenv("API_BASE_URL") # e.g. http://localhost:11434/v1 for Ollama
+
+
+def _get_provider_config(provider=None, api_key=None, base_url=None):
+ """Resolve provider, api_key, and base_url from args or environment."""
+ provider = provider or LLM_PROVIDER
+ if provider == "ollama":
+ return {
+ "provider": "ollama",
+ "api_key": api_key or "ollama",
+ "base_url": base_url or API_BASE_URL or "http://localhost:11434/v1",
+ }
+ elif provider == "anthropic":
+ return {
+ "provider": "anthropic",
+ "api_key": api_key or ANTHROPIC_API_KEY,
+ }
+ else: # openai (default)
+ cfg = {
+ "provider": "openai",
+ "api_key": api_key or CHATGPT_API_KEY,
+ }
+ if base_url or API_BASE_URL:
+ cfg["base_url"] = base_url or API_BASE_URL
+ return cfg
+
+
+def _call_anthropic(model, messages, api_key):
+ """Synchronous Anthropic API call."""
+ import anthropic
+ client = anthropic.Anthropic(api_key=api_key)
+ # Convert openai-style messages: extract system if present
+ system_msg = None
+ user_messages = []
+ for m in messages:
+ if m["role"] == "system":
+ system_msg = m["content"]
+ else:
+ user_messages.append(m)
+ kwargs = {"model": model, "max_tokens": 8192, "temperature": 0, "messages": user_messages}
+ if system_msg:
+ kwargs["system"] = system_msg
+ response = client.messages.create(**kwargs)
+ return response.content[0].text
+
+
+async def _call_anthropic_async(model, messages, api_key):
+ """Asynchronous Anthropic API call."""
+ import anthropic
+ client = anthropic.AsyncAnthropic(api_key=api_key)
+ system_msg = None
+ user_messages = []
+ for m in messages:
+ if m["role"] == "system":
+ system_msg = m["content"]
+ else:
+ user_messages.append(m)
+ kwargs = {"model": model, "max_tokens": 8192, "temperature": 0, "messages": user_messages}
+ if system_msg:
+ kwargs["system"] = system_msg
+ response = await client.messages.create(**kwargs)
+ return response.content[0].text
+
def count_tokens(text, model=None):
if not text:
return 0
- enc = tiktoken.encoding_for_model(model)
+ try:
+ enc = tiktoken.encoding_for_model(model)
+ except KeyError:
+ # Fallback for non-OpenAI models (Anthropic, Ollama, etc.)
+ enc = tiktoken.get_encoding("cl100k_base")
tokens = enc.encode(text)
return len(tokens)
-def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
+def ChatGPT_API_with_finish_reason(model, prompt, api_key=None, chat_history=None, provider=None, base_url=None):
max_retries = 10
- client = openai.OpenAI(api_key=api_key)
+ pcfg = _get_provider_config(provider, api_key, base_url)
+
for i in range(max_retries):
try:
if chat_history:
@@ -36,31 +106,40 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_
messages.append({"role": "user", "content": prompt})
else:
messages = [{"role": "user", "content": prompt}]
-
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- temperature=0,
- )
- if response.choices[0].finish_reason == "length":
- return response.choices[0].message.content, "max_output_reached"
+
+ if pcfg["provider"] == "anthropic":
+ content = _call_anthropic(model, messages, pcfg["api_key"])
+ return content, "finished"
else:
- return response.choices[0].message.content, "finished"
+ client_kwargs = {"api_key": pcfg["api_key"]}
+ if "base_url" in pcfg:
+ client_kwargs["base_url"] = pcfg["base_url"]
+ client = openai.OpenAI(**client_kwargs)
+ response = client.chat.completions.create(
+ model=model,
+ messages=messages,
+ temperature=0,
+ )
+ if response.choices[0].finish_reason == "length":
+ return response.choices[0].message.content, "max_output_reached"
+ else:
+ return response.choices[0].message.content, "finished"
except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
if i < max_retries - 1:
- time.sleep(1) # Wait for 1秒 before retrying
+ time.sleep(1)
else:
logging.error('Max retries reached for prompt: ' + prompt)
return "Error"
-def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
+def ChatGPT_API(model, prompt, api_key=None, chat_history=None, provider=None, base_url=None):
max_retries = 10
- client = openai.OpenAI(api_key=api_key)
+ pcfg = _get_provider_config(provider, api_key, base_url)
+
for i in range(max_retries):
try:
if chat_history:
@@ -68,41 +147,56 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
messages.append({"role": "user", "content": prompt})
else:
messages = [{"role": "user", "content": prompt}]
-
- response = client.chat.completions.create(
- model=model,
- messages=messages,
- temperature=0,
- )
-
- return response.choices[0].message.content
+
+ if pcfg["provider"] == "anthropic":
+ return _call_anthropic(model, messages, pcfg["api_key"])
+ else:
+ client_kwargs = {"api_key": pcfg["api_key"]}
+ if "base_url" in pcfg:
+ client_kwargs["base_url"] = pcfg["base_url"]
+ client = openai.OpenAI(**client_kwargs)
+ response = client.chat.completions.create(
+ model=model,
+ messages=messages,
+ temperature=0,
+ )
+ return response.choices[0].message.content
+
except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
if i < max_retries - 1:
- time.sleep(1) # Wait for 1秒 before retrying
+ time.sleep(1)
else:
logging.error('Max retries reached for prompt: ' + prompt)
return "Error"
-async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY):
+async def ChatGPT_API_async(model, prompt, api_key=None, provider=None, base_url=None):
max_retries = 10
messages = [{"role": "user", "content": prompt}]
+ pcfg = _get_provider_config(provider, api_key, base_url)
+
for i in range(max_retries):
try:
- async with openai.AsyncOpenAI(api_key=api_key) as client:
- response = await client.chat.completions.create(
- model=model,
- messages=messages,
- temperature=0,
- )
- return response.choices[0].message.content
+ if pcfg["provider"] == "anthropic":
+ return await _call_anthropic_async(model, messages, pcfg["api_key"])
+ else:
+ client_kwargs = {"api_key": pcfg["api_key"]}
+ if "base_url" in pcfg:
+ client_kwargs["base_url"] = pcfg["base_url"]
+ async with openai.AsyncOpenAI(**client_kwargs) as client:
+ response = await client.chat.completions.create(
+ model=model,
+ messages=messages,
+ temperature=0,
+ )
+ return response.choices[0].message.content
except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
if i < max_retries - 1:
- await asyncio.sleep(1) # Wait for 1s before retrying
+ await asyncio.sleep(1)
else:
logging.error('Max retries reached for prompt: ' + prompt)
return "Error"
@@ -411,7 +505,10 @@ def add_preface_if_needed(data):
def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"):
- enc = tiktoken.encoding_for_model(model)
+ try:
+ enc = tiktoken.encoding_for_model(model)
+ except KeyError:
+ enc = tiktoken.get_encoding("cl100k_base")
if pdf_parser == "PyPDF2":
pdf_reader = PyPDF2.PdfReader(pdf_path)
page_list = []
diff --git a/requirements.txt b/requirements.txt
index 463db58f1..cac9713a1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,5 @@
openai==1.101.0
+anthropic>=0.34.0
pymupdf==1.26.4
PyPDF2==3.0.1
python-dotenv==1.1.0
diff --git a/run_pageindex.py b/run_pageindex.py
index 107024505..38ad1b5dd 100644
--- a/run_pageindex.py
+++ b/run_pageindex.py
@@ -10,9 +10,14 @@
parser.add_argument('--pdf_path', type=str, help='Path to the PDF file')
parser.add_argument('--md_path', type=str, help='Path to the Markdown file')
- parser.add_argument('--model', type=str, default='gpt-4o-2024-11-20', help='Model to use')
+ parser.add_argument('--model', type=str, default='gpt-4o-2024-11-20', help='Model to use (e.g. gpt-4o-2024-11-20, claude-sonnet-4-20250514, llama3)')
+ parser.add_argument('--provider', type=str, default='openai',
+ choices=['openai', 'anthropic', 'ollama'],
+ help='LLM provider: openai, anthropic, or ollama (default: openai)')
+ parser.add_argument('--api-base-url', type=str, default=None,
+ help='Custom API base URL (e.g. http://localhost:11434/v1 for Ollama)')
- parser.add_argument('--toc-check-pages', type=int, default=20,
+ parser.add_argument('--toc-check-pages', type=int, default=20,
help='Number of pages to check for table of contents (PDF only)')
parser.add_argument('--max-pages-per-node', type=int, default=10,
help='Maximum number of pages per node (PDF only)')
@@ -54,6 +59,8 @@
# Configure options
opt = config(
model=args.model,
+ provider=args.provider,
+ api_base_url=args.api_base_url,
toc_check_page_num=args.toc_check_pages,
max_page_num_each_node=args.max_pages_per_node,
max_token_num_each_node=args.max_tokens_per_node,
@@ -98,6 +105,8 @@
# Create options dict with user args
user_opt = {
'model': args.model,
+ 'provider': args.provider,
+ 'api_base_url': args.api_base_url,
'if_add_node_summary': args.if_add_node_summary,
'if_add_doc_description': args.if_add_doc_description,
'if_add_node_text': args.if_add_node_text,