Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 83b042d

Browse files
authored
Merge branch 'main' into shuowei-anywidget-nested-strcut-array
2 parents d2710c2 + 83b83ea commit 83b042d

22 files changed

Lines changed: 372 additions & 167 deletions

File tree

bigframes/bigquery/_operations/ml.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import cast, Mapping, Optional, Union
17+
from typing import cast, List, Mapping, Optional, Union
1818

1919
import bigframes_vendored.constants
2020
import google.cloud.bigquery
@@ -431,3 +431,92 @@ def transform(
431431
return bpd.read_gbq_query(sql)
432432
else:
433433
return session.read_gbq_query(sql)
434+
435+
436+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
437+
def generate_text(
438+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
439+
input_: Union[pd.DataFrame, dataframe.DataFrame, str],
440+
*,
441+
temperature: Optional[float] = None,
442+
max_output_tokens: Optional[int] = None,
443+
top_k: Optional[int] = None,
444+
top_p: Optional[float] = None,
445+
flatten_json_output: Optional[bool] = None,
446+
stop_sequences: Optional[List[str]] = None,
447+
ground_with_google_search: Optional[bool] = None,
448+
request_type: Optional[str] = None,
449+
) -> dataframe.DataFrame:
450+
"""
451+
Generates text using a BigQuery ML model.
452+
453+
See the `BigQuery ML GENERATE_TEXT function syntax
454+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-text>`_
455+
for additional reference.
456+
457+
Args:
458+
model (bigframes.ml.base.BaseEstimator or str):
459+
The model to use for text generation.
460+
input_ (Union[bigframes.pandas.DataFrame, str]):
461+
The DataFrame or query to use for text generation.
462+
temperature (float, optional):
463+
A FLOAT64 value that is used for sampling promiscuity. The value
464+
must be in the range ``[0.0, 1.0]``. A lower temperature works well
465+
for prompts that expect a more deterministic and less open-ended
466+
or creative response, while a higher temperature can lead to more
467+
diverse or creative results. A temperature of ``0`` is
468+
deterministic, meaning that the highest probability response is
469+
always selected.
470+
max_output_tokens (int, optional):
471+
An INT64 value that sets the maximum number of tokens in the
472+
generated text.
473+
top_k (int, optional):
474+
An INT64 value that changes how the model selects tokens for
475+
output. A ``top_k`` of ``1`` means the next selected token is the
476+
most probable among all tokens in the model's vocabulary. A
477+
``top_k`` of ``3`` means that the next token is selected from
478+
among the three most probable tokens by using temperature. The
479+
default value is ``40``.
480+
top_p (float, optional):
481+
A FLOAT64 value that changes how the model selects tokens for
482+
output. Tokens are selected from most probable to least probable
483+
until the sum of their probabilities equals the ``top_p`` value.
484+
For example, if tokens A, B, and C have a probability of 0.3, 0.2,
485+
and 0.1 and the ``top_p`` value is ``0.5``, then the model will
486+
select either A or B as the next token by using temperature. The
487+
default value is ``0.95``.
488+
flatten_json_output (bool, optional):
489+
A BOOL value that determines the content of the generated JSON column.
490+
stop_sequences (List[str], optional):
491+
An ARRAY<STRING> value that contains the stop sequences for the model.
492+
ground_with_google_search (bool, optional):
493+
A BOOL value that determines whether to ground the model with Google Search.
494+
request_type (str, optional):
495+
A STRING value that contains the request type for the model.
496+
497+
Returns:
498+
bigframes.pandas.DataFrame:
499+
The generated text.
500+
"""
501+
import bigframes.pandas as bpd
502+
503+
model_name, session = _get_model_name_and_session(model, input_)
504+
table_sql = _to_sql(input_)
505+
506+
sql = bigframes.core.sql.ml.generate_text(
507+
model_name=model_name,
508+
table=table_sql,
509+
temperature=temperature,
510+
max_output_tokens=max_output_tokens,
511+
top_k=top_k,
512+
top_p=top_p,
513+
flatten_json_output=flatten_json_output,
514+
stop_sequences=stop_sequences,
515+
ground_with_google_search=ground_with_google_search,
516+
request_type=request_type,
517+
)
518+
519+
if session is None:
520+
return bpd.read_gbq_query(sql)
521+
else:
522+
return session.read_gbq_query(sql)

bigframes/bigquery/ml.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
create_model,
2424
evaluate,
2525
explain_predict,
26+
generate_text,
2627
global_explain,
2728
predict,
2829
transform,
@@ -35,4 +36,5 @@
3536
"explain_predict",
3637
"global_explain",
3738
"transform",
39+
"generate_text",
3840
]

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def _(
527527
else:
528528
result = apply_window_if_present(result, window)
529529

530-
if op.should_floor_result:
530+
if op.should_floor_result or column.dtype == dtypes.TIMEDELTA_DTYPE:
531531
result = sge.Cast(this=sge.func("FLOOR", result), to="INT64")
532532
return result
533533

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@
4141
def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
4242
"""Compiles a BigFrameNode according to the request into SQL using SQLGlot."""
4343

44-
# Generator for unique identifiers.
45-
uid_gen = guid.SequentialUIDGenerator()
4644
output_names = tuple((expression.DerefOp(id), id.sql) for id in request.node.ids)
4745
result_node = nodes.ResultNode(
4846
request.node,
@@ -61,22 +59,16 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
6159
)
6260
if request.sort_rows:
6361
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
64-
result_node = _remap_variables(result_node, uid_gen)
65-
result_node = typing.cast(
66-
nodes.ResultNode, rewrite.defer_selection(result_node)
67-
)
68-
sql = _compile_result_node(result_node, uid_gen)
62+
sql = _compile_result_node(result_node)
6963
return configs.CompileResult(
7064
sql, result_node.schema.to_bigquery(), result_node.order_by
7165
)
7266

7367
ordering: typing.Optional[bf_ordering.RowOrdering] = result_node.order_by
7468
result_node = dataclasses.replace(result_node, order_by=None)
7569
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
70+
sql = _compile_result_node(result_node)
7671

77-
result_node = _remap_variables(result_node, uid_gen)
78-
result_node = typing.cast(nodes.ResultNode, rewrite.defer_selection(result_node))
79-
sql = _compile_result_node(result_node, uid_gen)
8072
# Return the ordering iff no extra columns are needed to define the row order
8173
if ordering is not None:
8274
output_order = (
@@ -97,11 +89,16 @@ def _remap_variables(
9789
return typing.cast(nodes.ResultNode, result_node)
9890

9991

100-
def _compile_result_node(
101-
root: nodes.ResultNode, uid_gen: guid.SequentialUIDGenerator
102-
) -> str:
92+
def _compile_result_node(root: nodes.ResultNode) -> str:
93+
# Create UIDs to standardize variable names and ensure consistent compilation
94+
# of nodes using the same generator.
95+
uid_gen = guid.SequentialUIDGenerator()
96+
root = _remap_variables(root, uid_gen)
97+
root = typing.cast(nodes.ResultNode, rewrite.defer_selection(root))
98+
10399
# Have to bind schema as the final step before compilation.
104100
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))
101+
105102
selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
106103
(name, scalar_compiler.scalar_op_compiler.compile_expression(ref))
107104
for ref, name in root.output_cols
@@ -127,7 +124,6 @@ def _compile_result_node(
127124
return sqlglot_ir.sql
128125

129126

130-
@functools.lru_cache(maxsize=5000)
131127
def compile_node(
132128
node: nodes.BigFrameNode, uid_gen: guid.SequentialUIDGenerator
133129
) -> ir.SQLGlotIR:
@@ -266,10 +262,16 @@ def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlo
266262
assert len(children) >= 1
267263
uid_gen = children[0].uid_gen
268264

269-
output_ids = [id.sql for id in node.output_ids]
265+
# BigQuery `UNION` query takes the column names from the first `SELECT` clause.
266+
default_output_ids = [field.id.sql for field in node.child_nodes[0].fields]
267+
output_aliases = [
268+
(default_output_id, output_id.sql)
269+
for default_output_id, output_id in zip(default_output_ids, node.output_ids)
270+
]
271+
270272
return ir.SQLGlotIR.from_union(
271273
[child.expr for child in children],
272-
output_ids=output_ids,
274+
output_aliases=output_aliases,
273275
uid_gen=uid_gen,
274276
)
275277

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
252252
sg_expr = expr.expr
253253

254254
if from_type == dtypes.STRING_DTYPE:
255-
func_name = "PARSE_JSON_IN_SAFE" if op.safe else "PARSE_JSON"
255+
func_name = "SAFE.PARSE_JSON" if op.safe else "PARSE_JSON"
256256
return sge.func(func_name, sg_expr)
257257
if from_type in (dtypes.INT_DTYPE, dtypes.BOOL_DTYPE, dtypes.FLOAT_DTYPE):
258258
sg_expr = sge.Cast(this=sg_expr, to="STRING")

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def from_query_string(
170170
cls,
171171
query_string: str,
172172
) -> SQLGlotIR:
173-
"""Builds a SQLGlot expression from a query string"""
173+
"""Builds a SQLGlot expression from a query string. Wrapping the query
174+
in a CTE can avoid the query parsing issue for unsupported syntax in
175+
SQLGlot."""
174176
uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator()
175177
cte_name = sge.to_identifier(
176178
next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted
@@ -187,7 +189,7 @@ def from_query_string(
187189
def from_union(
188190
cls,
189191
selects: typing.Sequence[sge.Select],
190-
output_ids: typing.Sequence[str],
192+
output_aliases: typing.Sequence[typing.Tuple[str, str]],
191193
uid_gen: guid.SequentialUIDGenerator,
192194
) -> SQLGlotIR:
193195
"""Builds a SQLGlot expression by unioning of multiple select expressions."""
@@ -196,46 +198,36 @@ def from_union(
196198
), f"At least two select expressions must be provided, but got {selects}."
197199

198200
existing_ctes: list[sge.CTE] = []
199-
union_selects: list[sge.Expression] = []
201+
union_selects: list[sge.Select] = []
200202
for select in selects:
201203
assert isinstance(
202204
select, sge.Select
203205
), f"All provided expressions must be of type sge.Select, but got {type(select)}"
204206

205207
select_expr = select.copy()
206208
select_expr, select_ctes = _pop_query_ctes(select_expr)
207-
existing_ctes = [*existing_ctes, *select_ctes]
208-
209-
new_cte_name = sge.to_identifier(
210-
next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted
211-
)
212-
new_cte = sge.CTE(
213-
this=select_expr,
214-
alias=new_cte_name,
209+
existing_ctes = _merge_ctes(existing_ctes, select_ctes)
210+
union_selects.append(select_expr)
211+
212+
union_expr: sge.Query = union_selects[0].subquery()
213+
for select in union_selects[1:]:
214+
union_expr = sge.Union(
215+
this=union_expr,
216+
expression=select.subquery(),
217+
distinct=False,
218+
copy=False,
215219
)
216-
existing_ctes = [*existing_ctes, new_cte]
217220

218-
selections = [
219-
sge.Alias(
220-
this=sge.to_identifier(expr.alias_or_name, quoted=cls.quoted),
221-
alias=sge.to_identifier(output_id, quoted=cls.quoted),
222-
)
223-
for expr, output_id in zip(select_expr.expressions, output_ids)
224-
]
225-
union_selects.append(
226-
sge.Select().select(*selections).from_(sge.Table(this=new_cte_name))
221+
selections = [
222+
sge.Alias(
223+
this=sge.to_identifier(old_name, quoted=cls.quoted),
224+
alias=sge.to_identifier(new_name, quoted=cls.quoted),
227225
)
228-
229-
union_expr = typing.cast(
230-
sge.Select,
231-
functools.reduce(
232-
lambda x, y: sge.Union(
233-
this=x, expression=y, distinct=False, copy=False
234-
),
235-
union_selects,
236-
),
226+
for old_name, new_name in output_aliases
227+
]
228+
final_select_expr = (
229+
sge.Select().select(*selections).from_(union_expr.subquery())
237230
)
238-
final_select_expr = sge.Select().select(sge.Star()).from_(union_expr.subquery())
239231
final_select_expr = _set_query_ctes(final_select_expr, existing_ctes)
240232
return cls(expr=final_select_expr, uid_gen=uid_gen)
241233

@@ -345,7 +337,7 @@ def join(
345337

346338
left_select, left_ctes = _pop_query_ctes(left_select)
347339
right_select, right_ctes = _pop_query_ctes(right_select)
348-
merged_ctes = [*left_ctes, *right_ctes]
340+
merged_ctes = _merge_ctes(left_ctes, right_ctes)
349341

350342
join_on = _and(
351343
tuple(
@@ -382,7 +374,7 @@ def isin_join(
382374

383375
left_select, left_ctes = _pop_query_ctes(left_select)
384376
right_select, right_ctes = _pop_query_ctes(right_select)
385-
merged_ctes = [*left_ctes, *right_ctes]
377+
merged_ctes = _merge_ctes(left_ctes, right_ctes)
386378

387379
left_condition = typed_expr.TypedExpr(
388380
sge.Column(this=conditions[0].expr, table=left_cte_name),
@@ -835,6 +827,15 @@ def _set_query_ctes(
835827
return new_expr
836828

837829

830+
def _merge_ctes(ctes1: list[sge.CTE], ctes2: list[sge.CTE]) -> list[sge.CTE]:
831+
"""Merges two lists of CTEs, de-duplicating by alias name."""
832+
seen = {cte.alias: cte for cte in ctes1}
833+
for cte in ctes2:
834+
if cte.alias not in seen:
835+
seen[cte.alias] = cte
836+
return list(seen.values())
837+
838+
838839
def _pop_query_ctes(
839840
expr: sge.Select,
840841
) -> tuple[sge.Select, list[sge.CTE]]:

0 commit comments

Comments
 (0)