Skip to content

Commit 19e3abb

Browse files
committed
Implement SQL Aggregate Functions (COUNT, SUM, AVG, MIN, MAX) with validation
1 parent 6e3d12c commit 19e3abb

3 files changed

Lines changed: 121 additions & 3 deletions

File tree

minidb/database.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,9 @@ def execute_query(self, query_string: str) -> Any:
317317
else:
318318
res = table.select_all(limit=limit)
319319

320-
# Apply column projection
320+
# Apply column projection or aggregates
321+
if self._is_aggregate_query(columns):
322+
return self._apply_aggregates(res, columns, table)
321323
return table.project_columns(res, columns)
322324

323325
if cmd_type == 'DELETE':
@@ -461,7 +463,7 @@ def execute_query(self, query_string: str) -> Any:
461463
}
462464

463465

464-
except (DBError, TypeError) as e:
466+
except (DBError, TypeError, ValueError) as e:
465467
return f"Error: {e}"
466468
except Exception as e:
467469
return f"Unexpected Error: {e}"
@@ -474,6 +476,110 @@ def get_tables(self) -> List[str]:
474476
"""
475477
return list(self.tables.keys())
476478

479+
def _is_aggregate_query(self, columns_str: str) -> bool:
480+
"""Determines if a SELECT clause contains aggregate functions.
481+
482+
Args:
483+
columns_str: The columns portion of the SQL query.
484+
485+
Returns:
486+
bool: True if an aggregate function is detected.
487+
"""
488+
aggr_funcs = ['COUNT', 'SUM', 'AVG', 'MIN', 'MAX']
489+
up_cols = columns_str.upper()
490+
return any(func + '(' in up_cols for func in aggr_funcs)
491+
492+
def _apply_aggregates(self, rows: List[Dict[str, Any]], columns_str: str, table_obj: 'Table') -> List[Dict[str, Any]]:
493+
"""Calculates SQL aggregates (SUM, AVG, etc.) in a single pass.
494+
495+
Args:
496+
rows: The result set after filtering.
497+
columns_str: The aggregate column specifications.
498+
table_obj: The target Table object for type verification.
499+
500+
Returns:
501+
List[Dict[str, Any]]: A list containing a single row with the aggregate results.
502+
503+
Raises:
504+
ValueError: If an aggregate is applied to an incompatible column type.
505+
"""
506+
import re
507+
# Find all patterns like AGGR_FUNC(column)
508+
pattern = re.compile(r'(\w+)\(([\w\*]+)\)', re.IGNORECASE)
509+
aggr_specs = pattern.findall(columns_str)
510+
511+
if not aggr_specs:
512+
return []
513+
514+
# Initialize accumulators
515+
accumulators = {}
516+
for func_raw, col in aggr_specs:
517+
func = func_raw.upper()
518+
label = f"{func}({col})"
519+
520+
# Validation: SUM and AVG require numeric columns
521+
if func in ['SUM', 'AVG']:
522+
col_type = table_obj.column_types.get(col)
523+
if col_type == 'str':
524+
raise ValueError(f"Cannot compute {func} on non-numeric column '{col}' (type: STR)")
525+
526+
if func == 'COUNT':
527+
accumulators[label] = {'sum': 0}
528+
elif func == 'SUM':
529+
accumulators[label] = {'sum': 0}
530+
elif func == 'AVG':
531+
accumulators[label] = {'sum': 0, 'count': 0}
532+
elif func == 'MIN':
533+
accumulators[label] = {'min': float('inf')}
534+
elif func == 'MAX':
535+
accumulators[label] = {'max': float('-inf')}
536+
537+
# Single pass execution
538+
total_filtered_rows = len(rows)
539+
for row in rows:
540+
for func_raw, col in aggr_specs:
541+
func = func_raw.upper()
542+
label = f"{func}({col})"
543+
544+
if func == 'COUNT':
545+
if col == '*' or row.get(col) is not None:
546+
accumulators[label]['sum'] += 1
547+
continue
548+
549+
val = row.get(col)
550+
if val is not None:
551+
if func == 'SUM':
552+
accumulators[label]['sum'] += val
553+
elif func == 'AVG':
554+
accumulators[label]['sum'] += val
555+
accumulators[label]['count'] += 1
556+
elif func == 'MIN':
557+
if val < accumulators[label]['min']:
558+
accumulators[label]['min'] = val
559+
elif func == 'MAX':
560+
if val > accumulators[label]['max']:
561+
accumulators[label]['max'] = val
562+
563+
# Finalize results
564+
result_row = {}
565+
for func_raw, col in aggr_specs:
566+
func = func_raw.upper()
567+
label = f"{func}({col})"
568+
acc = accumulators[label]
569+
570+
if func == 'COUNT':
571+
result_row[label] = acc['sum']
572+
elif func == 'SUM':
573+
result_row[label] = acc['sum'] if total_filtered_rows > 0 else 0
574+
elif func == 'AVG':
575+
result_row[label] = acc['sum'] / acc['count'] if acc.get('count', 0) > 0 else None
576+
elif func == 'MIN':
577+
result_row[label] = acc['min'] if acc['min'] != float('inf') else None
578+
elif func == 'MAX':
579+
result_row[label] = acc['max'] if acc['max'] != float('-inf') else None
580+
581+
return [result_row]
582+
477583
def _nested_loop_join(self, left_rows: List[Dict[str, Any]], right_rows: List[Dict[str, Any]],
478584
left_on: Tuple[str, str], right_on: Tuple[str, str]) -> List[Dict[str, Any]]:
479585
"""Performs a simple Nested Loop Join. Complexity: O(N*M).

