1- from typing import Optional
1+ from typing import Optional , Generic , TypeVar , Union , List
2+ from dataclasses import dataclass
3+
4+ TTable = TypeVar ("TTable" , bound = "Table" )
25
36import polars as pl
47import pyarrow as pa
8+ import pyarrow .compute as pc
59
6- from pyiceberg .catalog import load_catalog
710from pyiceberg .table import Table as IcebergTable
11+ from pyiceberg .catalog import (
12+ Catalog ,
13+ load_catalog ,
14+ )
815
916from ._context import TowerContext
10- from .utils .pyarrow import convert_pyarrow_schema
17+ from .utils .pyarrow import (
18+ convert_pyarrow_schema ,
19+ convert_pyarrow_expressions ,
20+ )
1121from .utils .tables import (
1222 make_table_name ,
1323 namespace_or_default ,
1424)
1525
26+ @dataclass
27+ class RowsAffectedInformation :
28+ inserts : int
29+ updates : int
30+
31+
1632class Table :
1733 """
1834 `Table` is a wrapper around an Iceberg table. It provides methods to read and
1935 write data to the table.
2036 """
2137
2238 def __init__ (self , context : TowerContext , table : IcebergTable ):
39+ self ._stats = RowsAffectedInformation (0 , 0 )
2340 self ._context = context
2441 self ._table = table
2542
43+
2644 def read (self ) -> pl .DataFrame :
2745 """
2846 Reads from the Iceberg tables. Returns the results as a Polars DataFrame.
@@ -31,28 +49,135 @@ def read(self) -> pl.DataFrame:
3149 # the result as a DataFrame.
3250 return pl .scan_iceberg (self ._table ).collect ()
3351
34- def insert (self , data : pa .Table ):
52+
53+ def to_polars (self ) -> pl .LazyFrame :
54+ """
55+ Converts the table to a Polars LazyFrame. This is useful when you
56+ understand Polars and you want to do something more complicated.
57+ """
58+ return pl .scan_iceberg (self ._table )
59+
60+
61+ def rows_affected (self ) -> RowsAffectedInformation :
62+ """
63+ Returns the stats for the table. This includes the number of inserts,
64+ updates, and deletes.
65+ """
66+ return self ._stats
67+
68+
69+ def insert (self , data : pa .Table ) -> TTable :
3570 """
3671 Inserts data into the Iceberg table. The data is expressed as a PyArrow table.
3772
3873 Args:
3974 data (pa.Table): The data to insert into the table.
75+
76+ Returns:
77+ TTable: The table with the inserted rows.
4078 """
4179 self ._table .append (data )
80+ self ._stats .inserts += data .num_rows
81+ return self
82+
83+
84+ def upsert (self , data : pa .Table , join_cols : Optional [list [str ]] = None ) -> TTable :
85+ """
86+ Upserts data into the Iceberg table. The data is expressed as a PyArrow table.
87+
88+ Args:
89+ data (pa.Table): The data to upsert into the table.
90+ join_cols (Optional[list[str]]): The columns that form the key to match rows on
91+
92+ Returns:
93+ TTable: The table with the upserted rows.
94+ """
95+ res = self ._table .upsert (
96+ data ,
97+ join_cols = join_cols ,
98+
99+ # All upserts will always be case sensitive. Perhaps we'll add this
100+ # as a parameter in the future?
101+ case_sensitive = True ,
102+
103+ # These are the defaults, but we're including them to be complete.
104+ when_matched_update_all = True ,
105+ when_not_matched_insert_all = True ,
106+ )
107+
108+ # Update the stats with the results of the relevant upsert.
109+ self ._stats .updates += res .rows_updated
110+ self ._stats .inserts += res .rows_inserted
111+
112+ return self
113+
114+
115+ def delete (self , filters : Union [str , List [pc .Expression ]]) -> TTable :
116+ """
117+ Deletes data from the Iceberg table. The filters are expressed as a
118+ PyArrow expression. The filters are applied to the table and the
119+ matching rows are deleted.
120+
121+ Args:
122+ filters (Union[str, List[pc.Expression]]): The filters to apply to the table.
123+ This can be a string or a list of PyArrow expressions.
124+
125+ Returns:
126+ TTable: The table with the deleted rows.
127+ """
128+ if isinstance (filters , list ):
129+ # We need to convert the pc.Expression into PyIceberg
130+ next_filters = convert_pyarrow_expressions (filters )
131+ filters = next_filters
132+
133+ self ._table .delete (
134+ delete_filter = filters ,
135+
136+ # We want this to always be the case. Not sure why you wouldn't?
137+ case_sensitive = True ,
138+ )
139+
140+ # NOTE: There is, unfortunately, no way to get the number of rows
141+ # deleted besides comparing the two snapshots that were created.
142+
143+ return self
144+
145+
146+ def schema (self ) -> pa .Schema :
147+ # We take an Iceberg Schema and we need to convert it into a PyArrow Schema
148+ iceberg_schema = self ._table .schema ()
149+ return iceberg_schema .as_arrow ()
150+
151+
152+ def column (self , name : str ) -> pa .compute .Expression :
153+ """
154+ Returns a column from the table. This is useful when you want to
155+ perform some operations on the column.
156+ """
157+ field = self .schema ().field (name )
158+
159+ if field is None :
160+ raise ValueError (f"Column { name } not found in table schema" )
161+
162+ # We need to convert the PyArrow field into pa.compute.Expression
163+ return pa .compute .field (name )
164+
42165
43166class TableReference :
44- def __init__ (self , ctx : TowerContext , catalog_name : str , name : str , namespace : Optional [str ] = None ):
167+ def __init__ (self , ctx : TowerContext , catalog : Catalog , name : str , namespace : Optional [str ] = None ):
45168 self ._context = ctx
46- self ._catalog = load_catalog ( catalog_name )
169+ self ._catalog = catalog
47170 self ._name = name
48171 self ._namespace = namespace
49172
173+
50174 def load (self ) -> Table :
51175 namespace = namespace_or_default (self ._namespace )
52176 table_name = make_table_name (self ._name , namespace )
53177 table = self ._catalog .load_table (table_name )
54178 return Table (self ._context , table )
55179
180+
56181 def create (self , schema : pa .Schema ) -> Table :
57182 namespace = namespace_or_default (self ._namespace )
58183 table_name = make_table_name (self ._name , namespace )
@@ -71,6 +196,7 @@ def create(self, schema: pa.Schema) -> Table:
71196
72197 return Table (self ._context , table )
73198
199+
74200 def create_if_not_exists (self , schema : pa .Schema ) -> Table :
75201 namespace = namespace_or_default (self ._namespace )
76202 table_name = make_table_name (self ._name , namespace )
@@ -92,7 +218,7 @@ def create_if_not_exists(self, schema: pa.Schema) -> Table:
92218
93219def tables (
94220 name : str ,
95- catalog : str = "default" ,
221+ catalog : Union [ str , Catalog ] = "default" ,
96222 namespace : Optional [str ] = None
97223) -> TableReference :
98224 """
@@ -101,11 +227,16 @@ def tables(
101227
102228 Args:
103229 `name` (str): The name of the table to load.
104- `catalog` (str): The name of the catalog to use. "default" by default.
230+ `catalog` (Union[str, Catalog]): The name of the catalog or the actual
231+ catalog to use. "default" is the default value. You can pass in an
232+ actual catalog object for testing purposes.
105233 `namespace` (Optional[str]): The namespace in which to load the table.
106234
107235 Returns:
108236 TableReference: A reference to a table to be resolved with `create` or `load`
109237 """
238+ if isinstance (catalog , str ):
239+ catalog = load_catalog (catalog )
240+
110241 ctx = TowerContext .build ()
111242 return TableReference (ctx , catalog , name , namespace )
0 commit comments