|
28 | 28 | from bigframes.operations import ai_ops, output_schemas |
29 | 29 |
|
30 | 30 | PROMPT_TYPE = Union[ |
| 31 | + str, |
31 | 32 | series.Series, |
32 | 33 | pd.Series, |
33 | 34 | List[Union[str, series.Series, pd.Series]], |
@@ -80,7 +81,7 @@ def generate( |
80 | 81 | dtype: struct<is_herbivore: bool, number_of_legs: int64, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow] |
81 | 82 |
|
82 | 83 | Args: |
83 | | - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 84 | + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
84 | 85 | A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series |
85 | 86 | or pandas Series. |
86 | 87 | connection_id (str, optional): |
@@ -179,7 +180,7 @@ def generate_bool( |
179 | 180 | Name: result, dtype: boolean |
180 | 181 |
|
181 | 182 | Args: |
182 | | - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 183 | + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
183 | 184 | A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series |
184 | 185 | or pandas Series. |
185 | 186 | connection_id (str, optional): |
@@ -261,7 +262,7 @@ def generate_int( |
261 | 262 | Name: result, dtype: Int64 |
262 | 263 |
|
263 | 264 | Args: |
264 | | - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 265 | + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
265 | 266 | A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series |
266 | 267 | or pandas Series. |
267 | 268 | connection_id (str, optional): |
@@ -343,7 +344,7 @@ def generate_double( |
343 | 344 | Name: result, dtype: Float64 |
344 | 345 |
|
345 | 346 | Args: |
346 | | - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 347 | + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
347 | 348 | A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series |
348 | 349 | or pandas Series. |
349 | 350 | connection_id (str, optional): |
@@ -421,7 +422,7 @@ def if_( |
421 | 422 | dtype: string |
422 | 423 |
|
423 | 424 | Args: |
424 | | - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 425 | + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
425 | 426 | A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series |
426 | 427 | or pandas Series. |
427 | 428 | connection_id (str, optional): |
@@ -475,7 +476,7 @@ def classify( |
475 | 476 | [2 rows x 2 columns] |
476 | 477 |
|
477 | 478 | Args: |
478 | | - input (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 479 | + input (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
479 | 480 | A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series |
480 | 481 | or pandas Series. |
481 | 482 | categories (tuple[str, ...] | list[str]): |
@@ -531,7 +532,7 @@ def score( |
531 | 532 | dtype: Float64 |
532 | 533 |
|
533 | 534 | Args: |
534 | | - prompt (Series | List[str|Series] | Tuple[str|Series, ...]): |
| 535 | + prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]): |
535 | 536 | A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series |
536 | 537 | or pandas Series. |
537 | 538 | connection_id (str, optional): |
@@ -563,9 +564,12 @@ def _separate_context_and_series( |
563 | 564 | Input: ("str1", series1, "str2", "str3", series2) |
564 | 565 | Output: ["str1", None, "str2", "str3", None], [series1, series2] |
565 | 566 | """ |
566 | | - if not isinstance(prompt, (list, tuple, series.Series)): |
| 567 | + if not isinstance(prompt, (str, list, tuple, series.Series)): |
567 | 568 | raise ValueError(f"Unsupported prompt type: {type(prompt)}") |
568 | 569 |
|
| 570 | + if isinstance(prompt, str): |
| 571 | + return [None], [series.Series([prompt])] |
| 572 | + |
569 | 573 | if isinstance(prompt, series.Series): |
570 | 574 | if prompt.dtype == dtypes.OBJ_REF_DTYPE: |
571 | 575 | # Multi-model support |
|
0 commit comments