-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathsql.py
More file actions
1620 lines (1376 loc) · 52 KB
/
sql.py
File metadata and controls
1620 lines (1376 loc) · 52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module to generate SQL scripts for Metrics."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from collections import abc
import copy
import functools
import re
from typing import Any, Iterable, List, Optional, Text, Union
DEFAULT_DIALECT = 'GoogleSQL'
DIALECT = None
# If need to use CREATE TEMP TABLE. It's only needed when the engine doesn't
# evaluate RAND() only once in the WITH clause. Namely,
# run_only_once_in_with_clause() returns False.
VOLATILE_RAND_IN_WITH_CLAUSE = None
CREATE_TEMP_TABLE_FN = None
SUPPORT_FULL_JOIN = None
SUPPORT_JOIN_WITH_USING = None
# ORDER BY is required for ROW_NUMBER() in some dialects.
ROW_NUMBER_REQUIRE_ORDER_BY = None
GROUP_BY_FN = None
RAND_FN = None
CEIL_FN = None
SAFE_DIVIDE_FN = None
QUANTILE_FN = None
ARRAY_AGG_FN = None
ARRAY_INDEX_FN = None
NTH_VALUE_FN = None
COUNTIF_FN = None
FLOAT_CAST_FN = None
STRING_CAST_FN = None
UNIFORM_MAPPING_FN = None
UNNEST_ARRAY_FN = None
UNNEST_ARRAY_LITERAL_FN = None
GENERATE_ARRAY_FN = None
DUPLICATE_DATA_N_TIMES_FN = None
STDDEV_POP_FN = None
STDDEV_SAMP_FN = None
VARIANCE_POP_FN = None
VARIANCE_SAMP_FN = None
CORR_FN = None
COVAR_POP_FN = None
COVAR_SAMP_FN = None
def drop_table_if_exists(alias: str):
return f'DROP TABLE IF EXISTS {alias};'
def drop_temp_table_if_exists(alias: str):
return f'DROP TEMPORARY TABLE IF EXISTS {alias};'
def drop_table_if_exists_then_create_temp_table(alias: str, query: str):
"""Drops a table if it exists then creates a temporary table."""
return (
drop_table_if_exists(alias)
+ f'\nCREATE TEMPORARY TABLE {alias} AS {query}'
)
def drop_temp_table_if_exists_then_create_temp_table(alias: str, query: str):
"""Drops a table if it exists then creates a temporary table."""
return (
drop_temp_table_if_exists(alias)
+ f'\nCREATE TEMPORARY TABLE {alias} AS {query}'
)
def create_temp_table_fn_not_implemented(alias: str, query: str):
del alias, query # Unused
raise NotImplementedError('CREATE TEMP TABLE is not implemented.')
def sql_server_rand_fn_not_implemented():
raise NotImplementedError(
"SQL Server's RAND() without a seed parameter will return the same value"
" for every row within the same SELECT statement, which doesn't work"
' for us.'
)
def safe_divide_fn_default(numer: str, denom: str):
return (
f'CASE WHEN {{denom}} = 0 THEN NULL ELSE {FLOAT_CAST_FN("{numer}")} /'
f' {FLOAT_CAST_FN("{denom}")} END'.format(numer=numer, denom=denom)
)
def approx_quantiles_fn(percentile):
p = int(100 * percentile)
return f'APPROX_QUANTILES({{}}, 100)[SAFE_OFFSET({p})]'
def percentile_cont_fn(percentile):
return f'PERCENTILE_CONT({percentile}) WITHIN GROUP (ORDER BY {{}})'
def approx_percentile_fn(percentile):
return f'APPROX_PERCENTILE({{}}, {percentile})'
def quantile_fn_not_implemented(percentile):
del percentile # Unused
raise NotImplementedError('Quantile is not implemented.')
def array_agg_fn_googlesql(
sort_by: Optional[str],
ascending: Optional[bool],
dropna: Optional[bool],
limit: Optional[int],
):
"""Uses GoogleSQL's ARRAY_AGG to aggregate arrays."""
dropna = ' IGNORE NULLS' if dropna else ''
order_by = f' ORDER BY {sort_by}' if sort_by else ''
if order_by is not None:
order_by += '' if ascending else ' DESC'
limit = f' LIMIT {limit}' if limit else ''
return f'ARRAY_AGG({{}}{dropna}{order_by}{limit})'
def array_agg_fn_no_use_filter_no_limit(
sort_by: Optional[str],
ascending: Optional[bool],
dropna: Optional[bool],
limit: Optional[int],
):
"""Uses ARRAY_AGG to aggregate arrays. Use FILTER to filter out NULLs."""
del limit # LIMIT is not supported in PostgreSQL so just skip.
dropna = ' FILTER (WHERE {} IS NOT NULL)' if dropna else ''
order_by = f' ORDER BY {sort_by}' if sort_by else ''
if order_by is not None:
order_by += '' if ascending else ' DESC'
return f'ARRAY_AGG({{}}{order_by}){dropna}'
def json_array_agg_fn(
sort_by: Optional[str],
ascending: Optional[bool],
dropna: Optional[bool],
limit: Optional[int],
):
"""Uses JSON_ARRAYAGG to aggregate arrays."""
del limit # LIMIT is not supported in PostgreSQL so just skip.
if not dropna:
raise NotImplementedError('Respecting NULLS is not supported.')
order_by = f' ORDER BY {sort_by}' if sort_by else ''
if order_by is not None:
order_by += '' if ascending else ' DESC'
return f'JSON_ARRAYAGG({{}}{order_by})'
def array_agg_fn_not_implemented(
sort_by: Optional[str],
ascending: Optional[bool],
dropna: Optional[bool],
limit: Optional[int],
):
del sort_by, ascending, dropna, limit # Unused
raise NotImplementedError('ARRAY_AGG is not implemented.')
def array_index_safe_offset_fn(array: str, zero_based_idx: int):
return f'{array}[SAFE_OFFSET({zero_based_idx})]'
def array_subscript_fn(array: str, zero_based_idx: int):
return f'({array})[{zero_based_idx + 1}]'
def element_at_index_fn(array: str, zero_based_idx: int):
return f'element_at({array}, {zero_based_idx + 1})'
def json_extract_fn(array: str, zero_based_idx: int):
return f"JSON_EXTRACT({array}, '$[{zero_based_idx}]')"
def json_value_fn(array: str, zero_based_idx: int):
return f"JSON_VALUE({array}, '$[{zero_based_idx}]')"
def array_index_fn_not_implemented(array: str, zero_based_idx: int):
del array, zero_based_idx # Unused
raise NotImplementedError('ARRAY_INDEX is not implemented.')
def nth_fn_default(
zero_based_idx: int,
sort_by: Optional[str],
ascending: Optional[bool],
dropna: Optional[bool],
limit: Optional[int],
):
try:
array = ARRAY_AGG_FN(sort_by, ascending, dropna, limit)
return ARRAY_INDEX_FN(array, zero_based_idx)
except NotImplementedError as e:
raise NotImplementedError('Nth value is not implemented.') from e
def uniform_mapping_fn_not_implemented(_):
raise NotImplementedError('Uniform mapping is not implemented.')
def unnest_array_with_offset_fn(
array: str,
alias: Optional[str] = None,
offset: Optional[int] = None,
limit: Optional[int] = None,
):
"""Unnests an array in GoogleSQL."""
if alias is None:
return f'UNNEST({array})'
if not offset:
return f'UNNEST({array}) AS {alias}'
where = f' WHERE {offset} < {limit}' if limit else ''
return f'UNNEST({array}) {alias} WITH OFFSET AS {offset}{where}'
def unnest_array_with_ordinality_fn(
array: str,
alias: Optional[str] = None,
offset: Optional[int] = None,
limit: Optional[int] = None,
):
"""Unnests an array in PostgreSQL."""
if alias is None:
return f'UNNEST({array})'
if not offset:
return f'UNNEST({array}) unnested({alias})'
where = f' WHERE {offset} < {limit + 1}' if limit else ''
return (
f'UNNEST({array}) WITH ORDINALITY AS unnested({alias}, {offset}){where}'
)
def unnest_json_array_fn(
array: str,
alias: Optional[str] = None,
offset: Optional[int] = None,
limit: Optional[int] = None,
):
"""Unnests a JSON_ARRAY in Oracle SQL."""
where = f' WHERE {offset} < {limit + 1}' if limit else ''
return f'''JSON_TABLE({array}, '$[*]'
COLUMNS (
{alias} FLOAT PATH '$',
{offset} FOR ORDINALITY
)
) AS foobar{where}'''
def unnest_array_fn_not_implemented(
array: str,
alias: Optional[str] = None,
offset: Optional[int] = None,
limit: Optional[int] = None,
):
del array, alias, offset, limit # Unused
raise NotImplementedError('UNNEST is not implemented.')
def unnest_array_literal_fn_googlesql(array: List[Any], alias: str = ''):
return f'UNNEST({array}) {alias}'.strip()
def unnest_array_literal_fn_postgresql(array: List[Any], alias: str = ''):
return f'UNNEST(ARRAY{array}) {alias}'.strip()
def unnest_array_literal_fn_not_implemented(array, alias=''):
del array, alias # Unused
raise NotImplementedError('UNNEST with literal array is not implemented.')
def generate_array_fn(n):
"""Generates an array of n elements using GENERATE_ARRAY."""
return f'GENERATE_ARRAY(1, {n})'
def generate_series_fn(n):
"""Generates an array of n elements using GENERATE_SERIES."""
return f'GENERATE_SERIES(1, {n})'
def generate_sequence_fn_mariadb(n):
"""Generates an array of n elements using sequence in MariaDB."""
try:
n = int(n)
if not 1 < n < 9223372036854775807:
raise ValueError(
'Only support generating sequence for an integer between 1 and'
f' 2^63 - 1. Got: {n}'
)
return f'seq_0_to_{int(n) - 1}'
except ValueError as e:
raise NotImplementedError(
f'Only support generating sequence for an integer. Got: {n}'
) from e
def generate_array_fn_oracle(n, alias: str = '_'):
"""Generates an array of n elements using sequence in Oracle."""
try:
return f'SELECT LEVEL AS {alias} FROM DUAL CONNECT BY LEVEL <= {int(n)}'
except ValueError as e:
raise NotImplementedError(
f'Only support generating sequence for an integer. Got: {n}'
) from e
def generate_sequence_fn_trino(n):
"""Generates an array of n elements using sequence in Trino."""
return f'SEQUENCE(1, {n})'
def generate_array_fn_not_implemented(n):
del n # Unused
raise NotImplementedError(
'GENERATE_ARRAY/GENERATE_SERIES is not implemented.'
)
def unnest_generated_array(n, alias: Optional[str] = None):
"""Unnest a generated array, used to duplicate data."""
return UNNEST_ARRAY_FN(GENERATE_ARRAY_FN(n), alias)
def implicitly_unnest_generated_array(n, alias: Optional[str] = None):
"""Unnest a generated series, used to duplicate data."""
if not alias:
return GENERATE_ARRAY_FN(n)
return f'{GENERATE_ARRAY_FN(n)} {alias}'
def implicitly_unnest_generated_sequence(n, alias: Optional[str] = None):
"""Unnest a generated series, used to duplicate data."""
if not alias:
return GENERATE_ARRAY_FN(n)
return f'(SELECT seq AS {alias} FROM {GENERATE_ARRAY_FN(n)}) unnested'
def duplicate_data_n_times_oracle(n, alias: Optional[str] = None):
if not alias:
return generate_array_fn_oracle(n)
return generate_array_fn_oracle(n, alias)
def duplicate_data_n_times_not_implemented(n, alias: Optional[str] = None):
del n, alias # Unused
raise NotImplementedError(
'Duplicate data n times is not implemented.'
)
def stddev_pop_not_implemented():
raise NotImplementedError('STDDEV_POP is not implemented.')
def stddev_samp_not_implemented():
raise NotImplementedError('STDDEV_SAMP is not implemented.')
def variance_pop_not_implemented():
raise NotImplementedError('VARIANCE_POP is not implemented.')
def variance_samp_not_implemented():
raise NotImplementedError('VARIANCE_SAMP is not implemented.')
def corr_not_implemented():
raise NotImplementedError('CORR is not implemented.')
def covar_pop_not_implemented():
raise NotImplementedError('COVAR_POP is not implemented.')
def covar_samp_not_implemented():
raise NotImplementedError('COVAR_SAMP is not implemented.')
BUILTIN_DIALECTS = (
'GoogleSQL',
'MariaDB',
'Oracle',
'SQL Server',
'Calcite',
'PostgreSQL',
'Trino',
'SQLite',
)
CREATE_TEMP_TABLE_OPTIONS = {
'Default': drop_table_if_exists_then_create_temp_table,
'GoogleSQL': 'CREATE OR REPLACE TEMP TABLE {alias} AS {query};'.format,
'MariaDB': drop_temp_table_if_exists_then_create_temp_table,
'Oracle': create_temp_table_fn_not_implemented,
'SQL Server': 'SELECT * INTO #{alias} FROM ({query});'.format,
'Calcite': 'CREATE OR REPLACE TEMPORARY TABLE {alias} AS {query};'.format,
}
SUPPORT_FULL_JOIN_OPTIONS = {
'Default': True,
'MariaDB': False,
'SQLite': False,
}
SUPPORT_JOIN_WITH_USING_OPTIONS = {
'Default': True,
'SQL Server': False,
}
ROW_NUMBER_REQUIRE_ORDER_BY_OPTIONS = {
'Default': False,
'Oracle': True,
'SQL Server': True,
}
# In 'SELECT x + 1 AS foo, COUNT(*) FROM T GROUP BY ...', we can
# GROUP BY foo, GROUP BY 1, or GROUP BY x + 1. Most dialects support all three.
# But Oracle doesn't support GROUP BY 1 and SQL Server only supports
# GROUP BY x + 1. We prefer to use GROUP BY foo to GROUP BY 1 to GROUP BY x + 1
# from the readability perspective.
GROUP_BY_OPTIONS = {
'Default': lambda columns: ', '.join(columns.aliases),
'SQL Server': lambda columns: ', '.join(columns.expressions),
'Trino': lambda columns: ', '.join(map(str, range(1, len(columns) + 1))),
'Calcite': lambda columns: ', '.join(columns.expressions),
}
SAFE_DIVIDE_OPTIONS = {
'Default': safe_divide_fn_default,
'GoogleSQL': 'SAFE_DIVIDE({numer}, {denom})'.format,
}
# When make changes, manually evaluate the run_only_once_in_with_clause and
# update the VOLATILE_RAND_IN_WITH_CLAUSE_OPTIONS.
RAND_OPTIONS = {
'Default': 'RANDOM()'.format,
'GoogleSQL': 'RAND()'.format,
'MariaDB': 'RAND()'.format,
'Oracle': 'DBMS_RANDOM.VALUE'.format,
'SQL Server': sql_server_rand_fn_not_implemented,
'SQLite': '0.5 - RANDOM() / CAST(-9223372036854775808 AS REAL) / 2'.format,
'Calcite': 'RAND()'.format,
}
# Manually evalueated run_only_once_in_with_clause for each dialect.
VOLATILE_RAND_IN_WITH_CLAUSE_OPTIONS = {
'Default': True,
'PostgreSQL': False,
'MariaDB': False,
'Oracle': True,
'SQL Server': True,
'Trino': True,
'SQLite': False,
'Calcite': True,
}
CEIL_OPTIONS = {
'Default': 'CEIL({})'.format,
'SQL Server': 'CEILING({})'.format,
}
QUANTILE_OPTIONS = {
'Default': quantile_fn_not_implemented,
'GoogleSQL': approx_quantiles_fn,
'PostgreSQL': percentile_cont_fn,
'Oracle': percentile_cont_fn,
'Trino': approx_percentile_fn,
'Calcite': percentile_cont_fn,
}
ARRAY_AGG_OPTIONS = {
'Default': array_agg_fn_not_implemented,
'GoogleSQL': array_agg_fn_googlesql,
'PostgreSQL': array_agg_fn_no_use_filter_no_limit,
'MariaDB': json_array_agg_fn,
'Oracle': json_array_agg_fn,
# JSON_ARRAYAGG has been added in SQL Server 2025. Will update later.
'SQL Server': array_agg_fn_not_implemented,
'Trino': array_agg_fn_no_use_filter_no_limit,
'Calcite': array_agg_fn_no_use_filter_no_limit,
}
ARRAY_INDEX_OPTIONS = {
'Default': array_index_fn_not_implemented,
'GoogleSQL': array_index_safe_offset_fn,
'PostgreSQL': array_subscript_fn,
'MariaDB': json_extract_fn,
'Oracle': json_value_fn,
'Trino': element_at_index_fn,
'Calcite': array_index_safe_offset_fn,
}
NTH_OPTIONS = {
'Default': nth_fn_default,
}
COUNTIF_OPTIONS = {
'Default': 'COUNT(CASE WHEN {} THEN 1 END)'.format,
'GoogleSQL': 'COUNTIF({})'.format,
}
FLOAT_CAST_OPTIONS = {
'Default': 'CAST({} AS FLOAT)'.format,
'Trino': 'CAST({} AS DOUBLE)'.format,
}
STRING_CAST_OPTIONS = {
'Default': 'CAST({} AS TEXT)'.format,
'GoogleSQL': 'CAST({} AS STRING)'.format,
'MariaDB': 'CAST({} AS NCHAR)'.format,
'Oracle': 'TO_CHAR({})'.format,
'SQL Server': 'CAST({} AS VARCHAR(MAX))'.format,
'Trino': 'CAST({} AS VARCHAR)'.format,
'Calcite': 'CAST({} AS VARCHAR)'.format,
}
UNIFORM_MAPPING_OPTIONS = {
'Default': uniform_mapping_fn_not_implemented,
'GoogleSQL': lambda c: f'FARM_FINGERPRINT({c}) / 0xFFFFFFFFFFFFFFFF + 0.5',
# These queries are verified in
# https://colab.research.google.com/drive/1C1klaXsus0fWnOAT_vWzNHOT3Q21LZi7#scrollTo=O4--SViiuAv9&line=4&uniqifier=1.
'PostgreSQL': lambda c: f'ABS(HASHTEXT({c})::BIGINT) / 2147483647.',
'MariaDB': lambda c: (
f'CAST(CONV(SUBSTRING(MD5({c}), 1, 16), 16, 10) AS DECIMAL(38, 0)) /'
' POW(2, 64)'
),
'Trino': lambda c: (
f'CAST(from_big_endian_64(xxhash64(CAST({c} AS varbinary))) AS DOUBLE)'
' / POWER(2, 64) + 0.5'
),
}
UNNEST_ARRAY_OPTIONS = {
'Default': unnest_array_fn_not_implemented,
'GoogleSQL': unnest_array_with_offset_fn,
'PostgreSQL': unnest_array_with_ordinality_fn,
'MariaDB': unnest_json_array_fn,
'Oracle': unnest_json_array_fn,
'Trino': unnest_array_with_ordinality_fn,
'Calcite': unnest_array_with_ordinality_fn,
}
UNNEST_ARRAY_LITERAL_OPTIONS = {
'Default': unnest_array_literal_fn_not_implemented,
'GoogleSQL': unnest_array_literal_fn_googlesql,
'PostgreSQL': unnest_array_literal_fn_postgresql,
'Trino': unnest_array_literal_fn_postgresql,
'Calcite': unnest_array_literal_fn_postgresql,
}
GENERATE_ARRAY_OPTIONS = {
'Default': generate_array_fn_not_implemented,
'GoogleSQL': generate_array_fn,
'PostgreSQL': generate_series_fn,
'MariaDB': generate_sequence_fn_mariadb,
'Oracle': generate_array_fn_oracle,
'SQL Server': generate_series_fn,
'Trino': generate_sequence_fn_trino,
}
DUPLICATE_DATA_N_TIMES_OPTIONS = {
'Default': duplicate_data_n_times_not_implemented,
'GoogleSQL': unnest_generated_array,
'PostgreSQL': implicitly_unnest_generated_array,
'MariaDB': implicitly_unnest_generated_sequence,
'Oracle': duplicate_data_n_times_oracle,
'SQL Server': implicitly_unnest_generated_array,
'Trino': unnest_generated_array,
}
STDDEV_POP_OPTIONS = {
'Default': 'STDDEV_POP({})',
'SQL Server': 'STDEVP({})'
}
STDDEV_SAMP_OPTIONS = {
'Default': 'STDDEV_SAMP({})',
'SQL Server': 'STDEV({})'
}
VARIANCE_POP_OPTIONS = {
'Default': 'VAR_POP({})',
'SQL Server': 'VARP({})'
}
VARIANCE_SAMP_OPTIONS = {
'Default': 'VAR_SAMP({})',
'SQL Server': 'VAR({})'
}
CORR_OPTIONS = {
'Default': 'CORR({}, {})',
}
COVAR_POP_OPTIONS = {
'Default': 'COVAR_POP({}, {})',
}
COVAR_SAMP_OPTIONS = {
'Default': 'COVAR_SAMP({}, {})',
}
def set_dialect(dialect: Optional[str]):
"""Sets the dialect of the SQL query."""
# You can manually override the options below. You can manually test it in
# https://colab.research.google.com/drive/1y3UigzEby1anMM3-vXocBx7V8LVblIAp?usp=sharing.
global DIALECT, VOLATILE_RAND_IN_WITH_CLAUSE, CREATE_TEMP_TABLE_FN, SUPPORT_FULL_JOIN, SUPPORT_JOIN_WITH_USING, ROW_NUMBER_REQUIRE_ORDER_BY, GROUP_BY_FN, RAND_FN, CEIL_FN, SAFE_DIVIDE_FN, QUANTILE_FN, ARRAY_AGG_FN, ARRAY_INDEX_FN, NTH_VALUE_FN, COUNTIF_FN, STRING_CAST_FN, FLOAT_CAST_FN, UNIFORM_MAPPING_FN, UNNEST_ARRAY_FN, UNNEST_ARRAY_LITERAL_FN, GENERATE_ARRAY_FN, DUPLICATE_DATA_N_TIMES_FN, STDDEV_POP_FN, STDDEV_SAMP_FN, VARIANCE_POP_FN, VARIANCE_SAMP_FN, CORR_FN, COVAR_POP_FN, COVAR_SAMP_FN
if not dialect:
return
if dialect not in BUILTIN_DIALECTS:
print(
f'WARNING: Dialect {dialect} is not natively supported. Falling back to'
' the default options, which works with most variations. Built-in'
f' dialects are {BUILTIN_DIALECTS}'
)
DIALECT = dialect
VOLATILE_RAND_IN_WITH_CLAUSE = _get_dialect_option(
VOLATILE_RAND_IN_WITH_CLAUSE_OPTIONS
)
CREATE_TEMP_TABLE_FN = _get_dialect_option(CREATE_TEMP_TABLE_OPTIONS)
SUPPORT_FULL_JOIN = _get_dialect_option(SUPPORT_FULL_JOIN_OPTIONS)
SUPPORT_JOIN_WITH_USING = _get_dialect_option(SUPPORT_JOIN_WITH_USING_OPTIONS)
ROW_NUMBER_REQUIRE_ORDER_BY = _get_dialect_option(
ROW_NUMBER_REQUIRE_ORDER_BY_OPTIONS
)
GROUP_BY_FN = _get_dialect_option(GROUP_BY_OPTIONS)
RAND_FN = _get_dialect_option(RAND_OPTIONS)
CEIL_FN = _get_dialect_option(CEIL_OPTIONS)
SAFE_DIVIDE_FN = _get_dialect_option(SAFE_DIVIDE_OPTIONS)
QUANTILE_FN = _get_dialect_option(QUANTILE_OPTIONS)
ARRAY_AGG_FN = _get_dialect_option(ARRAY_AGG_OPTIONS)
ARRAY_INDEX_FN = _get_dialect_option(ARRAY_INDEX_OPTIONS)
NTH_VALUE_FN = _get_dialect_option(NTH_OPTIONS)
COUNTIF_FN = _get_dialect_option(COUNTIF_OPTIONS)
STRING_CAST_FN = _get_dialect_option(STRING_CAST_OPTIONS)
FLOAT_CAST_FN = _get_dialect_option(FLOAT_CAST_OPTIONS)
UNIFORM_MAPPING_FN = _get_dialect_option(UNIFORM_MAPPING_OPTIONS)
UNNEST_ARRAY_FN = _get_dialect_option(UNNEST_ARRAY_OPTIONS)
UNNEST_ARRAY_LITERAL_FN = _get_dialect_option(UNNEST_ARRAY_LITERAL_OPTIONS)
GENERATE_ARRAY_FN = _get_dialect_option(GENERATE_ARRAY_OPTIONS)
DUPLICATE_DATA_N_TIMES_FN = _get_dialect_option(
DUPLICATE_DATA_N_TIMES_OPTIONS
)
STDDEV_POP_FN = _get_dialect_option(STDDEV_POP_OPTIONS)
STDDEV_SAMP_FN = _get_dialect_option(STDDEV_SAMP_OPTIONS)
VARIANCE_POP_FN = _get_dialect_option(VARIANCE_POP_OPTIONS)
VARIANCE_SAMP_FN = _get_dialect_option(VARIANCE_SAMP_OPTIONS)
CORR_FN = _get_dialect_option(CORR_OPTIONS)
COVAR_POP_FN = _get_dialect_option(COVAR_POP_OPTIONS)
COVAR_SAMP_FN = _get_dialect_option(COVAR_SAMP_OPTIONS)
def _get_dialect_option(options: dict[str, Any]):
return options.get(DIALECT, options['Default'])
set_dialect(DEFAULT_DIALECT)
def is_compatible(sql0, sql1):
"""Checks if two datasources are compatible so their columns can be merged.
Being compatible means datasources
1. SELECT FROM the same data source
2. have same GROUP BY clauses
3. have the same WITH clause (usually None)
4. do not SELECT DISTINCT.
Args:
sql0: A Sql instance.
sql1: A Sql instance.
Returns:
If sql0 and sql1 are compatible.
"""
if not isinstance(sql0, Sql) or not isinstance(sql1, Sql):
raise ValueError('Both inputs must be a Sql instance!')
return (
sql0.from_data == sql1.from_data
and sql0.where == sql1.where
and sql0.groupby == sql1.groupby
and sql0.with_data == sql1.with_data
and not sql0.columns.distinct
and not sql1.columns.distinct
)
def add_suffix(alias):
"""Adds an int suffix to alias."""
alias = alias.strip('`')
m = re.search(r'([0-9]+)$', alias)
if m:
suffix = m.group(1)
alias = alias[:-len(suffix)] + str(int(suffix) + 1)
return alias
else:
return alias + '_1'
def rand_run_only_once_in_with_clause(execute):
"""Check if the RAND() is only evaluated once in the WITH clause."""
d = execute(
f'''WITH T AS (SELECT {RAND_FN()} AS r)
SELECT t1.r - t2.r AS d
FROM T t1 CROSS JOIN T t2'''
)
return bool(d.iloc[0, 0] == 0)
def dep_on_rand_table(query, rand_tables):
"""Returns if a SQL query depends on any stochastic table in rand_tables."""
for rand_table in rand_tables:
if re.search(r'\b%s\b' % rand_table, str(query)):
return True
return False
def get_temp_tables(with_data: 'Datasources'):
"""Gets all the subquery tables that need to be materialized.
When generating the query, we assume that volatile functions like RAND() in
the WITH clause behave as if they are evaluated only once. Unfortunately, not
all engines behave like that. In those cases, we need to CREATE TEMP TABLE to
materialize the subqueries that have volatile functions, so that the same
result is used in all places. An example is
WITH T AS (SELECT RAND() AS r)
SELECT t1.r - t2.r AS d
FROM T t1 CROSS JOIN T t2.
If it doesn't always evaluates to 0, we need to create a temp table for T.
A subquery needs to be materialized if
1. it depends on any stochastic table
(e.g. RAND()) and
2. the random column is referenced in the same query multiple times.
#2 is hard to check so we check if the stochastic table is referenced in the
same query multiple times instead.
An exception is the BootstrapRandomChoices table, which refers to a stochastic
table twice but only one refers to the stochasic column, so we don't need to
materialize it.
This function finds all the subquery tables in the WITH clause that need to be
materialized by
1. finding all the stochastic tables,
2. finding all the tables that depend, even indirectly, on a stochastic table,
3. finding all the tables in #2 that are referenced in the same query multiple
times.
Args:
with_data: The with clause.
Returns:
A set of table names that need to be materialized.
"""
tmp_tables = set()
for rand_table in with_data:
query = with_data[rand_table]
if RAND_FN() not in str(query):
continue
dep_on_rand = set([rand_table])
for alias in with_data:
if dep_on_rand_table(with_data[alias].from_data, dep_on_rand):
dep_on_rand.add(alias)
for t in dep_on_rand:
from_data = with_data[t].from_data
if isinstance(from_data, Join) and not t.startswith(
'BootstrapRandomChoices'
):
if dep_on_rand_table(from_data.ds1, dep_on_rand) and dep_on_rand_table(
from_data.ds2, dep_on_rand
):
tmp_tables.add(rand_table)
break
return tmp_tables
def get_alias(c):
return getattr(c, 'alias_raw', c)
def escape_alias(alias):
"""Replaces special characters in SQL column name alias."""
special = set(r""" `~!@#$%^&*()-=+[]{}\|;:'",.<>/?""")
if not alias or not special.intersection(alias):
return alias
escape = {c: '_' for c in special}
escape.update({
'!': '_not_',
'$': '_macro_',
'@': '_at_',
'%': '_pct_',
'^': '_to_the_power_',
'*': '_times_',
')': '',
'-': '_minus_',
'=': '_equals_',
'+': '_plus_',
'.': '_point_',
'/': '_divides_',
'>': '_greater_than_',
'<': '_smaller_than_',
})
res = (
''.join(escape.get(c, c) for c in alias)
.strip('_')
.strip(' ')
.replace('__', '_')
)
return 'col_' + res if res[0].isdigit() else res
@functools.total_ordering
class SqlComponent:
"""Base class for a SQL component like column, tabel and filter."""
def __eq__(self, other):
return str(self) == str(other)
def __lt__(self, other):
return str(self) < other
def __repr__(self):
return str(self)
def __hash__(self):
return hash(str(self))
def __bool__(self):
return bool(str(self))
def __nonzero__(self):
return bool(str(self))
def __add__(self, other):
return str.__add__(str(self), other)
def __mul__(self, other):
return str.__mul__(str(self), other)
def __rmul__(self, other):
return str.__rmul__(str(self), other)
def __getitem__(self, idx):
return str(self)[idx]
class SqlComponents(SqlComponent):
"""Base class for a bunch of SQL components like columns and filters."""
def __init__(self, children=None):
super(SqlComponents, self).__init__()
self.children = []
self.add(children)
def add(self, children):
if not isinstance(children, str) and isinstance(children, abc.Iterable):
for c in list(children):
self.add(c)
else:
if children and children not in self.children:
self.children.append(children)
return self
def __iter__(self):
for c in self.children:
yield c
def __len__(self):
return len(self.children)
def __getitem__(self, key):
return self.children[key]
def __setitem__(self, key, value):
self.children[key] = value
class Filter(SqlComponent):
"""Represents single condition in SQL WHERE clause."""
def __init__(self, cond: Optional[Text]):
super(Filter, self).__init__()
self.cond = ''
if isinstance(cond, Filter):
self.cond = cond.cond
elif cond:
self.cond = cond.replace('==', '=') or ''
def __str__(self):
if not self.cond:
return ''
return '(%s)' % self.cond if ' OR ' in self.cond.upper() else self.cond
class Filters(SqlComponents):
"""Represents a bunch of SQL conditions."""
@property
def where(self):
return sorted((str(Filter(f)) for f in self.children))
def remove(self, filters):
if not filters:
return self
self.children = [c for c in self.children if c not in Filters(filters)]
return self
def __str__(self):
return ' AND '.join(self.where)
class Column(SqlComponent):
"""Represents a SQL column.
Generates a single row in the SELECT clause in SQL. Here are some examples of
the input and representation.
Input => Representation
Column('click', 'SUM({})') => SUM(click) AS `sum(click)`
Column('click * weight', 'SUM({})', 'foo') => SUM(click * weight) AS foo
Column('click', 'SUM({})', auto_alias=False) => SUM(click)
Column('click', 'SUM({})', filters='region = "US"') =>
SUM(IF(region = "US", click, NULL)) AS `sum(click)`
Column('region') => region # No alias because it's same as the column.
Column('* EXCEPT (click)', auto_alias=False) => * EXCEPT (click)
Column(('click', 'impression'), 'SAFE_DIVIDE({}, {})', 'ctr') =>
SAFE_DIVIDE(click, impression) AS ctr.
Column(('click', 'impr'), 'SAFE_DIVIDE({}, {})', 'ctr', 'click > 5') =>
SAFE_DIVIDE(IF(click > 5, click, NULL), IF(click > 5, impr, NULL)) AS ctr.
Column('click', 'SUM({})', partition='region', auto_alias=False) =>
SUM(click) OVER (PARTITION BY region)
The representation is generated by applying the self.fn to self.column, then
adding optional OVER clause and renaming. The advantange of using Column
instead of raw string is
1. It handles filters nicely.
2. Even you don't need filters you can still pass the raw string, for exmaple,
'* EXCEPT (click)', in and it'd equivalent to a string, but can be used
with other Columns.
3. It supports arithmetic operations.
Column('click') * 2 is same as Column('click * 2') and
Column('click') Column('impression') is same as
Column(('click', 'impression'), 'SAFE_DIVIDE({}, {})') except for the
auto-generated aliases. This makes constructing complex SQL column easy.
4. Alias will be sanitized and auto-added if necessary.
"""
def __init__(
self,
column,
fn: Text = '{}',
alias: Optional[Text] = None,
filters=None,
partition=None,
order=None,
window_frame=None,
auto_alias=True,
):
super(Column, self).__init__()
self.column = [column] if isinstance(column, str) else column or []
self.fn = fn
# For a single column, we apply the function to the column repeatedly.
if len(self.column) == 1 and fn.count('{}') > 1:
self.column *= fn.count('{}')
self.filters = Filters(filters)
self.alias_raw = alias.strip('`') if alias else None
if not alias and auto_alias:
self.alias_raw = fn.lower().format(*self.column)
self.partition = partition
self.order = order
self.window_frame = window_frame
self.auto_alias = auto_alias
self.suffix = 0
@property
def alias(self):
a = self.alias_raw
if self.suffix:
a = '%s_%s' % (a, self.suffix)
return escape_alias(a)
@alias.setter
def alias(self, alias):
self.alias_raw = alias.strip('`')
def set_alias(self, alias):
self.alias = alias
return self
def add_suffix(self):
self.suffix += 1
return self.alias
@property
def expression(self):
"""Genereates the representation without the 'AS ...' part."""
over = None
if not (self.partition is None and self.order is None and
self.window_frame is None):
partition_cols_str = [
STRING_CAST_FN(c) for c in Columns(self.partition).expressions
]
partition = 'PARTITION BY %s' % ', '.join(
partition_cols_str) if self.partition else ''
order = 'ORDER BY %s' % ', '.join(Columns(