minidb/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self) -> None:
1515
'CREATE': re.compile(r"CREATE\s+TABLE\s+(\w+)\s*\((.*)\)", re.IGNORECASE),
1616
'INSERT': re.compile(r"INSERT\s+INTO\s+(\w+)\s+VALUES\s*\((.*)\)", re.IGNORECASE),
1717
'SELECT_JOIN': re.compile(r"SELECT\s+\*\s+FROM\s+(\w+)\s+JOIN\s+(\w+)\s+ON\s+(\w+)\.(\w+)\s*=\s*(\w+)\.(\w+)", re.IGNORECASE),
18-
'SELECT': re.compile(r"SELECT\s+(\*|[\w,\s]+)\s+FROM\s+(\w+)(?:\s+WHERE\s+(\w+)\s*(>=|<=|!=|>|<|=|\s+IN\s+)\s*(.*?))?(?:\s+LIMIT\s+(\d+))?$", re.IGNORECASE),
18+
'SELECT': re.compile(r"SELECT\s+(\*|[\w,\s\(\)\*]+)\s+FROM\s+(\w+)(?:\s+WHERE\s+(\w+)\s*(>=|<=|!=|>|<|=|\s+IN\s+)\s*(.*?))?(?:\s+LIMIT\s+(\d+))?$", re.IGNORECASE),
1919
'DELETE': re.compile(r"DELETE\s+FROM\s+(\w+)\s+WHERE\s+(\w+)\s*(>=|<=|!=|>|<|=)\s*(.*)", re.IGNORECASE),
2020
'UPDATE': re.compile(r"UPDATE\s+(\w+)\s+SET\s+(\w+)\s*=\s*(.*)\s+WHERE\s+(\w+)\s*(>=|<=|!=|>|<|=)\s*(.*)", re.IGNORECASE),
2121
'ALTER_TABLE': re.compile(r"ALTER\s+TABLE\s+(\w+)\s+ADD\s+(\w+)\s+(\w+)", re.IGNORECASE),

templates/documentation.html

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,18 @@ <h5 class="fw-bold mt-4 small text-uppercase">Nested Subqueries</h5>
126126
<pre><code class="sql">-- Find students enrolled in 'Python'
127127
SELECT * FROM students
128128
WHERE id IN (SELECT student_id FROM enrollments WHERE course = 'Python')</code></pre>
129+
130+
<h5 class="fw-bold mt-4 small text-uppercase">Aggregate Functions</h5>
131+
<p>MiniDB provides built-in support for standard SQL aggregates to perform calculations on your
132+
data. These functions are executed in a single pass for maximum efficiency.</p>
133+
<ul>
134+
<li><code>COUNT(*)</code> or <code>COUNT(col)</code> — Total row count or non-null values.</li>
135+
<li><code>SUM(col)</code> — Arithmetic total (numeric columns only).</li>
136+
<li><code>AVG(col)</code> — Arithmetic mean (numeric columns only).</li>
137+
<li><code>MIN(col) / MAX(col)</code> — Boundary values.</li>
138+
</ul>
139+
<pre><code class="sql">-- Combined Analysis
140+
SELECT AVG(salary), MAX(score), COUNT(*) FROM employees WHERE department = 'Sales'</code></pre>
129141
</div>
130142

131143
<!-- UPDATE/DELETE -->

0 commit comments

Comments
 (0)