diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 39018c4d..87af148e 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -8,6 +8,7 @@ import os from concurrent.futures import ThreadPoolExecutor, as_completed +model = os.getenv("MODEL") ################### check title in page ######################################################### async def check_title_appearance(item, page_list, start_index=1, model=None): @@ -36,7 +37,7 @@ async def check_title_appearance(item, page_list, start_index=1, model=None): }} Directly return the final JSON structure. Do not output anything else.""" - response = await ChatGPT_API_async(model=model, prompt=prompt) + response = await LLM_API_async(model=model, prompt=prompt) response = extract_json(response) if 'answer' in response: answer = response['answer'] @@ -64,7 +65,7 @@ async def check_title_appearance_in_start(title, page_text, model=None, logger=N }} Directly return the final JSON structure. Do not output anything else.""" - response = await ChatGPT_API_async(model=model, prompt=prompt) + response = await LLM_API_async(model=model, prompt=prompt) response = extract_json(response) if logger: logger.info(f"Response: {response}") @@ -116,7 +117,7 @@ def toc_detector_single_page(content, model=None): Directly return the final JSON structure. Do not output anything else. Please note: abstract,summary, notation list, figure list, table list, etc. are not table of contents.""" - response = ChatGPT_API(model=model, prompt=prompt) + response = LLM_API(model=model, prompt=prompt) # print('response', response) json_content = extract_json(response) return json_content['toc_detected'] @@ -135,7 +136,7 @@ def check_if_toc_extraction_is_complete(content, toc, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\n Document:\n' + content + '\n Table of contents:\n' + toc - response = ChatGPT_API(model=model, prompt=prompt) + response = LLM_API(model=model, prompt=prompt) json_content = extract_json(response) return json_content['completed'] @@ -153,7 +154,7 @@ def check_if_toc_transformation_is_complete(content, toc, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc - response = ChatGPT_API(model=model, prompt=prompt) + response = LLM_API(model=model, prompt=prompt) json_content = extract_json(response) return json_content['completed'] @@ -165,7 +166,7 @@ def extract_toc_content(content, model=None): Directly return the full table of contents content. Do not output anything else.""" - response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) + response, finish_reason = LLM_API_with_finish_reason(prompt=prompt) if_complete = check_if_toc_transformation_is_complete(content, response, model) if if_complete == "yes" and finish_reason == "finished": @@ -176,7 +177,7 @@ def extract_toc_content(content, model=None): {"role": "assistant", "content": response}, ] prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" - new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history) + new_response, finish_reason = LLM_API_with_finish_reason(prompt=prompt, chat_history=chat_history) response = response + new_response if_complete = check_if_toc_transformation_is_complete(content, response, model) @@ -186,7 +187,7 @@ def extract_toc_content(content, model=None): {"role": "assistant", "content": response}, ] prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" - new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history) + new_response, finish_reason = LLM_API_with_finish_reason( prompt=prompt, chat_history=chat_history) response = response + new_response if_complete = check_if_toc_transformation_is_complete(content, response, model) @@ -212,7 +213,7 @@ def detect_page_index(toc_content, model=None): }} Directly return the final JSON structure. Do not output anything else.""" - response = ChatGPT_API(model=model, prompt=prompt) + response = LLM_API(model=model, prompt=prompt) json_content = extract_json(response) return json_content['page_index_given_in_toc'] @@ -261,7 +262,7 @@ def toc_index_extractor(toc, content, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = toc_extractor_prompt + '\nTable of contents:\n' + str(toc) + '\nDocument pages:\n' + content - response = ChatGPT_API(model=model, prompt=prompt) + response = LLM_API(model=model, prompt=prompt) json_content = extract_json(response) return json_content @@ -289,7 +290,7 @@ def toc_transformer(toc_content, model=None): Directly return the final JSON structure, do not output anything else. """ prompt = init_prompt + '\n Given table of contents\n:' + toc_content - last_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) + last_complete, finish_reason = LLM_API_with_finish_reason( prompt=prompt) if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) if if_complete == "yes" and finish_reason == "finished": last_complete = extract_json(last_complete) @@ -313,7 +314,7 @@ def toc_transformer(toc_content, model=None): Please continue the json structure, directly output the remaining part of the json structure.""" - new_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) + new_complete, finish_reason = LLM_API_with_finish_reason( prompt=prompt) if new_complete.startswith('```json'): new_complete = get_json_content(new_complete) @@ -474,7 +475,7 @@ def add_page_number_to_toc(part, structure, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = fill_prompt_seq + f"\n\nCurrent Partial Document:\n{part}\n\nGiven Structure\n{json.dumps(structure, indent=2)}\n" - current_json_raw = ChatGPT_API(model=model, prompt=prompt) + current_json_raw = LLM_API(model=model, prompt=prompt) json_result = extract_json(current_json_raw) for item in json_result: @@ -524,7 +525,7 @@ def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20"): Directly return the additional part of the final JSON structure. Do not output anything else.""" prompt = prompt + '\nGiven text\n:' + part + '\nPrevious tree structure\n:' + json.dumps(toc_content, indent=2) - response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) + response, finish_reason = LLM_API_with_finish_reason( prompt=prompt) if finish_reason == 'finished': return extract_json(response) else: @@ -558,7 +559,7 @@ def generate_toc_init(part, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\nGiven text\n:' + part - response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) + response, finish_reason = LLM_API_with_finish_reason( prompt=prompt) if finish_reason == 'finished': return extract_json(response) @@ -729,7 +730,7 @@ def check_toc(page_list, opt=None): ################### fix incorrect toc ######################################################### -def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20"): +def single_toc_item_index_fixer(section_title, content, model=model): toc_extractor_prompt = """ You are given a section title and several pages of a document, your job is to find the physical index of the start page of the section in the partial document. @@ -743,7 +744,7 @@ def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20 Directly return the final JSON structure. Do not output anything else.""" prompt = toc_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content - response = ChatGPT_API(model=model, prompt=prompt) + response = LLM_API(model=model, prompt=prompt) json_content = extract_json(response) return convert_physical_index_to_int(json_content['physical_index']) diff --git a/pageindex/providers/__init__.py b/pageindex/providers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pageindex/providers/anthropic_provider.py b/pageindex/providers/anthropic_provider.py new file mode 100644 index 00000000..51e307fb --- /dev/null +++ b/pageindex/providers/anthropic_provider.py @@ -0,0 +1,84 @@ +# import anthropic +# from pageindex.providers.base_llm import BaseLLM +# from pageindex.response_schema import LLMResponse + +# class AnthropicLLM(BaseLLM): + +# def __init__(self, api_key: str, model: str): +# self.client = anthropic.Anthropic(api_key=api_key) +# self.model = model + +# def generate(self, messages, **kwargs): +# response = self.client.messages.create( +# model=self.model, +# messages=messages, +# max_tokens=kwargs.get("max_tokens", 1024) +# ) + +# return LLMResponse( +# content=response.content[0].text, +# finish_reason=response.stop_reason, +# raw=response +# ) + +# async def generate_async(self, messages, **kwargs): +# response = await self.client.messages.create( +# model=self.model, +# messages=messages, +# max_tokens=kwargs.get("max_tokens", 1024) +# ) + +# return LLMResponse( +# content=response.content[0].text, +# finish_reason=response.stop_reason, +# raw=response +# ) + +from anthropic import Anthropic, AsyncAnthropic +from pageindex.providers.base_llm import BaseLLM + + +class AnthropicProvider(BaseLLM): + + def __init__(self, api_key: str): + self.client = Anthropic(api_key=api_key) + + def generate(self, model, messages, **kwargs): + response = self.client.messages.create( + model=model, + messages=messages, + max_tokens=kwargs.get("max_tokens", 1024), + temperature=kwargs.get("temperature", 0) + ) + + # Claude returns a list of content blocks + content = "" + for block in response.content: + if block.type == "text": + content += block.text + + return { + "content": content, + "finish_reason": response.stop_reason, + "raw": response + } + + async def agenerate(self, model, messages, **kwargs): + async with AsyncAnthropic(api_key=self.client.api_key) as client: + response = await client.messages.create( + model=model, + messages=messages, + max_tokens=kwargs.get("max_tokens", 1024), + temperature=kwargs.get("temperature", 0) + ) + + content = "" + for block in response.content: + if block.type == "text": + content += block.text + + return { + "content": content, + "finish_reason": response.stop_reason, + "raw": response + } \ No newline at end of file diff --git a/pageindex/providers/badrock_provider.py b/pageindex/providers/badrock_provider.py new file mode 100644 index 00000000..7ff577b3 --- /dev/null +++ b/pageindex/providers/badrock_provider.py @@ -0,0 +1,43 @@ +import json +import boto3 +from pageindex.providers.base_llm import BaseLLM + + +class BedrockProvider(BaseLLM): + + def __init__(self, aws_access_key, aws_secret_key, region): + self.client = boto3.client( + "bedrock-runtime", + region_name=region, + aws_access_key_id=aws_access_key, + aws_secret_access_key=aws_secret_key + ) + + def generate(self, model, messages, **kwargs): + body = json.dumps({ + "anthropic_version": "bedrock-2023-05-31", + "messages": messages, + "max_tokens": kwargs.get("max_tokens", 1024), + "temperature": kwargs.get("temperature", 0) + }) + + response = self.client.invoke_model( + modelId=model, + body=body + ) + + response_body = json.loads(response["body"].read()) + + content = "" + for block in response_body["content"]: + if block["type"] == "text": + content += block["text"] + + return { + "content": content, + "finish_reason": response_body.get("stop_reason"), + "raw": response_body + } + + async def agenerate(self, *args, **kwargs): + raise NotImplementedError("Async not implemented for Bedrock (boto3 is sync)") \ No newline at end of file diff --git a/pageindex/providers/base_llm.py b/pageindex/providers/base_llm.py new file mode 100644 index 00000000..7253ffed --- /dev/null +++ b/pageindex/providers/base_llm.py @@ -0,0 +1,12 @@ + +from abc import ABC, abstractmethod + +class BaseLLM(ABC): + + @abstractmethod + def generate(self, model: str, messages: list, **kwargs): + pass + + @abstractmethod + async def agenerate(self, model: str, messages: list, **kwargs): + pass \ No newline at end of file diff --git a/pageindex/providers/factory.py b/pageindex/providers/factory.py new file mode 100644 index 00000000..06827b60 --- /dev/null +++ b/pageindex/providers/factory.py @@ -0,0 +1,32 @@ +# llm/factory.py + +from pageindex.providers.open_router_provider import OpenRouterProvider +from pageindex.providers.badrock_provider import BedrockProvider +from pageindex.providers.groq_provider import GroqProvider +from pageindex.providers.anthropic_provider import AnthropicProvider +from pageindex.providers.gemini_provider import GeminiProvider +from pageindex.providers.open_ai_provider import OpenAIProvider + + +class LLMFactory: + + @staticmethod + def create(provider: str, api_key: str): + + if provider == "openai": + return OpenAIProvider(api_key) + + elif provider == "gemini": + return GeminiProvider(api_key) + elif provider =="anthropic": + return AnthropicProvider(api_key) + elif provider =='groq': + return GroqProvider(api_key) + elif provider =='aws-badrock': + return BedrockProvider(api_key) + elif provider == 'open-router': + return OpenRouterProvider(api_key) + else: + raise ValueError("Unsupported provider") + + diff --git a/pageindex/providers/gemini_provider.py b/pageindex/providers/gemini_provider.py new file mode 100644 index 00000000..a9ceda7c --- /dev/null +++ b/pageindex/providers/gemini_provider.py @@ -0,0 +1,60 @@ +from google import genai +from pageindex.providers.base_llm import BaseLLM +from google.genai import types + + +class GeminiProvider(BaseLLM): + + def __init__(self, api_key: str): + self.client = genai.Client(api_key=api_key) + + def _convert_messages(self, messages): + """ + Convert OpenAI-style messages into Gemini contents format. + """ + contents = [] + + for msg in messages: + contents.append({ + "role": msg["role"], + "parts": [{"text": msg["content"]}] + }) + + return contents + + def generate(self, model, messages, **kwargs): + contents = self._convert_messages(messages) + + response = self.client.models.generate_content( + model=model, + contents=contents, + config=types.GenerateContentConfig( + temperature=0, + max_output_tokens=2512 + ) + + ) + + return { + "content": response.text, + "finish_reason": getattr(response, "finish_reason", None), + "raw": response + } + + async def agenerate(self, model, messages, **kwargs): + contents = self._convert_messages(messages) + + response = await self.client.aio.models.generate_content( + model=model, + contents=contents, + config=types.GenerateContentConfig( + temperature=0, + max_output_tokens=2500 + ) + ) + + return { + "content": response.text, + "finish_reason": getattr(response, "finish_reason", None), + "raw": response + } \ No newline at end of file diff --git a/pageindex/providers/groq_provider.py b/pageindex/providers/groq_provider.py new file mode 100644 index 00000000..688db9e5 --- /dev/null +++ b/pageindex/providers/groq_provider.py @@ -0,0 +1,70 @@ +# from groq import Groq, AsyncGroq +# from pageindex.providers.base_llm import BaseLLM +# from pageindex.response_schema import LLMResponse + +# class GroqLLM(BaseLLM): + +# def __init__(self, api_key: str, model: str): +# self.client = Groq(api_key=api_key) +# self.async_client = AsyncGroq(api_key=api_key) +# self.model = model + +# def generate(self, messages, **kwargs): +# response = self.client.chat.completions.create( +# model=self.model, +# messages=messages, +# **kwargs +# ) +# return LLMResponse( +# content=response.choices[0].message.content, +# finish_reason=response.choices[0].finish_reason, +# raw=response +# ) + +# async def agenerate(self, model, messages, **kwargs): +# response = await self.async_client.chat.completions.create( +# model=self.model, +# messages=messages, +# **kwargs +# ) +# return LLMResponse( +# content=response.choices[0].message.content, +# finish_reason=response.choices[0].finish_reason, +# raw=response +# ) + +from groq import Groq, AsyncGroq +from pageindex.providers.base_llm import BaseLLM + + +class GroqProvider(BaseLLM): + + def __init__(self, api_key: str): + self.client = Groq(api_key=api_key) + + def generate(self, model, messages, **kwargs): + response = self.client.chat.completions.create( + model=model, + messages=messages, + temperature=kwargs.get("temperature", 0) + ) + + return { + "content": response.choices[0].message.content, + "finish_reason": response.choices[0].finish_reason, + "raw": response + } + + async def agenerate(self, model, messages, **kwargs): + async with AsyncGroq(api_key=self.client.api_key) as client: + response = await client.chat.completions.create( + model=model, + messages=messages, + temperature=kwargs.get("temperature", 0) + ) + + return { + "content": response.choices[0].message.content, + "finish_reason": response.choices[0].finish_reason, + "raw": response + } \ No newline at end of file diff --git a/pageindex/providers/open_ai_provider.py b/pageindex/providers/open_ai_provider.py new file mode 100644 index 00000000..f3598567 --- /dev/null +++ b/pageindex/providers/open_ai_provider.py @@ -0,0 +1,34 @@ +import openai +from pageindex.providers.base_llm import BaseLLM + +class OpenAIProvider(BaseLLM): + + def __init__(self, api_key): + self.client = openai.OpenAI(api_key=api_key) + + def generate(self, model, messages, **kwargs): + response = self.client.chat.completions.create( + model=model, + messages=messages, + temperature=kwargs.get("temperature", 0) + ) + + return { + "content": response.choices[0].message.content, + "finish_reason": response.choices[0].finish_reason, + "raw": response + } + + async def agenerate(self, model, messages, **kwargs): + async with openai.AsyncOpenAI(api_key=self.client.api_key) as client: + response = await client.chat.completions.create( + model=model, + messages=messages, + temperature=kwargs.get("temperature", 0) + ) + + return { + "content": response.choices[0].message.content, + "finish_reason": response.choices[0].finish_reason, + "raw": response + } \ No newline at end of file diff --git a/pageindex/providers/open_router_provider.py b/pageindex/providers/open_router_provider.py new file mode 100644 index 00000000..57d11f32 --- /dev/null +++ b/pageindex/providers/open_router_provider.py @@ -0,0 +1,42 @@ +import openai +from pageindex.providers.base_llm import BaseLLM + + +class OpenRouterProvider(BaseLLM): + + def __init__(self, api_key: str): + self.client = openai.OpenAI( + api_key=api_key, + base_url="https://openrouter.ai/api/v1" + ) + + def generate(self, model, messages, **kwargs): + response = self.client.chat.completions.create( + model=model, + messages=messages, + temperature=kwargs.get("temperature", 0) + ) + + return { + "content": response.choices[0].message.content, + "finish_reason": response.choices[0].finish_reason, + "raw": response + } + + async def agenerate(self, model, messages, **kwargs): + async with openai.AsyncOpenAI( + api_key=self.client.api_key, + base_url="https://openrouter.ai/api/v1" + ) as client: + + response = await client.chat.completions.create( + model=model, + messages=messages, + temperature=kwargs.get("temperature", 0) + ) + + return { + "content": response.choices[0].message.content, + "finish_reason": response.choices[0].finish_reason, + "raw": response + } \ No newline at end of file diff --git a/pageindex/response_schema.py b/pageindex/response_schema.py new file mode 100644 index 00000000..12e3648b --- /dev/null +++ b/pageindex/response_schema.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass +from typing import Optional + +@dataclass +class LLMResponse: + content: str + finish_reason: Optional[str] = None + raw: Optional[dict] = None \ No newline at end of file diff --git a/pageindex/utils.py b/pageindex/utils.py index dc7acd88..fc7e17e2 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -1,5 +1,6 @@ import tiktoken import openai +import re import logging import os from datetime import datetime @@ -16,19 +17,24 @@ import yaml from pathlib import Path from types import SimpleNamespace as config +from pageindex.providers.factory import LLMFactory +API_KEY = os.getenv("API_KEY") +PLATFORM = os.getenv("PROVIDER") +model = os.getenv("MODEL") +embedding_model = os.getenv("EMBEDDING_MODEL") -CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") - -def count_tokens(text, model=None): +def count_tokens(text, model=model): if not text: return 0 - enc = tiktoken.encoding_for_model(model) + # here we need to do dynamic + + enc = tiktoken.get_encoding(embedding_model) tokens = enc.encode(text) return len(tokens) -def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def LLM_API_with_finish_reason(prompt,model=model, api_key=API_KEY, chat_history=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + llm = LLMFactory.create(PLATFORM, API_KEY) for i in range(max_retries): try: if chat_history: @@ -37,15 +43,16 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ else: messages = [{"role": "user", "content": prompt}] - response = client.chat.completions.create( + + response = llm.generate( model=model, messages=messages, temperature=0, ) - if response.choices[0].finish_reason == "length": - return response.choices[0].message.content, "max_output_reached" + if response['finish_reason'] == "length": + return response['content'], "max_output_reached" else: - return response.choices[0].message.content, "finished" + return response['content'], "finished" except Exception as e: print('************* Retrying *************') @@ -58,9 +65,9 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ -def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def LLM_API(prompt, model = model, api_key=API_KEY, chat_history=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + llm = LLMFactory.create(PLATFORM, API_KEY) for i in range(max_retries): try: if chat_history: @@ -69,35 +76,35 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): else: messages = [{"role": "user", "content": prompt}] - response = client.chat.completions.create( + response = llm.generate( model=model, messages=messages, temperature=0, ) - return response.choices[0].message.content + return response['content'] except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - time.sleep(1) # Wait for 1秒 before retrying + time.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" -async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): +async def LLM_API_async(prompt,model= model, api_key=API_KEY): max_retries = 10 messages = [{"role": "user", "content": prompt}] + llm = LLMFactory.create(PLATFORM, API_KEY) for i in range(max_retries): try: - async with openai.AsyncOpenAI(api_key=api_key) as client: - response = await client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - return response.choices[0].message.content + response = await llm.agenerate( + model=model, + messages=messages, + temperature=0, + ) + return response['content'] except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") @@ -107,7 +114,7 @@ async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): logging.error('Max retries reached for prompt: ' + prompt) return "Error" - + def get_json_content(response): start_idx = response.find("```json") if start_idx != -1: @@ -410,8 +417,10 @@ def add_preface_if_needed(data): -def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"): - enc = tiktoken.encoding_for_model(model) +def get_page_tokens(pdf_path, model=model, pdf_parser="PyPDF2"): + # here we need to dynamic + + enc = tiktoken.get_encoding(embedding_model) if pdf_parser == "PyPDF2": pdf_reader = PyPDF2.PdfReader(pdf_path) page_list = [] @@ -609,7 +618,7 @@ async def generate_node_summary(node, model=None): Directly return the description, do not include any other text. """ - response = await ChatGPT_API_async(model, prompt) + response = await LLM_API_async(model, prompt) return response @@ -654,7 +663,7 @@ def generate_doc_description(structure, model=None): Directly return the description, do not include any other text. """ - response = ChatGPT_API(model, prompt) + response = LLM_API(model, prompt) return response diff --git a/run_pageindex.py b/run_pageindex.py index 10702450..61ce579e 100644 --- a/run_pageindex.py +++ b/run_pageindex.py @@ -3,7 +3,7 @@ import json from pageindex import * from pageindex.page_index_md import md_to_tree - +model = os.getenv("MODEL") if __name__ == "__main__": # Set up argument parser parser = argparse.ArgumentParser(description='Process PDF or Markdown document and generate structure') @@ -53,7 +53,7 @@ # Process PDF file # Configure options opt = config( - model=args.model, + model=model, toc_check_page_num=args.toc_check_pages, max_page_num_each_node=args.max_pages_per_node, max_token_num_each_node=args.max_tokens_per_node, @@ -97,7 +97,7 @@ # Create options dict with user args user_opt = { - 'model': args.model, + 'model': model, 'if_add_node_summary': args.if_add_node_summary, 'if_add_doc_description': args.if_add_doc_description, 'if_add_node_text': args.if_add_node_text,