1616
1717from __future__ import annotations
1818
19- from typing import Optional
19+ from typing import cast , Iterable , Optional , Tuple
2020
2121from google .cloud import bigquery
2222import google .cloud .bigquery .table
2828import bigframes .core .blocks as blocks
2929import bigframes .core .guid
3030import bigframes .core .schema as schemata
31+ import bigframes .enums
3132import bigframes .session
3233
3334
@@ -53,19 +54,35 @@ def create_dataframe_from_query_job_stats(
5354
5455
5556def create_dataframe_from_row_iterator (
56- rows : google .cloud .bigquery .table .RowIterator , * , session : bigframes .session .Session
57+ rows : google .cloud .bigquery .table .RowIterator ,
58+ * ,
59+ session : bigframes .session .Session ,
60+ index_col : Iterable [str ] | str | bigframes .enums .DefaultIndexKind ,
61+ columns : Iterable [str ],
5762) -> dataframe .DataFrame :
5863 """Convert a RowIterator into a DataFrame wrapping a LocalNode.
5964
6065 This allows us to create a DataFrame from query results, even in the
6166 'jobless' case where there's no destination table.
6267 """
6368 pa_table = rows .to_arrow ()
69+ bq_schema = list (rows .schema )
6470
65- # TODO(tswast): Use array_value.promote_offsets() instead once that node is
66- # supported by the local engine.
67- offsets_col = bigframes .core .guid .generate_guid ()
68- pa_table = pyarrow_utils .append_offsets (pa_table , offsets_col = offsets_col )
71+ if not index_col or isinstance (index_col , bigframes .enums .DefaultIndexKind ):
72+ # We get a sequential index for free, so use that if no index is specified.
73+ # TODO(tswast): Use array_value.promote_offsets() instead once that node is
74+ # supported by the local engine.
75+ offsets_col = bigframes .core .guid .generate_guid ()
76+ pa_table = pyarrow_utils .append_offsets (pa_table , offsets_col = offsets_col )
77+ bq_schema += [bigquery .SchemaField (offsets_col , "INTEGER" )]
78+ index_columns : Tuple [str , ...] = (offsets_col ,)
79+ index_labels : Tuple [Optional [str ], ...] = (None ,)
80+ elif isinstance (index_col , str ):
81+ index_columns = (index_col ,)
82+ index_labels = (index_col ,)
83+ else :
84+ index_columns = tuple (index_col )
85+ index_labels = cast (Tuple [Optional [str ], ...], tuple (index_col ))
6986
7087 # We use the ManagedArrowTable constructor directly, because the
7188 # results of to_arrow() should be the source of truth with regards
@@ -74,17 +91,24 @@ def create_dataframe_from_row_iterator(
7491 # like the output of the BQ Storage Read API.
7592 mat = local_data .ManagedArrowTable (
7693 pa_table ,
77- schemata .ArraySchema .from_bq_schema (
78- list (rows .schema ) + [bigquery .SchemaField (offsets_col , "INTEGER" )]
79- ),
94+ schemata .ArraySchema .from_bq_schema (bq_schema ),
8095 )
8196 mat .validate ()
8297
98+ column_labels = [
99+ field .name for field in rows .schema if field .name not in index_columns
100+ ]
101+
83102 array_value = core .ArrayValue .from_managed (mat , session )
84103 block = blocks .Block (
85104 array_value ,
86- ( offsets_col ,) ,
87- [ field . name for field in rows . schema ] ,
88- ( None ,) ,
105+ index_columns = index_columns ,
106+ column_labels = column_labels ,
107+ index_labels = index_labels ,
89108 )
90- return dataframe .DataFrame (block )
109+ df = dataframe .DataFrame (block )
110+
111+ if columns :
112+ df = df [list (columns )]
113+
114+ return df
0 commit comments