Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions pageindex/page_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from concurrent.futures import ThreadPoolExecutor, as_completed

model = os.getenv("MODEL")

################### check title in page #########################################################
async def check_title_appearance(item, page_list, start_index=1, model=None):
Expand Down Expand Up @@ -36,7 +37,7 @@ async def check_title_appearance(item, page_list, start_index=1, model=None):
}}
Directly return the final JSON structure. Do not output anything else."""

response = await ChatGPT_API_async(model=model, prompt=prompt)
response = await LLM_API_async(model=model, prompt=prompt)
response = extract_json(response)
if 'answer' in response:
answer = response['answer']
Expand Down Expand Up @@ -64,7 +65,7 @@ async def check_title_appearance_in_start(title, page_text, model=None, logger=N
}}
Directly return the final JSON structure. Do not output anything else."""

response = await ChatGPT_API_async(model=model, prompt=prompt)
response = await LLM_API_async(model=model, prompt=prompt)
response = extract_json(response)
if logger:
logger.info(f"Response: {response}")
Expand Down Expand Up @@ -116,7 +117,7 @@ def toc_detector_single_page(content, model=None):
Directly return the final JSON structure. Do not output anything else.
Please note: abstract,summary, notation list, figure list, table list, etc. are not table of contents."""

response = ChatGPT_API(model=model, prompt=prompt)
response = LLM_API(model=model, prompt=prompt)
# print('response', response)
json_content = extract_json(response)
return json_content['toc_detected']
Expand All @@ -135,7 +136,7 @@ def check_if_toc_extraction_is_complete(content, toc, model=None):
Directly return the final JSON structure. Do not output anything else."""

prompt = prompt + '\n Document:\n' + content + '\n Table of contents:\n' + toc
response = ChatGPT_API(model=model, prompt=prompt)
response = LLM_API(model=model, prompt=prompt)
json_content = extract_json(response)
return json_content['completed']

Expand All @@ -153,7 +154,7 @@ def check_if_toc_transformation_is_complete(content, toc, model=None):
Directly return the final JSON structure. Do not output anything else."""

prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
response = ChatGPT_API(model=model, prompt=prompt)
response = LLM_API(model=model, prompt=prompt)
json_content = extract_json(response)
return json_content['completed']

Expand All @@ -165,7 +166,7 @@ def extract_toc_content(content, model=None):

Directly return the full table of contents content. Do not output anything else."""

response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
response, finish_reason = LLM_API_with_finish_reason(prompt=prompt)

if_complete = check_if_toc_transformation_is_complete(content, response, model)
if if_complete == "yes" and finish_reason == "finished":
Expand All @@ -176,7 +177,7 @@ def extract_toc_content(content, model=None):
{"role": "assistant", "content": response},
]
prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure"""
new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history)
new_response, finish_reason = LLM_API_with_finish_reason(prompt=prompt, chat_history=chat_history)
response = response + new_response
if_complete = check_if_toc_transformation_is_complete(content, response, model)

Expand All @@ -186,7 +187,7 @@ def extract_toc_content(content, model=None):
{"role": "assistant", "content": response},
]
prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure"""
new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history)
new_response, finish_reason = LLM_API_with_finish_reason( prompt=prompt, chat_history=chat_history)
response = response + new_response
if_complete = check_if_toc_transformation_is_complete(content, response, model)

Expand All @@ -212,7 +213,7 @@ def detect_page_index(toc_content, model=None):
}}
Directly return the final JSON structure. Do not output anything else."""

response = ChatGPT_API(model=model, prompt=prompt)
response = LLM_API(model=model, prompt=prompt)
json_content = extract_json(response)
return json_content['page_index_given_in_toc']

Expand Down Expand Up @@ -261,7 +262,7 @@ def toc_index_extractor(toc, content, model=None):
Directly return the final JSON structure. Do not output anything else."""

prompt = toc_extractor_prompt + '\nTable of contents:\n' + str(toc) + '\nDocument pages:\n' + content
response = ChatGPT_API(model=model, prompt=prompt)
response = LLM_API(model=model, prompt=prompt)
json_content = extract_json(response)
return json_content

Expand Down Expand Up @@ -289,7 +290,7 @@ def toc_transformer(toc_content, model=None):
Directly return the final JSON structure, do not output anything else. """

prompt = init_prompt + '\n Given table of contents\n:' + toc_content
last_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
last_complete, finish_reason = LLM_API_with_finish_reason( prompt=prompt)
if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model)
if if_complete == "yes" and finish_reason == "finished":
last_complete = extract_json(last_complete)
Expand All @@ -313,7 +314,7 @@ def toc_transformer(toc_content, model=None):

Please continue the json structure, directly output the remaining part of the json structure."""

new_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
new_complete, finish_reason = LLM_API_with_finish_reason( prompt=prompt)

if new_complete.startswith('```json'):
new_complete = get_json_content(new_complete)
Expand Down Expand Up @@ -474,7 +475,7 @@ def add_page_number_to_toc(part, structure, model=None):
Directly return the final JSON structure. Do not output anything else."""

prompt = fill_prompt_seq + f"\n\nCurrent Partial Document:\n{part}\n\nGiven Structure\n{json.dumps(structure, indent=2)}\n"
current_json_raw = ChatGPT_API(model=model, prompt=prompt)
current_json_raw = LLM_API(model=model, prompt=prompt)
json_result = extract_json(current_json_raw)

for item in json_result:
Expand Down Expand Up @@ -524,7 +525,7 @@ def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20"):
Directly return the additional part of the final JSON structure. Do not output anything else."""

