Skip to content

Commit 0a2f0fb

Browse files
committed
fix: Address review feedback for embed_stream
Fixes for issues identified by Cursor bugbot: 1. Multiple embedding types IndexError (High): - Track text index separately per embedding type - Use type_indices dict to correctly map embeddings to texts 2. Image embeddings IndexError (Medium): - Remove images parameter from v2 embed_stream (text-only) - Document that images should use regular embed() 3. Fallback fails after ijson consumes stream (Medium): - Buffer response content before attempting ijson parsing - Fallback can now use buffered content if ijson fails 4. OMIT default causes TypeError (Low): - Check explicitly for None or OMIT sentinel - Handle ellipsis default value correctly 5. Zero/negative batch_size crashes (Low): - Add validation: raise ValueError if batch_size < 1
1 parent a5bd20c commit 0a2f0fb

3 files changed

Lines changed: 97 additions & 74 deletions

File tree

src/cohere/base_client.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1193,19 +1193,26 @@ def embed_stream(
11931193
print(f"Embedding {embedding.index}: {embedding.embedding[:5]}...")
11941194
# Process/save embedding immediately
11951195
"""
1196-
if not texts:
1196+
# Validate inputs
1197+
if texts is None or texts is OMIT:
11971198
return
1198-
1199+
if batch_size < 1:
1200+
raise ValueError("batch_size must be at least 1")
1201+
11991202
from .streaming_utils import StreamingEmbedParser
1200-
1203+
12011204
# Process texts in batches
1202-
texts_list = list(texts) if texts else []
1203-
total_embeddings_yielded = 0
1204-
1205+
texts_list = list(texts)
1206+
if not texts_list:
1207+
return
1208+
1209+
# Track text index separately from embedding index (for multiple embedding types)
1210+
global_text_index = 0
1211+
12051212
for batch_start in range(0, len(texts_list), batch_size):
12061213
batch_end = min(batch_start + batch_size, len(texts_list))
12071214
batch_texts = texts_list[batch_start:batch_end]
1208-
1215+
12091216
# Get response for this batch
12101217
response = self._raw_client.embed(
12111218
texts=batch_texts,
@@ -1215,15 +1222,15 @@ def embed_stream(
12151222
truncate=truncate,
12161223
request_options=request_options,
12171224
)
1218-
1225+
12191226
# Parse embeddings from response incrementally
12201227
parser = StreamingEmbedParser(response._response, batch_texts)
1221-
for i, embedding in enumerate(parser.iter_embeddings()):
1222-
# Adjust index for global position
1223-
embedding.index = batch_start + i
1224-
embedding.text = texts_list[embedding.index]
1228+
for embedding in parser.iter_embeddings():
1229+
# The parser tracks text index per embedding type
1230+
# Adjust text reference to use batch_texts mapping
1231+
text_index_in_batch = batch_texts.index(embedding.text) if embedding.text in batch_texts else 0
1232+
embedding.index = batch_start + text_index_in_batch
12251233
yield embedding
1226-
total_embeddings_yielded += len(batch_texts)
12271234

12281235
def rerank(
12291236
self,

src/cohere/streaming_utils.py

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
import io
6+
import json
57
from dataclasses import dataclass
68
from typing import Iterator, List, Optional, Union
79

@@ -21,84 +23,89 @@ class StreamedEmbedding:
2123
embedding: Union[List[float], List[int], str] # float, int8, uint8, binary, ubinary, base64
2224
embedding_type: str
2325
text: Optional[str] = None
24-
26+
2527

2628
class StreamingEmbedParser:
2729
"""
2830
Parses embed responses incrementally using ijson for memory efficiency.
2931
Falls back to regular JSON parsing if ijson is not available.
3032
"""
31-
33+
3234
def __init__(self, response: httpx.Response, batch_texts: Optional[List[str]] = None):
3335
"""
3436
Initialize the streaming parser.
35-
37+
3638
Args:
3739
response: The httpx response object
3840
batch_texts: The original texts for this batch (for correlation)
3941
"""
4042
self.response = response
4143
self.batch_texts = batch_texts or []
4244
self.embeddings_yielded = 0
43-
45+
self._response_content: Optional[bytes] = None
46+
4447
def iter_embeddings(self) -> Iterator[StreamedEmbedding]:
4548
"""
4649
Iterate over embeddings one at a time without loading all into memory.
47-
50+
4851
Yields:
4952
StreamedEmbedding objects as they are parsed from the response
5053
"""
51-
if not IJSON_AVAILABLE:
52-
# Fallback to regular parsing if ijson not available
54+
# Try to buffer the response content first to allow fallback if ijson fails
55+
# This trades some memory for reliability
56+
if self._response_content is None:
57+
try:
58+
content = self.response.content
59+
if isinstance(content, bytes):
60+
self._response_content = content
61+
except Exception:
62+
# Content not available as bytes, will use json() method
63+
pass
64+
65+
if not IJSON_AVAILABLE or self._response_content is None:
66+
# Fallback to regular parsing if ijson not available or no bytes content
5367
yield from self._iter_embeddings_fallback()
5468
return
55-
69+
5670
try:
5771
# Use ijson for memory-efficient parsing
58-
parser = ijson.parse(self.response.iter_bytes(chunk_size=65536))
72+
parser = ijson.parse(io.BytesIO(self._response_content))
5973
yield from self._parse_with_ijson(parser)
6074
except Exception:
6175
# If ijson parsing fails, fallback to regular parsing
6276
yield from self._iter_embeddings_fallback()
6377

6478
def _parse_with_ijson(self, parser) -> Iterator[StreamedEmbedding]:
6579
"""Parse embeddings using ijson incremental parser."""
66-
current_path: List[str] = []
67-
current_embedding = []
68-
embedding_index = 0
69-
embedding_type = "float"
80+
current_embedding: List[Union[float, int]] = []
7081
response_type = None
71-
in_embeddings = False
72-
82+
# Track index per embedding type to properly map to texts
83+
type_indices: dict[str, int] = {}
84+
7385
for prefix, event, value in parser:
74-
# Track current path
75-
if event == 'map_key':
76-
if current_path and current_path[-1] == 'embeddings':
77-
# This is an embedding type key (float_, int8, etc.)
78-
embedding_type = value.rstrip('_')
79-
8086
# Detect response type
8187
if prefix == 'response_type':
8288
response_type = value
83-
89+
8490
# Handle embeddings based on response type
8591
if response_type == 'embeddings_floats':
8692
# Simple float array format
8793
if prefix.startswith('embeddings.item.item'):
8894
current_embedding.append(value)
8995
elif prefix.startswith('embeddings.item') and event == 'end_array':
9096
# Complete embedding
91-
text = self.batch_texts[embedding_index] if embedding_index < len(self.batch_texts) else None
97+
text_index = type_indices.get('float', 0)
98+
text = self.batch_texts[text_index] if text_index < len(self.batch_texts) else None
9299
yield StreamedEmbedding(
93100
index=self.embeddings_yielded,
94-
embedding=current_embedding,
101+
embedding=list(current_embedding),
95102
embedding_type='float',
96103
text=text
97104
)
98105
self.embeddings_yielded += 1
99-
embedding_index += 1
106+
type_indices['float'] = text_index + 1
100107
current_embedding = []
101-
108+
102109
elif response_type == 'embeddings_by_type':
103110
# Complex format with multiple embedding types
104111
# Pattern: embeddings.<type>.item.item
@@ -108,66 +115,73 @@ def _parse_with_ijson(self, parser) -> Iterator[StreamedEmbedding]:
108115
current_embedding.append(value)
109116
elif prefix.startswith(f'embeddings.{emb_type}.item') and event == 'end_array':
110117
# Complete embedding of this type
111-
text = self.batch_texts[embedding_index] if embedding_index < len(self.batch_texts) else None
118+
# Use separate index per type to correctly map to texts
119+
text_index = type_indices.get(type_name, 0)
120+
text = self.batch_texts[text_index] if text_index < len(self.batch_texts) else None
112121
yield StreamedEmbedding(
113122
index=self.embeddings_yielded,
114-
embedding=current_embedding,
123+
embedding=list(current_embedding),
115124
embedding_type=type_name,
116125
text=text
117126
)
118127
self.embeddings_yielded += 1
119-
embedding_index += 1
128+
type_indices[type_name] = text_index + 1
120129
current_embedding = []
121-
130+
122131
# Handle base64 embeddings (string format)
123132
if prefix.startswith('embeddings.base64.item') and event == 'string':
124-
text = self.batch_texts[embedding_index] if embedding_index < len(self.batch_texts) else None
133+
text_index = type_indices.get('base64', 0)
134+
text = self.batch_texts[text_index] if text_index < len(self.batch_texts) else None
125135
yield StreamedEmbedding(
126136
index=self.embeddings_yielded,
127137
embedding=value, # base64 string
128138
embedding_type='base64',
129139
text=text
130140
)
131141
self.embeddings_yielded += 1
132-
embedding_index += 1
142+
type_indices['base64'] = text_index + 1
133143

134144
def _iter_embeddings_fallback(self) -> Iterator[StreamedEmbedding]:
135145
"""Fallback method using regular JSON parsing."""
136-
# This still loads the full response but at least provides the same interface
137-
if hasattr(self.response, 'json'):
146+
# Use buffered content if available, otherwise read from response
147+
if self._response_content is not None and isinstance(self._response_content, bytes):
148+
data = json.loads(self._response_content)
149+
elif hasattr(self.response, 'json') and callable(self.response.json):
138150
data = self.response.json()
139151
elif hasattr(self.response, '_response'):
140152
data = self.response._response.json() # type: ignore
141153
else:
142154
raise ValueError("Response object does not have a json() method")
155+
143156
response_type = data.get('response_type', '')
144-
157+
texts = data.get('texts', self.batch_texts)
158+
145159
if response_type == 'embeddings_floats':
146160
embeddings = data.get('embeddings', [])
147-
texts = data.get('texts', [])
148161
for i, embedding in enumerate(embeddings):
149162
yield StreamedEmbedding(
150-
index=i,
163+
index=self.embeddings_yielded,
151164
embedding=embedding,
152165
embedding_type='float',
153166
text=texts[i] if i < len(texts) else None
154167
)
155-
168+
self.embeddings_yielded += 1
169+
156170
elif response_type == 'embeddings_by_type':
157171
embeddings_obj = data.get('embeddings', {})
158-
texts = data.get('texts', [])
159-
172+
160173
# Iterate through each embedding type
161174
for emb_type, embeddings_list in embeddings_obj.items():
162175
type_name = emb_type.rstrip('_')
163176
if isinstance(embeddings_list, list):
164177
for i, embedding in enumerate(embeddings_list):
165178
yield StreamedEmbedding(
166-
index=i,
179+
index=self.embeddings_yielded,
167180
embedding=embedding,
168181
embedding_type=type_name,
169182
text=texts[i] if i < len(texts) else None
170183
)
184+
self.embeddings_yielded += 1
171185

172186

173187
def stream_embed_response(response: httpx.Response, texts: List[str]) -> Iterator[StreamedEmbedding]:

src/cohere/v2/client.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,6 @@ def embed_stream(
495495
model: str,
496496
input_type: EmbedInputType,
497497
texts: typing.Optional[typing.Sequence[str]] = OMIT,
498-
images: typing.Optional[typing.Sequence[str]] = OMIT,
499498
max_tokens: typing.Optional[int] = OMIT,
500499
output_dimension: typing.Optional[int] = OMIT,
501500
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
@@ -505,11 +504,14 @@ def embed_stream(
505504
) -> typing.Iterator[typing.Any]: # Returns Iterator[StreamedEmbedding]
506505
"""
507506
Memory-efficient streaming version of embed that yields embeddings one at a time.
508-
507+
509508
This method processes texts in batches and yields individual embeddings as they are
510509
parsed from the response, without loading all embeddings into memory at once.
511510
Ideal for processing large datasets where memory usage is a concern.
512511
512+
Note: This method only supports text embeddings. For image embeddings, use the
513+
regular embed() method.
514+
513515
Parameters
514516
----------
515517
model : str
@@ -521,9 +523,6 @@ def embed_stream(
521523
texts : typing.Optional[typing.Sequence[str]]
522524
An array of strings for the model to embed. Will be processed in batches.
523525
524-
images : typing.Optional[typing.Sequence[str]]
525-
An array of image data URIs for the model to embed.
526-
527526
max_tokens : typing.Optional[int]
528527
The maximum number of tokens to embed per input.
529528
@@ -556,7 +555,7 @@ def embed_stream(
556555
client_name="YOUR_CLIENT_NAME",
557556
token="YOUR_TOKEN",
558557
)
559-
558+
560559
# Process embeddings one at a time without loading all into memory
561560
for embedding in client.v2.embed_stream(
562561
model="embed-v4.0",
@@ -567,40 +566,43 @@ def embed_stream(
567566
print(f"Embedding {embedding.index}: {embedding.embedding[:5]}...")
568567
# Process/save embedding immediately
569568
"""
570-
if not texts:
569+
# Validate inputs
570+
if texts is None or texts is OMIT:
571571
return
572-
572+
if batch_size < 1:
573+
raise ValueError("batch_size must be at least 1")
574+
573575
from ..streaming_utils import StreamingEmbedParser
574-
576+
575577
# Process texts in batches
576-
texts_list = list(texts) if texts else []
577-
total_embeddings_yielded = 0
578-
578+
texts_list = list(texts)
579+
if not texts_list:
580+
return
581+
579582
for batch_start in range(0, len(texts_list), batch_size):
580583
batch_end = min(batch_start + batch_size, len(texts_list))
581584
batch_texts = texts_list[batch_start:batch_end]
582-
585+
583586
# Get response for this batch
584587
response = self._raw_client.embed(
585588
model=model,
586589
input_type=input_type,
587590
texts=batch_texts,
588-
images=images if batch_start == 0 else None, # Only include images in first batch
589591
max_tokens=max_tokens,
590592
output_dimension=output_dimension,
591593
embedding_types=embedding_types,
592594
truncate=truncate,
593595
request_options=request_options,
594596
)
595-
597+
596598
# Parse embeddings from response incrementally
597599
parser = StreamingEmbedParser(response._response, batch_texts)
598-
for i, embedding in enumerate(parser.iter_embeddings()):
599-
# Adjust index for global position
600-
embedding.index = batch_start + i
601-
embedding.text = texts_list[embedding.index]
600+
for embedding in parser.iter_embeddings():
601+
# The parser tracks text index per embedding type
602+
# Adjust text reference to use batch_texts mapping
603+
text_index_in_batch = batch_texts.index(embedding.text) if embedding.text in batch_texts else 0
604+
embedding.index = batch_start + text_index_in_batch
602605
yield embedding
603-
total_embeddings_yielded += len(batch_texts)
604606

605607
def rerank(
606608
self,

0 commit comments

Comments
 (0)