diff --git a/pageindex/utils.py b/pageindex/utils.py index dc7acd888..f02eabf59 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -18,6 +18,25 @@ from types import SimpleNamespace as config CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") +# Novita AI - OpenAI-compatible API support +NOVITA_API_KEY = os.getenv("NOVITA_API_KEY") +NOVITA_BASE_URL = "https://api.novita.ai/openai" +NOVITA_MODEL = os.getenv("NOVITA_MODEL") +NOVITA_DEFAULT_MODEL = "deepseek/deepseek-r1" +DEFAULT_OPENAI_MODEL = "gpt-4o-2024-11-20" + +def get_openai_client(api_key=CHATGPT_API_KEY, async_client=False): + """Get OpenAI client - supports both OpenAI and Novita AI (OpenAI-compatible).""" + client_cls = openai.AsyncOpenAI if async_client else openai.OpenAI + if NOVITA_API_KEY and api_key == CHATGPT_API_KEY: + return client_cls(api_key=NOVITA_API_KEY, base_url=NOVITA_BASE_URL) + return client_cls(api_key=api_key) + +def resolve_chat_model(model, api_key=CHATGPT_API_KEY): + """Resolve model name for OpenAI-compatible providers.""" + if NOVITA_API_KEY and api_key == CHATGPT_API_KEY and model == DEFAULT_OPENAI_MODEL: + return NOVITA_MODEL or NOVITA_DEFAULT_MODEL + return model def count_tokens(text, model=None): if not text: @@ -28,7 +47,8 @@ def count_tokens(text, model=None): def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + client = get_openai_client(api_key=api_key) + resolved_model = resolve_chat_model(model, api_key=api_key) for i in range(max_retries): try: if chat_history: @@ -38,7 +58,7 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ messages = [{"role": "user", "content": prompt}] response = client.chat.completions.create( - model=model, + model=resolved_model, messages=messages, temperature=0, ) @@ -60,7 +80,8 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + client = get_openai_client(api_key=api_key) + resolved_model = resolve_chat_model(model, api_key=api_key) for i in range(max_retries): try: if chat_history: @@ -70,7 +91,7 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): messages = [{"role": "user", "content": prompt}] response = client.chat.completions.create( - model=model, + model=resolved_model, messages=messages, temperature=0, ) @@ -88,12 +109,13 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): max_retries = 10 + resolved_model = resolve_chat_model(model, api_key=api_key) messages = [{"role": "user", "content": prompt}] for i in range(max_retries): try: - async with openai.AsyncOpenAI(api_key=api_key) as client: + async with get_openai_client(api_key=api_key, async_client=True) as client: response = await client.chat.completions.create( - model=model, + model=resolved_model, messages=messages, temperature=0, ) @@ -709,4 +731,4 @@ def load(self, user_opt=None) -> config: self._validate_keys(user_dict) merged = {**self._default_dict, **user_dict} - return config(**merged) \ No newline at end of file + return config(**merged)