22
33from __future__ import annotations
44
5+ import io
6+ import json
57from dataclasses import dataclass
68from typing import Iterator , List , Optional , Union
79
@@ -21,84 +23,89 @@ class StreamedEmbedding:
2123 embedding : Union [List [float ], List [int ], str ] # float, int8, uint8, binary, ubinary, base64
2224 embedding_type : str
2325 text : Optional [str ] = None
24-
26+
2527
2628class StreamingEmbedParser :
2729 """
2830 Parses embed responses incrementally using ijson for memory efficiency.
2931 Falls back to regular JSON parsing if ijson is not available.
3032 """
31-
33+
3234 def __init__ (self , response : httpx .Response , batch_texts : Optional [List [str ]] = None ):
3335 """
3436 Initialize the streaming parser.
35-
37+
3638 Args:
3739 response: The httpx response object
3840 batch_texts: The original texts for this batch (for correlation)
3941 """
4042 self .response = response
4143 self .batch_texts = batch_texts or []
4244 self .embeddings_yielded = 0
43-
45+ self ._response_content : Optional [bytes ] = None
46+
4447 def iter_embeddings (self ) -> Iterator [StreamedEmbedding ]:
4548 """
4649 Iterate over embeddings one at a time without loading all into memory.
47-
50+
4851 Yields:
4952 StreamedEmbedding objects as they are parsed from the response
5053 """
51- if not IJSON_AVAILABLE :
52- # Fallback to regular parsing if ijson not available
54+ # Try to buffer the response content first to allow fallback if ijson fails
55+ # This trades some memory for reliability
56+ if self ._response_content is None :
57+ try :
58+ content = self .response .content
59+ if isinstance (content , bytes ):
60+ self ._response_content = content
61+ except Exception :
62+ # Content not available as bytes, will use json() method
63+ pass
64+
65+ if not IJSON_AVAILABLE or self ._response_content is None :
66+ # Fallback to regular parsing if ijson not available or no bytes content
5367 yield from self ._iter_embeddings_fallback ()
5468 return
55-
69+
5670 try :
5771 # Use ijson for memory-efficient parsing
58- parser = ijson .parse (self . response . iter_bytes ( chunk_size = 65536 ))
72+ parser = ijson .parse (io . BytesIO ( self . _response_content ))
5973 yield from self ._parse_with_ijson (parser )
6074 except Exception :
6175 # If ijson parsing fails, fallback to regular parsing
6276 yield from self ._iter_embeddings_fallback ()
6377
6478 def _parse_with_ijson (self , parser ) -> Iterator [StreamedEmbedding ]:
6579 """Parse embeddings using ijson incremental parser."""
66- current_path : List [str ] = []
67- current_embedding = []
68- embedding_index = 0
69- embedding_type = "float"
80+ current_embedding : List [Union [float , int ]] = []
7081 response_type = None
71- in_embeddings = False
72-
82+ # Track index per embedding type to properly map to texts
83+ type_indices : dict [str , int ] = {}
84+
7385 for prefix , event , value in parser :
74- # Track current path
75- if event == 'map_key' :
76- if current_path and current_path [- 1 ] == 'embeddings' :
77- # This is an embedding type key (float_, int8, etc.)
78- embedding_type = value .rstrip ('_' )
79-
8086 # Detect response type
8187 if prefix == 'response_type' :
8288 response_type = value
83-
89+
8490 # Handle embeddings based on response type
8591 if response_type == 'embeddings_floats' :
8692 # Simple float array format
8793 if prefix .startswith ('embeddings.item.item' ):
8894 current_embedding .append (value )
8995 elif prefix .startswith ('embeddings.item' ) and event == 'end_array' :
9096 # Complete embedding
91- text = self .batch_texts [embedding_index ] if embedding_index < len (self .batch_texts ) else None
97+ text_index = type_indices .get ('float' , 0 )
98+ text = self .batch_texts [text_index ] if text_index < len (self .batch_texts ) else None
9299 yield StreamedEmbedding (
93100 index = self .embeddings_yielded ,
94- embedding = current_embedding ,
101+ embedding = list ( current_embedding ) ,
95102 embedding_type = 'float' ,
96103 text = text
97104 )
98105 self .embeddings_yielded += 1
99- embedding_index += 1
106+ type_indices [ 'float' ] = text_index + 1
100107 current_embedding = []
101-
108+
102109 elif response_type == 'embeddings_by_type' :
103110 # Complex format with multiple embedding types
104111 # Pattern: embeddings.<type>.item.item
@@ -108,66 +115,73 @@ def _parse_with_ijson(self, parser) -> Iterator[StreamedEmbedding]:
108115 current_embedding .append (value )
109116 elif prefix .startswith (f'embeddings.{ emb_type } .item' ) and event == 'end_array' :
110117 # Complete embedding of this type
111- text = self .batch_texts [embedding_index ] if embedding_index < len (self .batch_texts ) else None
118+ # Use separate index per type to correctly map to texts
119+ text_index = type_indices .get (type_name , 0 )
120+ text = self .batch_texts [text_index ] if text_index < len (self .batch_texts ) else None
112121 yield StreamedEmbedding (
113122 index = self .embeddings_yielded ,
114- embedding = current_embedding ,
123+ embedding = list ( current_embedding ) ,
115124 embedding_type = type_name ,
116125 text = text
117126 )
118127 self .embeddings_yielded += 1
119- embedding_index += 1
128+ type_indices [ type_name ] = text_index + 1
120129 current_embedding = []
121-
130+
122131 # Handle base64 embeddings (string format)
123132 if prefix .startswith ('embeddings.base64.item' ) and event == 'string' :
124- text = self .batch_texts [embedding_index ] if embedding_index < len (self .batch_texts ) else None
133+ text_index = type_indices .get ('base64' , 0 )
134+ text = self .batch_texts [text_index ] if text_index < len (self .batch_texts ) else None
125135 yield StreamedEmbedding (
126136 index = self .embeddings_yielded ,
127137 embedding = value , # base64 string
128138 embedding_type = 'base64' ,
129139 text = text
130140 )
131141 self .embeddings_yielded += 1
132- embedding_index += 1
142+ type_indices [ 'base64' ] = text_index + 1
133143
134144 def _iter_embeddings_fallback (self ) -> Iterator [StreamedEmbedding ]:
135145 """Fallback method using regular JSON parsing."""
136- # This still loads the full response but at least provides the same interface
137- if hasattr (self .response , 'json' ):
146+ # Use buffered content if available, otherwise read from response
147+ if self ._response_content is not None and isinstance (self ._response_content , bytes ):
148+ data = json .loads (self ._response_content )
149+ elif hasattr (self .response , 'json' ) and callable (self .response .json ):
138150 data = self .response .json ()
139151 elif hasattr (self .response , '_response' ):
140152 data = self .response ._response .json () # type: ignore
141153 else :
142154 raise ValueError ("Response object does not have a json() method" )
155+
143156 response_type = data .get ('response_type' , '' )
144-
157+ texts = data .get ('texts' , self .batch_texts )
158+
145159 if response_type == 'embeddings_floats' :
146160 embeddings = data .get ('embeddings' , [])
147- texts = data .get ('texts' , [])
148161 for i , embedding in enumerate (embeddings ):
149162 yield StreamedEmbedding (
150- index = i ,
163+ index = self . embeddings_yielded ,
151164 embedding = embedding ,
152165 embedding_type = 'float' ,
153166 text = texts [i ] if i < len (texts ) else None
154167 )
155-
168+ self .embeddings_yielded += 1
169+
156170 elif response_type == 'embeddings_by_type' :
157171 embeddings_obj = data .get ('embeddings' , {})
158- texts = data .get ('texts' , [])
159-
172+
160173 # Iterate through each embedding type
161174 for emb_type , embeddings_list in embeddings_obj .items ():
162175 type_name = emb_type .rstrip ('_' )
163176 if isinstance (embeddings_list , list ):
164177 for i , embedding in enumerate (embeddings_list ):
165178 yield StreamedEmbedding (
166- index = i ,
179+ index = self . embeddings_yielded ,
167180 embedding = embedding ,
168181 embedding_type = type_name ,
169182 text = texts [i ] if i < len (texts ) else None
170183 )
184+ self .embeddings_yielded += 1
171185
172186
173187def stream_embed_response (response : httpx .Response , texts : List [str ]) -> Iterator [StreamedEmbedding ]:
0 commit comments