Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,20 @@ You can follow these steps to generate a PageIndex tree from a PDF document.
pip3 install --upgrade -r requirements.txt
```

### 2. Set your OpenAI API key
### 2. Set your API key

Create a `.env` file in the root directory and add your API key:
Create a `.env` file in the root directory and add at least one API key:

```bash
# For OpenAI models (e.g. gpt-4o, gpt-4o-2024-11-20)
CHATGPT_API_KEY=your_openai_key_here

# For Gemini models (e.g. gemini-1.5-pro, gemini-1.5-flash)
GEMINI_API_KEY=your_gemini_key_here
```

Use `--model` to choose the model; if the model name starts with `gemini-`, `GEMINI_API_KEY` is used; otherwise `CHATGPT_API_KEY` is used.

### 3. Run PageIndex on your PDF

```bash
Expand All @@ -167,7 +173,7 @@ python3 run_pageindex.py --pdf_path /path/to/your/document.pdf
You can customize the processing with additional optional arguments:

```
--model OpenAI model to use (default: gpt-4o-2024-11-20)
--model LLM model: OpenAI (e.g. gpt-4o-2024-11-20) or Gemini (e.g. gemini-1.5-pro) (default: gpt-4o-2024-11-20)
--toc-check-pages Pages to check for table of contents (default: 20)
--max-pages-per-node Max pages per node (default: 10)
--max-tokens-per-node Max tokens per node (default: 20000)
Expand Down
112 changes: 104 additions & 8 deletions pageindex/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import tiktoken
import openai
import logging
Expand All @@ -18,17 +19,80 @@
from types import SimpleNamespace as config

CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")

def _use_gemini(model, api_key):
"""Return (use_gemini: bool, key to use). Uses Gemini when model starts with 'gemini-' and key is available."""
if not model or not (str(model).startswith("gemini-")):
return False, api_key or CHATGPT_API_KEY
gemini_key = api_key or GEMINI_API_KEY
if gemini_key:
return True, gemini_key
return False, api_key or CHATGPT_API_KEY

def _gemini_generate_sync(model, prompt, api_key, chat_history=None):
"""Sync Gemini generate_content via google.genai. Returns (text, finish_reason)."""
from google import genai
from google.genai import types
client = genai.Client(api_key=api_key)
config = types.GenerateContentConfig(temperature=0)
if chat_history:
contents = []
for m in chat_history:
role = "user" if m.get("role") == "user" else "model"
contents.append(types.Content(role=role, parts=[types.Part.from_text(text=m.get("content", ""))]))
contents.append(types.Content(role="user", parts=[types.Part.from_text(text=prompt)]))
response = client.models.generate_content(model=model, contents=contents, config=config)
else:
response = client.models.generate_content(model=model, contents=prompt, config=config)
text = response.text if response.text else "Error"
finish_reason = "finished"
if response.candidates:
fr = getattr(response.candidates[0], "finish_reason", None)
if fr is not None and "MAX" in str(fr).upper():
finish_reason = "max_output_reached"
return text, finish_reason

async def _gemini_generate_async(model, prompt, api_key):
"""Async Gemini generate_content via google.genai. Returns text."""
from google import genai
from google.genai import types
client = genai.Client(api_key=api_key)
config = types.GenerateContentConfig(temperature=0)
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: client.models.generate_content(model=model, contents=prompt, config=config),
)
return response.text if response.text else "Error"

def count_tokens(text, model=None):
if not text:
return 0
enc = tiktoken.encoding_for_model(model)
if model and str(model).startswith("gemini-"):
return max(1, len(text) // 4)
enc = tiktoken.encoding_for_model(model or "gpt-4o")
tokens = enc.encode(text)
return len(tokens)

def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
use_gemini, key = _use_gemini(model, api_key)
if use_gemini:
max_retries = 10
for i in range(max_retries):
try:
return _gemini_generate_sync(model, prompt, key, chat_history=chat_history)
except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
if i < max_retries - 1:
time.sleep(1)
else:
logging.error('Max retries reached for prompt: ' + prompt)
return "Error", "error"
return "Error", "error"
max_retries = 10
client = openai.OpenAI(api_key=api_key)
client = openai.OpenAI(api_key=key)
for i in range(max_retries):
try:
if chat_history:
Expand All @@ -54,13 +118,29 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_
time.sleep(1) # Wait for 1秒 before retrying
else:
logging.error('Max retries reached for prompt: ' + prompt)
return "Error"
return "Error", "error"



def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
use_gemini, key = _use_gemini(model, api_key)
if use_gemini:
max_retries = 10
for i in range(max_retries):
try:
text, _ = _gemini_generate_sync(model, prompt, key, chat_history=chat_history)
return text
except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
if i < max_retries - 1:
time.sleep(1)
else:
logging.error('Max retries reached for prompt: ' + prompt)
return "Error"
return "Error"
max_retries = 10
client = openai.OpenAI(api_key=api_key)
client = openai.OpenAI(api_key=key)
for i in range(max_retries):
try:
if chat_history:
Expand All @@ -87,11 +167,26 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):


async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY):
use_gemini, key = _use_gemini(model, api_key)
if use_gemini:
max_retries = 10
for i in range(max_retries):
try:
return await _gemini_generate_async(model, prompt, key)
except Exception as e:
print('************* Retrying *************')
logging.error(f"Error: {e}")
if i < max_retries - 1:
await asyncio.sleep(1)
else:
logging.error('Max retries reached for prompt: ' + prompt)
return "Error"
return "Error"
max_retries = 10
messages = [{"role": "user", "content": prompt}]
for i in range(max_retries):
try:
async with openai.AsyncOpenAI(api_key=api_key) as client:
async with openai.AsyncOpenAI(api_key=key) as client:
response = await client.chat.completions.create(
model=model,
messages=messages,
Expand Down Expand Up @@ -411,14 +506,15 @@ def add_preface_if_needed(data):


def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"):
enc = tiktoken.encoding_for_model(model)
use_tiktoken = not (model and str(model).startswith("gemini-"))
enc = tiktoken.encoding_for_model(model or "gpt-4o") if use_tiktoken else None
if pdf_parser == "PyPDF2":
pdf_reader = PyPDF2.PdfReader(pdf_path)
page_list = []
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
page_text = page.extract_text()
token_length = len(enc.encode(page_text))
token_length = len(enc.encode(page_text)) if enc else max(1, len(page_text) // 4)
page_list.append((page_text, token_length))
return page_list
elif pdf_parser == "PyMuPDF":
Expand All @@ -430,7 +526,7 @@ def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"):
page_list = []
for page in doc:
page_text = page.get_text()
token_length = len(enc.encode(page_text))
token_length = len(enc.encode(page_text)) if enc else max(1, len(page_text) // 4)
page_list.append((page_text, token_length))
return page_list
else:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
openai==1.101.0
google-genai>=1.0.0
pymupdf==1.26.4
PyPDF2==3.0.1
python-dotenv==1.1.0
Expand Down
Loading