prompt = prompt + '\nGiven text\n:' + part + '\nPrevious tree structure\n:' + json.dumps(toc_content, indent=2)
response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
response, finish_reason = LLM_API_with_finish_reason( prompt=prompt)
if finish_reason == 'finished':
return extract_json(response)
else:
Expand Down Expand Up @@ -558,7 +559,7 @@ def generate_toc_init(part, model=None):
Directly return the final JSON structure. Do not output anything else."""

prompt = prompt + '\nGiven text\n:' + part
response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt)
response, finish_reason = LLM_API_with_finish_reason( prompt=prompt)

if finish_reason == 'finished':
return extract_json(response)
Expand Down Expand Up @@ -729,7 +730,7 @@ def check_toc(page_list, opt=None):


################### fix incorrect toc #########################################################
def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20"):
def single_toc_item_index_fixer(section_title, content, model=model):
toc_extractor_prompt = """
You are given a section title and several pages of a document, your job is to find the physical index of the start page of the section in the partial document.

Expand All @@ -743,7 +744,7 @@ def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20
Directly return the final JSON structure. Do not output anything else."""

prompt = toc_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content
response = ChatGPT_API(model=model, prompt=prompt)
response = LLM_API(model=model, prompt=prompt)
json_content = extract_json(response)
return convert_physical_index_to_int(json_content['physical_index'])

Expand Down
Empty file added pageindex/providers/__init__.py
Empty file.
84 changes: 84 additions & 0 deletions pageindex/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# import anthropic
# from pageindex.providers.base_llm import BaseLLM
# from pageindex.response_schema import LLMResponse

# class AnthropicLLM(BaseLLM):

# def __init__(self, api_key: str, model: str):
# self.client = anthropic.Anthropic(api_key=api_key)
# self.model = model

# def generate(self, messages, **kwargs):
# response = self.client.messages.create(
# model=self.model,
# messages=messages,
# max_tokens=kwargs.get("max_tokens", 1024)
# )

# return LLMResponse(
# content=response.content[0].text,
# finish_reason=response.stop_reason,
# raw=response
# )

# async def generate_async(self, messages, **kwargs):
# response = await self.client.messages.create(
# model=self.model,
# messages=messages,
# max_tokens=kwargs.get("max_tokens", 1024)
# )

# return LLMResponse(
# content=response.content[0].text,
# finish_reason=response.stop_reason,
# raw=response
# )

from anthropic import Anthropic, AsyncAnthropic
from pageindex.providers.base_llm import BaseLLM


class AnthropicProvider(BaseLLM):

def __init__(self, api_key: str):
self.client = Anthropic(api_key=api_key)

def generate(self, model, messages, **kwargs):
response = self.client.messages.create(
model=model,
messages=messages,
max_tokens=kwargs.get("max_tokens", 1024),
temperature=kwargs.get("temperature", 0)
)

# Claude returns a list of content blocks
content = ""
for block in response.content:
if block.type == "text":
content += block.text

return {
"content": content,
"finish_reason": response.stop_reason,
"raw": response
}

async def agenerate(self, model, messages, **kwargs):
async with AsyncAnthropic(api_key=self.client.api_key) as client:
response = await client.messages.create(
model=model,
messages=messages,
max_tokens=kwargs.get("max_tokens", 1024),
temperature=kwargs.get("temperature", 0)
)

content = ""
for block in response.content:
if block.type == "text":
content += block.text

return {
"content": content,
"finish_reason": response.stop_reason,
"raw": response
}
43 changes: 43 additions & 0 deletions pageindex/providers/badrock_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import json
import boto3
from pageindex.providers.base_llm import BaseLLM


class BedrockProvider(BaseLLM):

def __init__(self, aws_access_key, aws_secret_key, region):
self.client = boto3.client(
"bedrock-runtime",
region_name=region,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key
)

def generate(self, model, messages, **kwargs):
body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"messages": messages,
"max_tokens": kwargs.get("max_tokens", 1024),
"temperature": kwargs.get("temperature", 0)
})

response = self.client.invoke_model(
modelId=model,
body=body
)

response_body = json.loads(response["body"].read())

content = ""
for block in response_body["content"]:
if block["type"] == "text":
content += block["text"]

return {
"content": content,
"finish_reason": response_body.get("stop_reason"),
"raw": response_body
}

async def agenerate(self, *args, **kwargs):
raise NotImplementedError("Async not implemented for Bedrock (boto3 is sync)")
12 changes: 12 additions & 0 deletions pageindex/providers/base_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

from abc import ABC, abstractmethod

class BaseLLM(ABC):

@abstractmethod
def generate(self, model: str, messages: list, **kwargs):
pass

@abstractmethod
async def agenerate(self, model: str, messages: list, **kwargs):
pass
32 changes: 32 additions & 0 deletions pageindex/providers/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# llm/factory.py

from pageindex.providers.open_router_provider import OpenRouterProvider
from pageindex.providers.badrock_provider import BedrockProvider
from pageindex.providers.groq_provider import GroqProvider
from pageindex.providers.anthropic_provider import AnthropicProvider
from pageindex.providers.gemini_provider import GeminiProvider
from pageindex.providers.open_ai_provider import OpenAIProvider


class LLMFactory:

@staticmethod
def create(provider: str, api_key: str):

if provider == "openai":
return OpenAIProvider(api_key)

elif provider == "gemini":
return GeminiProvider(api_key)
elif provider =="anthropic":
return AnthropicProvider(api_key)
elif provider =='groq':
return GroqProvider(api_key)
elif provider =='aws-badrock':
return BedrockProvider(api_key)
elif provider == 'open-router':
return OpenRouterProvider(api_key)
else:
raise ValueError("Unsupported provider")


Loading