11# TODO
22# - [ ] Support text streaming
33# - [ ] Support file streaming
4+ import copy
45import hashlib
56import os
67import tempfile
7- from dataclasses import dataclass
88from functools import cached_property
99from pathlib import Path
1010from typing import (
2424 cast ,
2525 overload ,
2626)
27- from urllib .parse import urlparse
2827
2928import httpx
3029
@@ -61,36 +60,6 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
6160 return True
6261
6362
64- def _has_iterator_output_type (openapi_schema : dict ) -> bool :
65- """
66- Returns true if the model output type is an iterator (non-concatenate).
67- """
68- output = openapi_schema .get ("components" , {}).get ("schemas" , {}).get ("Output" , {})
69- return (
70- output .get ("type" ) == "array" and output .get ("x-cog-array-type" ) == "iterator"
71- )
72-
73-
74- def _download_file (url : str ) -> Path :
75- """
76- Download a file from URL to a temporary location and return the Path.
77- """
78- parsed_url = urlparse (url )
79- filename = os .path .basename (parsed_url .path )
80-
81- if not filename or "." not in filename :
82- filename = "download"
83-
84- _ , ext = os .path .splitext (filename )
85- with tempfile .NamedTemporaryFile (delete = False , suffix = ext ) as temp_file :
86- with httpx .stream ("GET" , url ) as response :
87- response .raise_for_status ()
88- for chunk in response .iter_bytes ():
89- temp_file .write (chunk )
90-
91- return Path (temp_file .name )
92-
93-
9463def _process_iterator_item (item : Any , openapi_schema : dict ) -> Any :
9564 """
9665 Process a single item from an iterator output based on schema.
@@ -177,6 +146,60 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: # py
177146 return output
178147
179148
149+ def _dereference_schema (schema : dict [str , Any ]) -> dict [str , Any ]:
150+ """
151+ Performs basic dereferencing on an OpenAPI schema based on the current schemas generated
152+ by Replicate. This code assumes that:
153+
154+ 1) References will always point to a field within #/components/schemas and will error
155+ if the reference is more deeply nested.
156+ 2) That the references when used can be discarded.
157+
158+ Should something more in-depth be required we could consider using the jsonref package.
159+ """
160+ dereferenced = copy .deepcopy (schema )
161+ schemas = dereferenced .get ("components" , {}).get ("schemas" , {})
162+ dereferenced_refs = set ()
163+
164+ def _resolve_ref (obj : Any ) -> Any :
165+ if isinstance (obj , dict ):
166+ if "$ref" in obj :
167+ ref_path = obj ["$ref" ]
168+ if ref_path .startswith ("#/components/schemas/" ):
169+ parts = ref_path .replace ("#/components/schemas/" , "" ).split ("/" , 2 )
170+
171+ if len (parts ) > 1 :
172+ raise NotImplementedError (
173+ f"Unexpected nested $ref found in schema: { ref_path } "
174+ )
175+
176+ (schema_name ,) = parts
177+ if schema_name in schemas :
178+ dereferenced_refs .add (schema_name )
179+ return _resolve_ref (schemas [schema_name ])
180+ else :
181+ return obj
182+ else :
183+ return obj
184+ else :
185+ return {key : _resolve_ref (value ) for key , value in obj .items ()}
186+ elif isinstance (obj , list ):
187+ return [_resolve_ref (item ) for item in obj ]
188+ else :
189+ return obj
190+
191+ result = _resolve_ref (dereferenced )
192+
193+ # Filter out any references that have now been referenced.
194+ result ["components" ]["schemas" ] = {
195+ k : v
196+ for k , v in result ["components" ]["schemas" ].items ()
197+ if k not in dereferenced_refs
198+ }
199+
200+ return result
201+
202+
180203T = TypeVar ("T" )
181204
182205
@@ -302,7 +325,6 @@ class FunctionRef(Protocol, Generic[Input, Output]):
302325 __call__ : Callable [Input , Output ]
303326
304327
305- @dataclass
306328class Run [O ]:
307329 """
308330 Represents a running prediction with access to the underlying schema.
@@ -361,13 +383,13 @@ def logs(self) -> Optional[str]:
361383 return self ._prediction .logs
362384
363385
364- @dataclass
365386class Function (Generic [Input , Output ]):
366387 """
367388 A wrapper for a Replicate model that can be called as a function.
368389 """
369390
370391 _ref : str
392+ _streaming : bool
371393
372394 def __init__ (self , ref : str , * , streaming : bool ) -> None :
373395 self ._ref = ref
@@ -405,7 +427,9 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
405427 )
406428
407429 return Run (
408- prediction = prediction , schema = self .openapi_schema , streaming = self ._streaming
430+ prediction = prediction ,
431+ schema = self .openapi_schema (),
432+ streaming = self ._streaming ,
409433 )
410434
411435 @property
@@ -415,20 +439,28 @@ def default_example(self) -> Optional[dict[str, Any]]:
415439 """
416440 raise NotImplementedError ("This property has not yet been implemented" )
417441
418- @cached_property
419442 def openapi_schema (self ) -> dict [str , Any ]:
420443 """
421444 Get the OpenAPI schema for this model version.
422445 """
423- latest_version = self ._model .latest_version
424- if latest_version is None :
425- msg = f"Model { self ._model .owner } /{ self ._model .name } has no latest version"
446+ return self ._openapi_schema
447+
448+ @cached_property
449+ def _openapi_schema (self ) -> dict [str , Any ]:
450+ _ , _ , model_version = self ._parsed_ref
451+ model = self ._model
452+
453+ version = (
454+ model .versions .get (model_version ) if model_version else model .latest_version
455+ )
456+ if version is None :
457+ msg = f"Model { self ._model .owner } /{ self ._model .name } has no version"
426458 raise ValueError (msg )
427459
428- schema = latest_version .openapi_schema
429- if cog_version := latest_version .cog_version :
460+ schema = version .openapi_schema
461+ if cog_version := version .cog_version :
430462 schema = make_schema_backwards_compatible (schema , cog_version )
431- return schema
463+ return _dereference_schema ( schema )
432464
433465 def _client (self ) -> Client :
434466 return Client ()
@@ -469,7 +501,6 @@ def _version(self) -> Version | None:
469501 return version
470502
471503
472- @dataclass
473504class AsyncRun [O ]:
474505 """
475506 Represents a running prediction with access to its version (async version).
@@ -528,21 +559,25 @@ async def logs(self) -> Optional[str]:
528559 return self ._prediction .logs
529560
530561
531- @dataclass
532562class AsyncFunction (Generic [Input , Output ]):
533563 """
534564 An async wrapper for a Replicate model that can be called as a function.
535565 """
536566
537- function_ref : str
538- streaming : bool
567+ _ref : str
568+ _streaming : bool
569+ _openapi_schema : dict [str , Any ] | None = None
570+
571+ def __init__ (self , ref : str , * , streaming : bool ) -> None :
572+ self ._ref = ref
573+ self ._streaming = streaming
539574
540575 def _client (self ) -> Client :
541576 return Client ()
542577
543578 @cached_property
544579 def _parsed_ref (self ) -> Tuple [str , str , Optional [str ]]:
545- return ModelVersionIdentifier .parse (self .function_ref )
580+ return ModelVersionIdentifier .parse (self ._ref )
546581
547582 async def _model (self ) -> Model :
548583 client = self ._client ()
@@ -607,7 +642,7 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
607642 return AsyncRun (
608643 prediction = prediction ,
609644 schema = await self .openapi_schema (),
610- streaming = self .streaming ,
645+ streaming = self ._streaming ,
611646 )
612647
613648 @property
@@ -621,16 +656,26 @@ async def openapi_schema(self) -> dict[str, Any]:
621656 """
622657 Get the OpenAPI schema for this model version asynchronously.
623658 """
624- model = await self ._model ()
625- latest_version = model .latest_version
626- if latest_version is None :
627- msg = f"Model { model .owner } /{ model .name } has no latest version"
628- raise ValueError (msg )
659+ if not self ._openapi_schema :
660+ _ , _ , model_version = self ._parsed_ref
629661
630- schema = latest_version .openapi_schema
631- if cog_version := latest_version .cog_version :
632- schema = make_schema_backwards_compatible (schema , cog_version )
633- return schema
662+ model = await self ._model ()
663+ if model_version :
664+ version = await model .versions .async_get (model_version )
665+ else :
666+ version = model .latest_version
667+
668+ if version is None :
669+ msg = f"Model { model .owner } /{ model .name } has no version"
670+ raise ValueError (msg )
671+
672+ schema = version .openapi_schema
673+ if cog_version := version .cog_version :
674+ schema = make_schema_backwards_compatible (schema , cog_version )
675+
676+ self ._openapi_schema = _dereference_schema (schema )
677+
678+ return self ._openapi_schema
634679
635680
636681@overload
0 commit comments