@@ -84,10 +84,6 @@ def test_register_and_query(self):
8484 assert isinstance (result , pl .DataFrame )
8585 assert result .shape == (2 , 2 )
8686
87- def test_supports_register (self ):
88- reader = ggsql .DuckDBReader ("duckdb://memory" )
89- assert reader .supports_register () is True
90-
9187 def test_invalid_connection_string (self ):
9288 with pytest .raises (ValueError ):
9389 ggsql .DuckDBReader ("invalid://connection" )
@@ -396,38 +392,24 @@ class TestCustomReader:
396392 """Tests for custom Python reader support."""
397393
398394 def test_simple_custom_reader (self ):
399- """Custom reader with execute_sql() method works."""
395+ """Custom reader works."""
400396
401397 class SimpleReader :
402- def execute_sql (self , sql : str ) -> pl .DataFrame :
403- return pl .DataFrame ({"x" : [1 , 2 , 3 ], "y" : [10 , 20 , 30 ]})
404-
405- reader = SimpleReader ()
406- spec = ggsql .execute ("SELECT * FROM data VISUALISE x, y DRAW point" , reader )
407- assert spec .metadata ()["rows" ] == 3
408-
409- def test_custom_reader_with_register (self ):
410- """Custom reader with register() support."""
411-
412- class RegisterReader :
413398 def __init__ (self ):
414- self .tables = {}
399+ self .ctx = pl . SQLContext ()
415400
416401 def execute_sql (self , sql : str ) -> pl .DataFrame :
417- # Simple: just return the first registered table
418- if self .tables :
419- return next (iter (self .tables .values ()))
420- return pl .DataFrame ({"x" : [1 ], "y" : [2 ]})
421-
422- def supports_register (self ) -> bool :
423- return True
402+ return self .ctx .execute (sql ).collect ()
424403
425- def register (self , name : str , df : pl .DataFrame ) -> None :
426- self .tables [name ] = df
404+ def register (
405+ self , name : str , df : pl .DataFrame , replace : bool = False
406+ ) -> None :
407+ self .ctx .register (name , df )
427408
428- reader = RegisterReader ()
429- spec = ggsql .execute ("SELECT 1 AS x, 2 AS y VISUALISE x, y DRAW point" , reader )
430- assert spec is not None
409+ reader = SimpleReader ()
410+ reader .register ("data" , pl .DataFrame ({"x" : [1 , 2 , 3 ], "y" : [10 , 20 , 30 ]}))
411+ spec = ggsql .execute ("SELECT * FROM data VISUALISE x, y DRAW point" , reader )
412+ assert spec .metadata ()["rows" ] == 3
431413
432414 def test_custom_reader_error_handling (self ):
433415 """Custom reader errors are propagated."""
@@ -436,6 +418,11 @@ class ErrorReader:
436418 def execute_sql (self , sql : str ) -> pl .DataFrame :
437419 raise ValueError ("Custom reader error" )
438420
421+ def register (
422+ self , name : str , df : pl .DataFrame , replace : bool = False
423+ ) -> None :
424+ raise ValueError ("Custom reader error" )
425+
439426 reader = ErrorReader ()
440427 with pytest .raises (ValueError , match = "Custom reader error" ):
441428 ggsql .execute ("SELECT 1 VISUALISE x, y DRAW point" , reader )
@@ -444,9 +431,17 @@ def test_custom_reader_wrong_return_type(self):
444431 """Custom reader returning wrong type raises TypeError."""
445432
446433 class WrongTypeReader :
434+ def __init__ (self ):
435+ self .ctx = pl .SQLContext ()
436+
447437 def execute_sql (self , sql : str ):
448438 return {"x" : [1 , 2 , 3 ]} # dict, not DataFrame
449439
440+ def register (
441+ self , name : str , df : pl .DataFrame , replace : bool = False
442+ ) -> None :
443+ self .ctx .register (name , df )
444+
450445 reader = WrongTypeReader ()
451446 with pytest .raises ((ValueError , TypeError )):
452447 ggsql .execute ("SELECT 1 VISUALISE x, y DRAW point" , reader )
@@ -461,16 +456,28 @@ def test_custom_reader_can_render(self):
461456 """Custom reader result can be rendered to Vega-Lite."""
462457
463458 class StaticReader :
459+ def __init__ (self ):
460+ self .ctx = pl .SQLContext ()
461+
464462 def execute_sql (self , sql : str ) -> pl .DataFrame :
465- return pl .DataFrame (
466- {
467- "x" : [1 , 2 , 3 , 4 , 5 ],
468- "y" : [10 , 40 , 20 , 50 , 30 ],
469- "category" : ["A" , "B" , "A" , "B" , "A" ],
470- }
471- )
463+ return self .ctx .execute (sql ).collect ()
464+
465+ def register (
466+ self , name : str , df : pl .DataFrame , replace : bool = False
467+ ) -> None :
468+ self .ctx .register (name , df )
472469
473470 reader = StaticReader ()
471+ reader .register (
472+ "data" ,
473+ pl .DataFrame (
474+ {
475+ "x" : [1 , 2 , 3 , 4 , 5 ],
476+ "y" : [10 , 40 , 20 , 50 , 30 ],
477+ "category" : ["A" , "B" , "A" , "B" , "A" ],
478+ }
479+ ),
480+ )
474481 spec = ggsql .execute (
475482 "SELECT * FROM data VISUALISE x, y, category AS color DRAW point" ,
476483 reader ,
@@ -488,13 +495,20 @@ def test_custom_reader_execute_sql_called(self):
488495
489496 class RecordingReader :
490497 def __init__ (self ):
498+ self .ctx = pl .SQLContext ()
491499 self .execute_calls = []
492500
493501 def execute_sql (self , sql : str ) -> pl .DataFrame :
494502 self .execute_calls .append (sql )
495- return pl .DataFrame ({"x" : [1 ], "y" : [2 ]})
503+ return self .ctx .execute (sql ).collect ()
504+
505+ def register (
506+ self , name : str , df : pl .DataFrame , replace : bool = False
507+ ) -> None :
508+ self .ctx .register (name , df )
496509
497510 reader = RecordingReader ()
511+ reader .register ("data" , pl .DataFrame ({"x" : [1 ], "y" : [2 ]}))
498512 ggsql .execute (
499513 "SELECT * FROM data VISUALISE x, y DRAW point" ,
500514 reader ,
@@ -516,10 +530,7 @@ def __init__(self):
516530 def execute_sql (self , sql : str ) -> pl .DataFrame :
517531 return self .con .con .execute (sql ).pl ()
518532
519- def supports_register (self ) -> bool :
520- return True
521-
522- def register (self , name : str , df : pl .DataFrame ) -> None :
533+ def register (self , name : str , df : pl .DataFrame , replace : bool = False ) -> None :
523534 self .con .create_table (name , df .to_arrow (), overwrite = True )
524535
525536 def unregister (self , name : str ) -> None :
0 commit comments