Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions .formatter.exs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Used by "mix format"
[
import_deps: [:ecto, :phoenix, :phoenix_live_view],
plugins: [Phoenix.LiveView.HTMLFormatter],
inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}", "pages/cookbook/**/*.{ex,exs}"]
import_deps: [:flint],
inputs: [
"{mix,.formatter}.exs",
"{config,lib,test}/**/*.{ex,exs}",
"pages/cookbook/**/*.{ex,exs}"
]
]
105 changes: 77 additions & 28 deletions lib/instructor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ defmodule Instructor do
* `:mode` - The mode to use when parsing the response, :tools, :json, :md_json (defaults to `:tools`), generally speaking you don't need to change this unless you are not using OpenAI.
* `:max_retries` - The maximum number of times to retry the LLM call if it fails, or does not pass validations.
(defaults to `0`)
* `:after_request` - A callback function that will take in the current `%Ecto.Changeset{}` along with the raw `%Req.Response{}` returned
from the adapter. This can be helpful with tracking metadata or usage information, such as rate-limit headers and token usage.

## Examples

Expand Down Expand Up @@ -130,22 +132,38 @@ defmodule Instructor do

case {response_model, is_stream} do
{{:partial, {:array, response_model}}, true} ->
do_streaming_partial_array_chat_completion(response_model, params, config)
do_streaming_partial_array_chat_completion(
response_model.__struct__() |> Ecto.Changeset.change(),
params,
config
)

{{:partial, response_model}, true} ->
do_streaming_partial_chat_completion(response_model, params, config)
do_streaming_partial_chat_completion(
response_model.__struct__() |> Ecto.Changeset.change(),
params,
config
)

{{:array, response_model}, true} ->
do_streaming_array_chat_completion(response_model, params, config)
do_streaming_array_chat_completion(
response_model.__struct__() |> Ecto.Changeset.change(),
params,
config
)

{{:array, response_model}, false} ->
params = Keyword.put(params, :stream, true)

do_streaming_array_chat_completion(response_model, params, config)
do_streaming_array_chat_completion(
response_model.__struct__() |> Ecto.Changeset.change(),
params,
config
)
|> Enum.to_list()

{response_model, false} ->
do_chat_completion(response_model, params, config)
do_chat_completion(response_model.__struct__() |> Ecto.Changeset.change(), params, config)

{_, true} ->
raise """
Expand Down Expand Up @@ -234,6 +252,38 @@ defmodule Instructor do
|> Ecto.Changeset.validate_required(fields)
end

def cast_all(%Ecto.Changeset{data: data} = changeset, params) do
response_model = data.__struct__
fields = response_model.__schema__(:fields) |> MapSet.new()
embedded_fields = response_model.__schema__(:embeds) |> MapSet.new()
associated_fields = response_model.__schema__(:associations) |> MapSet.new()

fields =
fields
|> MapSet.difference(embedded_fields)
|> MapSet.difference(associated_fields)

changeset =
changeset
|> Ecto.Changeset.cast(params, fields |> MapSet.to_list())

changeset =
for field <- embedded_fields, reduce: changeset do
changeset ->
changeset
|> Ecto.Changeset.cast_embed(field, with: &cast_all/2)
end

changeset =
for field <- associated_fields, reduce: changeset do
changeset ->
changeset
|> Ecto.Changeset.cast_assoc(field, with: &cast_all/2)
end

changeset
end

def cast_all(schema, params) do
response_model = schema.__struct__
fields = response_model.__schema__(:fields) |> MapSet.new()
Expand Down Expand Up @@ -415,27 +465,29 @@ defmodule Instructor do
end)
end

defp do_chat_completion(response_model, params, config) do
defp do_chat_completion(
%Ecto.Changeset{data: %{__struct__: response_model}} = changeset,
params,
config
) do
after_request = Keyword.get(params, :after_request, fn c, _r -> c end)
validation_context = Keyword.get(params, :validation_context, %{})
max_retries = Keyword.get(params, :max_retries)
mode = Keyword.get(params, :mode, :tools)
params = params_for_mode(mode, response_model, params)

model =
if is_ecto_schema(response_model) do
response_model.__struct__()
else
{%{}, response_model}
end

with {:ok, raw_response, params} <- do_adapter_chat_completion(params, config),
{%Ecto.Changeset{valid?: true} = changeset, raw_response} <-
{cast_all(model, params), raw_response},
{%Ecto.Changeset{valid?: true} = changeset, _raw_response} <-
{call_validate(response_model, changeset, validation_context), raw_response} do
with {:ok, {_raw_request, _raw_response} = result, params} <-
do_adapter_chat_completion(Keyword.drop(params, [:after_request]), config),
changeset = after_request.(changeset, result),
{%Ecto.Changeset{valid?: true} = changeset, result} <-
{cast_all(changeset, params), result},
{%Ecto.Changeset{valid?: true} = changeset, _result} <-
{call_validate(response_model, changeset, validation_context), result} do
{:ok, changeset |> Ecto.Changeset.apply_changes()}
else
{%Ecto.Changeset{} = changeset, raw_response} ->
{%Ecto.Changeset{} = changeset, {_raw_request, raw_response} = result} ->
changeset = after_request.(changeset, result)

if max_retries > 0 do
errors = Instructor.ErrorFormatter.format_errors(changeset)

Expand All @@ -461,7 +513,7 @@ defmodule Instructor do
]
end)

do_chat_completion(response_model, params, config)
do_chat_completion(changeset, params, config)
else
{:error, changeset}
end
Expand Down Expand Up @@ -535,10 +587,7 @@ defmodule Instructor do
:json ->
[sys_message | messages]

:json_schema ->
messages

:tools ->
m when m in [:json_schema, :tools, :structured_output] ->
messages
end
end)
Expand All @@ -553,7 +602,7 @@ defmodule Instructor do
type: "json_object"
})

:json_schema ->
m when m in [:json_schema, :structured_output] ->
params
|> Keyword.put(:response_format, %{
type: "json_schema",
Expand Down Expand Up @@ -589,12 +638,12 @@ defmodule Instructor do
not is_ecto_schema(response_model) ->
changeset

function_exported?(response_model, :validate_changeset, 1) ->
response_model.validate_changeset(changeset)

function_exported?(response_model, :validate_changeset, 2) ->
response_model.validate_changeset(changeset, context)

function_exported?(response_model, :validate_changeset, 1) ->
response_model.validate_changeset(changeset)

true ->
changeset
end
Expand Down
15 changes: 7 additions & 8 deletions lib/instructor/adapters/anthropic.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ defmodule Instructor.Adapters.Anthropic do
Anthropic adapter for Instructor.
"""
@behaviour Instructor.Adapter
@default_config [
api_url: "https://api.anthropic.com/",
http_options: [receive_timeout: 60_000]
]

