forked from xys-syx/polygon-starter-code
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_polygon_extensions.py
More file actions
139 lines (122 loc) · 4.52 KB
/
test_polygon_extensions.py
File metadata and controls
139 lines (122 loc) · 4.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from polygon.environment import Environment
from polygon.schemas import TableSchema, ColumnSchema
class TestSchemaBuilder:
"""Helper class to manage schema creation with proper IDs"""
def __init__(self):
self.table_counter = 1
self.column_counter = 1
def create_table(self, table_name, columns, pkeys=None, bound=3):
"""Create a table schema with auto-incrementing IDs"""
if pkeys is None:
pkeys = []
# Create mapping of column names to types
col_types = {name: typ for name, typ in columns}
table = TableSchema(
table_id=self.table_counter,
table_name=table_name,
bound=bound,
lineage="test"
)
self.table_counter += 1
for col_name, col_type in columns:
column = ColumnSchema(
column_id=self.column_counter,
column_name=col_name,
column_type=col_type,
table_name=table_name
)
self.column_counter += 1
table.append(column)
return {
"TableName": table_name,
"TableSchema": table,
"Bound": bound,
"PKeys": [{"Name": pk, "Type": col_types[pk]} for pk in pkeys],
"FKeys": [],
"Others": []
}
def create_test_env():
"""Create a test environment with properly configured schemas following example.py structure"""
schema = [
{
"TableName": "employees",
"PKeys": [{"Name": "id", "Type": "int"}],
"FKeys": [],
"Others": [
{"Name": "name", "Type": "varchar"},
{"Name": "salary", "Type": "int"},
{"Name": "dept", "Type": "varchar"},
{"Name": "gender", "Type": "varchar"}
]
},
{
"TableName": "sales",
"PKeys": [{"Name": "product", "Type": "varchar"}],
"FKeys": [],
"Others": [
{"Name": "amount", "Type": "int"},
{"Name": "region", "Type": "varchar"}
]
}
]
constraints = [
{'distinct': ['employees.id']},
{'distinct': ['sales.product']}
]
env = Environment(schema, constraints, bound=3, time_budget=60)
# Load test data
env.db.tables = {
1: [ # employees
{"id": 1, "name": "Alice", "salary": 80000, "dept": "Engineering", "gender": "F"},
{"id": 2, "name": "Bob", "salary": 120000, "dept": "Engineering", "gender": "M"},
{"id": 3, "name": "Charlie", "salary": 70000, "dept": "HR", "gender": "F"}
],
2: [ # sales
{"product": "Widget", "amount": 100, "region": "North"},
{"product": "Widget", "amount": 200, "region": "South"},
{"product": "Gadget", "amount": 150, "region": "North"}
]
}
return env
def test_if_expression():
"""Test IF expression functionality"""
try:
env = create_test_env()
# Basic IF test
q1 = "SELECT IF(salary > 100000, 'High', 'Low') as salary_level FROM employees"
q2 = "SELECT 'High' as salary_level FROM employees"
is_equivalent, _, _, _, _ = env.check(q1, q2)
assert not is_equivalent, "IF should conditionally select values"
print("✅ Basic IF test passed")
return True
except Exception as e:
print(f"❌ IF expression test failed: {str(e)}")
import traceback
traceback.print_exc()
return False
def test_filter_clause():
"""Test FILTER clause functionality"""
try:
env = create_test_env()
# Filtered aggregate test
q1 = "SELECT SUM(amount) FILTER (WHERE region = 'North') as north_sales FROM sales"
q2 = "SELECT SUM(amount) as total_sales FROM sales"
is_equivalent, _, _, _, _ = env.check(q1, q2)
assert not is_equivalent, "Filtered aggregate should differ"
print("✅ FILTER clause test passed")
return True
except Exception as e:
print(f"❌ FILTER clause test failed: {str(e)}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
print("🚀 Testing Polygon Extensions...")
success = True
success &= test_if_expression()
success &= test_filter_clause()
if success:
print("🎉 All tests passed successfully!")
else:
print("🔴 Some tests failed")
exit(0 if success else 1)