diff --git a/superpipe/clients.py b/superpipe/clients.py index daa23b1..af2c23e 100644 --- a/superpipe/clients.py +++ b/superpipe/clients.py @@ -1,6 +1,7 @@ import requests import os from openai import OpenAI +from openai import AzureOpenAI from anthropic import Anthropic from superpipe.models import * @@ -10,10 +11,17 @@ openrouter_models = [] -def init_openai(api_key, base_url=None): - openai_client = OpenAI(api_key=api_key, base_url=base_url) - client_for_model[gpt35] = openai_client - client_for_model[gpt4] = openai_client +def init_openai(api_key, base_url=None, api_version=None): + if base_url and api_version: + openai_client = AzureOpenAI(api_key=api_key, + azure_endpoint=base_url, + api_version=api_version) + client_for_model[gpt35] = openai_client + client_for_model[gpt4] = openai_client + else: + openai_client = OpenAI(api_key=api_key) + client_for_model[gpt35] = openai_client + client_for_model[gpt4] = openai_client def init_anthropic(api_key): @@ -45,8 +53,13 @@ def get_client(model): if client_for_model.get(gpt35) is None or \ client_for_model.get(gpt4) is None: api_key = os.getenv("OPENAI_API_KEY") + base_url, api_version = None, None + if client_for_model.get("OPENAI_API_BASE") is None: + base_url = os.getenv("OPENAI_API_BASE") + if client_for_model.get("OPEN_API_VERSION") is None: + api_version = os.getenv("OPENAI_API_VERSION") if api_key is not None: - init_openai(api_key) + init_openai(api_key, base_url, api_version) if client_for_model.get(claude3_haiku) is None or \ client_for_model.get(claude3_sonnet) is None or \ client_for_model.get(claude3_opus) is None: diff --git a/superpipe/llm.py b/superpipe/llm.py index 5c61bc9..651ea61 100644 --- a/superpipe/llm.py +++ b/superpipe/llm.py @@ -193,7 +193,11 @@ def get_structured_llm_response_openai( args: CompletionCreateParamsNonStreaming = {}) -> StructuredLLMResponse: system = "You are a helpful assistant designed to output JSON." updated_args = {**args, "response_format": {"type": "json_object"}} + response = get_llm_response_openai(prompt, model, updated_args, system) + if response.error: # models before 1163 do not support response_format param + response = get_llm_response_openai(prompt, model, args, system) + return StructuredLLMResponse( input_tokens=response.input_tokens, output_tokens=response.output_tokens,