Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 9667d5a

Browse files
authored
Merge branch 'main' into sycai_ai_doc_fix
2 parents f2abe78 + 7600001 commit 9667d5a

2 files changed

Lines changed: 35 additions & 8 deletions

File tree

bigframes/bigquery/_operations/ai.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from bigframes.operations import ai_ops, output_schemas
2929

3030
PROMPT_TYPE = Union[
31+
str,
3132
series.Series,
3233
pd.Series,
3334
List[Union[str, series.Series, pd.Series]],
@@ -80,7 +81,7 @@ def generate(
8081
dtype: struct<is_herbivore: bool, number_of_legs: int64, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
8182
8283
Args:
83-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
84+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
8485
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
8586
or pandas Series.
8687
connection_id (str, optional):
@@ -179,7 +180,7 @@ def generate_bool(
179180
Name: result, dtype: boolean
180181
181182
Args:
182-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
183+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
183184
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
184185
or pandas Series.
185186
connection_id (str, optional):
@@ -261,7 +262,7 @@ def generate_int(
261262
Name: result, dtype: Int64
262263
263264
Args:
264-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
265+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
265266
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
266267
or pandas Series.
267268
connection_id (str, optional):
@@ -343,7 +344,7 @@ def generate_double(
343344
Name: result, dtype: Float64
344345
345346
Args:
346-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
347+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
347348
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
348349
or pandas Series.
349350
connection_id (str, optional):
@@ -421,7 +422,7 @@ def if_(
421422
dtype: string
422423
423424
Args:
424-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
425+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
425426
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
426427
or pandas Series.
427428
connection_id (str, optional):
@@ -475,7 +476,7 @@ def classify(
475476
[2 rows x 2 columns]
476477
477478
Args:
478-
input (Series | List[str|Series] | Tuple[str|Series, ...]):
479+
input (str | Series | List[str|Series] | Tuple[str|Series, ...]):
479480
A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series
480481
or pandas Series.
481482
categories (tuple[str, ...] | list[str]):
@@ -531,7 +532,7 @@ def score(
531532
dtype: Float64
532533
533534
Args:
534-
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
535+
prompt (str | Series | List[str|Series] | Tuple[str|Series, ...]):
535536
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
536537
or pandas Series.
537538
connection_id (str, optional):
@@ -563,9 +564,12 @@ def _separate_context_and_series(
563564
Input: ("str1", series1, "str2", "str3", series2)
564565
Output: ["str1", None, "str2", "str3", None], [series1, series2]
565566
"""
566-
if not isinstance(prompt, (list, tuple, series.Series)):
567+
if not isinstance(prompt, (str, list, tuple, series.Series)):
567568
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
568569

570+
if isinstance(prompt, str):
571+
return [None], [series.Series([prompt])]
572+
569573
if isinstance(prompt, series.Series):
570574
if prompt.dtype == dtypes.OBJ_REF_DTYPE:
571575
# Multi-model support

tests/system/small/bigquery/test_ai.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from unittest import mock
16+
1517
from packaging import version
1618
import pandas as pd
1719
import pyarrow as pa
@@ -42,6 +44,27 @@ def test_ai_function_pandas_input(session):
4244
)
4345

4446

47+
def test_ai_function_string_input(session):
48+
with mock.patch(
49+
"bigframes.core.global_session.get_global_session"
50+
) as mock_get_session:
51+
mock_get_session.return_value = session
52+
prompt = "Is apple a fruit?"
53+
54+
result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash")
55+
56+
assert _contains_no_nulls(result)
57+
assert result.dtype == pd.ArrowDtype(
58+
pa.struct(
59+
(
60+
pa.field("result", pa.bool_()),
61+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
62+
pa.field("status", pa.string()),
63+
)
64+
)
65+
)
66+
67+
4568
def test_ai_function_compile_model_params(session):
4669
if version.Version(sqlglot.__version__) < version.Version("25.18.0"):
4770
pytest.skip(

0 commit comments

Comments
 (0)