alias Instructor.SSEStreamParser

Expand Down Expand Up @@ -131,14 +135,9 @@ defmodule Instructor.Adapters.Anthropic do
defp api_key(config), do: Keyword.fetch!(config, :api_key)
defp http_options(config), do: Keyword.fetch!(config, :http_options)

defp config(nil), do: config(Application.get_env(:instructor, :anthropic, []))

defp config(base_config) do
default_config = [
api_url: "https://api.anthropic.com/",
http_options: [receive_timeout: 60_000]
]

Keyword.merge(default_config, base_config)
@default_config
|> Keyword.merge(Application.get_env(:anthropic, :openai, []))
|> Keyword.merge(base_config || [])
end
end
44 changes: 19 additions & 25 deletions lib/instructor/adapters/openai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@ defmodule Instructor.Adapters.OpenAI do
Documentation for `Instructor.Adapters.OpenAI`.
"""
@behaviour Instructor.Adapter
@supported_modes [:tools, :json, :md_json, :json_schema]
@supported_modes [:tools, :json, :md_json, :json_schema, :structured_output]

alias Instructor.JSONSchema
alias Instructor.SSEStreamParser

@default_config [
api_url: "https://api.openai.com",
api_path: "/v1/chat/completions",
auth_mode: :bearer,
http_options: [receive_timeout: 60_000]
]

@impl true
def chat_completion(params, user_config \\ nil) do
config = config(user_config)
Expand All @@ -24,10 +31,11 @@ defmodule Instructor.Adapters.OpenAI do
raise "Unsupported OpenAI mode #{mode}. Supported modes: #{inspect(@supported_modes)}"
end

# TODO: Only do this when `strict: true`
params =
case params do
# OpenAI's json_schema mode doesn't support format or pattern attributes
%{"response_format" => %{"json_schema" => %{"schema" => _schema}}} ->
%{response_format: %{json_schema: %{schema: _schema}}} ->
update_in(params, [:response_format, :json_schema, :schema], fn schema ->
JSONSchema.traverse_and_update(schema, fn
%{"type" => _} = x when is_map_key(x, "format") or is_map_key(x, "pattern") ->
Expand Down Expand Up @@ -125,12 +133,12 @@ defmodule Instructor.Adapters.OpenAI do
defp do_chat_completion(mode, params, config) do
options = Keyword.merge(http_options(config), [auth_header(config), json: params])

with {:ok, %Req.Response{status: 200, body: body} = response} <-
Req.post(url(config), options),
with {%Req.Request{}, %Req.Response{status: 200, body: body}} = result <-
Req.run(url(config), [{:method, :post} | options]),
{:ok, content} <- parse_response_for_mode(mode, body) do
{:ok, response, content}
{:ok, result, content}
else
{:ok, %Req.Response{status: status, body: body}} ->
{%Req.Request{}, %Req.Response{status: status, body: body}} ->
{:error, "Unexpected HTTP response code: #{status}\n#{inspect(body)}"}

e ->
Expand All @@ -145,15 +153,8 @@ defmodule Instructor.Adapters.OpenAI do
}),
do: Jason.decode(args)

defp parse_response_for_mode(:md_json, %{"choices" => [%{"message" => %{"content" => content}}]}),
do: Jason.decode(content)

defp parse_response_for_mode(:json, %{"choices" => [%{"message" => %{"content" => content}}]}),
do: Jason.decode(content)

defp parse_response_for_mode(:json_schema, %{
"choices" => [%{"message" => %{"content" => content}}]
}),
defp parse_response_for_mode(mode, %{"choices" => [%{"message" => %{"content" => content}}]})
when mode in [:md_json, :json, :json_schema, :structured_output],
do: Jason.decode(content)

defp parse_response_for_mode(mode, response) do
Expand Down Expand Up @@ -213,16 +214,9 @@ defmodule Instructor.Adapters.OpenAI do

defp http_options(config), do: Keyword.fetch!(config, :http_options)

defp config(nil), do: config(Application.get_env(:instructor, :openai, []))

defp config(base_config) do
default_config = [
api_url: "https://api.openai.com",
api_path: "/v1/chat/completions",
auth_mode: :bearer,
http_options: [receive_timeout: 60_000]
]

Keyword.merge(default_config, base_config)
@default_config
|> Keyword.merge(Application.get_env(:instructor, :openai, []))
|> Keyword.merge(base_config || [])
end
end
Loading