Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit b593284

Browse files
authored
Merge pull request #23 from datafold/RoderickJDunn-optimizer-hints-v1
Tiny fix for PR #19 - Optimizer hints v1
2 parents 40159a3 + 201c4e9 commit b593284

File tree

10 files changed

+88
-13
lines changed

10 files changed

+88
-13
lines changed

docs/intro.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,8 @@ List of available abstract mixins:
463463

464464
- `AbstractMixin_TimeTravel` - Only snowflake & bigquery
465465

466+
- `AbstractMixin_OptimizerHints` - Only oracle & mysql
467+
466468
More will be added in the future.
467469

468470
Note that it's still possible to use user-defined mixins that aren't on this list.

sqeleton/abcs/mixins.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,13 @@ def time_travel(
145145
146146
Must specify exactly one of `timestamp`, `offset` or `statement`.
147147
"""
148+
149+
150+
class AbstractMixin_OptimizerHints(AbstractMixin):
151+
@abstractmethod
152+
def optimizer_hints(self, optimizer_hints: str) -> str:
153+
"""Creates a compatible optimizer_hints string
154+
155+
Parameters:
156+
optimizer_hints - string of optimizer hints
157+
"""

sqeleton/databases/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@
3636
Boolean,
3737
)
3838
from ..abcs.mixins import Compilable
39-
from ..abcs.mixins import AbstractMixin_Schema, AbstractMixin_RandomSample, AbstractMixin_NormalizeValue
39+
from ..abcs.mixins import (
40+
AbstractMixin_Schema,
41+
AbstractMixin_RandomSample,
42+
AbstractMixin_NormalizeValue,
43+
AbstractMixin_OptimizerHints,
44+
)
4045
from ..bound_exprs import bound_table
4146

4247
logger = logging.getLogger("database")
@@ -134,6 +139,11 @@ def random_sample_ratio_approx(self, tbl: AbstractTable, ratio: float) -> Abstra
134139
return tbl.where(Random() < ratio)
135140

136141

142+
class Mixin_OptimizerHints(AbstractMixin_OptimizerHints):
143+
def optimizer_hints(self, hints: str) -> str:
144+
return f"/*+ {hints} */ "
145+
146+
137147
class BaseDialect(AbstractDialect):
138148
SUPPORTS_PRIMARY_KEY = False
139149
SUPPORTS_INDEXES = False

sqeleton/databases/mysql.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
AbstractMixin_Regex,
1818
AbstractMixin_RandomSample,
1919
)
20-
from .base import ThreadedDatabase, import_helper, ConnectError, BaseDialect, Compilable
20+
from .base import Mixin_OptimizerHints, ThreadedDatabase, import_helper, ConnectError, BaseDialect, Compilable
2121
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, TIMESTAMP_PRECISION_POS, Mixin_Schema, Mixin_RandomSample
2222
from ..queries.ast_classes import BinBoolOp
2323

@@ -54,7 +54,7 @@ def test_regex(self, string: Compilable, pattern: Compilable) -> Compilable:
5454
return BinBoolOp("REGEXP", [string, pattern])
5555

5656

57-
class Dialect(BaseDialect, Mixin_Schema):
57+
class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints):
5858
name = "MySQL"
5959
ROUNDS_ON_PREC_LOSS = True
6060
SUPPORTS_PRIMARY_KEY = True
@@ -109,6 +109,9 @@ def type_repr(self, t) -> str:
109109
def explain_as_text(self, query: str) -> str:
110110
return f"EXPLAIN FORMAT=TREE {query}"
111111

112+
def optimizer_hints(self, s: str):
113+
return f"/*+ {s} */ "
114+
112115
def set_timezone_to_utc(self) -> str:
113116
return "SET @@session.time_zone='+00:00'"
114117

sqeleton/databases/oracle.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,15 @@
1717
from ..abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue, AbstractMixin_Schema
1818
from ..abcs import Compilable
1919
from ..queries import this, table, SKIP
20-
from .base import BaseDialect, ThreadedDatabase, import_helper, ConnectError, QueryError, Mixin_RandomSample
20+
from .base import (
21+
BaseDialect,
22+
Mixin_OptimizerHints,
23+
ThreadedDatabase,
24+
import_helper,
25+
ConnectError,
26+
QueryError,
27+
Mixin_RandomSample,
28+
)
2129
from .base import TIMESTAMP_PRECISION_POS
2230

2331
SESSION_TIME_ZONE = None # Changed by the tests
@@ -72,7 +80,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
7280
)
7381

7482

75-
class Dialect(BaseDialect, Mixin_Schema):
83+
class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints):
7684
name = "Oracle"
7785
SUPPORTS_PRIMARY_KEY = True
7886
SUPPORTS_INDEXES = True

sqeleton/databases/snowflake.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ def table_information(self) -> Compilable:
136136
def set_timezone_to_utc(self) -> str:
137137
return "ALTER SESSION SET TIMEZONE = 'UTC'"
138138

139+
def optimizer_hints(self, hints: str) -> str:
140+
raise NotImplementedError("Optimizer hints not yet implemented in snowflake")
141+
139142

140143
class Snowflake(Database):
141144
dialect = Dialect()

sqeleton/queries/ast_classes.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,14 @@ class ITable(AbstractTable):
9191
source_table: Any
9292
schema: Schema = None
9393

94-
def select(self, *exprs, distinct=SKIP, **named_exprs):
94+
def select(self, *exprs, distinct=SKIP, optimizer_hints=SKIP, **named_exprs):
9595
"""Create a new table with the specified fields"""
9696
exprs = args_as_tuple(exprs)
9797
exprs = _drop_skips(exprs)
9898
named_exprs = _drop_skips_dict(named_exprs)
9999
exprs += _named_exprs_as_aliases(named_exprs)
100100
resolve_names(self.source_table, exprs)
101-
return Select.make(self, columns=exprs, distinct=distinct)
101+
return Select.make(self, columns=exprs, distinct=distinct, optimizer_hints=optimizer_hints)
102102

103103
def where(self, *exprs):
104104
exprs = args_as_tuple(exprs)
@@ -682,6 +682,7 @@ class Select(ExprNode, ITable, Root):
682682
having_exprs: Sequence[Expr] = None
683683
limit_expr: int = None
684684
distinct: bool = False
685+
optimizer_hints: Sequence[Expr] = None
685686

686687
@property
687688
def schema(self):
@@ -699,7 +700,8 @@ def compile(self, parent_c: Compiler) -> str:
699700

700701
columns = ", ".join(map(c.compile, self.columns)) if self.columns else "*"
701702
distinct = "DISTINCT " if self.distinct else ""
702-
select = f"SELECT {distinct}{columns}"
703+
optimizer_hints = c.dialect.optimizer_hints(self.optimizer_hints) if self.optimizer_hints else ""
704+
select = f"SELECT {optimizer_hints}{distinct}{columns}"
703705

704706
if self.table:
705707
select += " FROM " + c.compile(self.table)
@@ -729,15 +731,19 @@ def compile(self, parent_c: Compiler) -> str:
729731
return select
730732

731733
@classmethod
732-
def make(cls, table: ITable, distinct: bool = SKIP, **kwargs):
734+
def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, **kwargs):
733735
assert "table" not in kwargs
734736

735737
if not isinstance(table, cls): # If not Select
736738
if distinct is not SKIP:
737739
kwargs["distinct"] = distinct
740+
if optimizer_hints is not SKIP:
741+
kwargs["optimizer_hints"] = optimizer_hints
738742
return cls(table, **kwargs)
739743

740744
# We can safely assume isinstance(table, Select)
745+
if optimizer_hints is not SKIP:
746+
kwargs["optimizer_hints"] = optimizer_hints
741747

742748
if distinct is not SKIP:
743749
if distinct == False and table.distinct:
@@ -752,7 +758,7 @@ def make(cls, table: ITable, distinct: bool = SKIP, **kwargs):
752758
if getattr(table, k) is not None:
753759
if k == "where_exprs": # Additive attribute
754760
kwargs[k] = getattr(table, k) + v
755-
elif k == "distinct":
761+
elif k in ["distinct", "optimizer_hints"]:
756762
pass
757763
else:
758764
raise ValueError(k)

sqeleton/repl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def repl(uri):
4848
continue
4949
try:
5050
path = db.parse_table_name(table_name)
51-
print('->', path)
51+
print("->", path)
5252
schema = db.query_table_schema(path)
5353
except Exception as e:
5454
logging.error(e)

sqeleton/utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
1-
from typing import Iterable, Iterator, MutableMapping, Union, Any, Sequence, Dict, Hashable, TypeVar, TYPE_CHECKING, List
1+
from typing import (
2+
Iterable,
3+
Iterator,
4+
MutableMapping,
5+
Union,
6+
Any,
7+
Sequence,
8+
Dict,
9+
Hashable,
10+
TypeVar,
11+
TYPE_CHECKING,
12+
List,
13+
)
214
from abc import abstractmethod
315
from weakref import ref
416
import math
@@ -256,7 +268,6 @@ def __eq__(self, other):
256268
return NotImplemented
257269
return self._str == other._str
258270

259-
260271
def new(self, *args, **kw):
261272
return type(self)(*args, **kw, max_len=self._max_len)
262273

tests/test_query.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def normalize_spaces(s: str):
1616
class MockDialect(AbstractDialect):
1717
name = "MockDialect"
1818

19+
PLACEHOLDER_TABLE = None
1920
ROUNDS_ON_PREC_LOSS = False
2021

2122
def quote(self, s: str) -> str:
@@ -50,6 +51,9 @@ def timestamp_value(self, t: datetime) -> str:
5051
def set_timezone_to_utc(self) -> str:
5152
return "set timezone 'UTC'"
5253

54+
def optimizer_hints(self, s: str):
55+
return f"/*+ {s} */ "
56+
5357
def load_mixins(self):
5458
raise NotImplementedError()
5559

@@ -189,6 +193,24 @@ def test_select_distinct(self):
189193
q = c.compile(t.select(this.b, distinct=True).select(distinct=False))
190194
self.assertEqual(q, "SELECT * FROM (SELECT DISTINCT b FROM a) tmp2")
191195

196+
def test_select_with_optimizer_hints(self):
197+
c = Compiler(MockDatabase())
198+
t = table("a")
199+
200+
q = c.compile(t.select(this.b, optimizer_hints="PARALLEL(a 16)"))
201+
assert q == "SELECT /*+ PARALLEL(a 16) */ b FROM a"
202+
203+
q = c.compile(t.where(this.b > 10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
204+
self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM a WHERE (b > 10)")
205+
206+
q = c.compile(t.limit(10).select(this.b, optimizer_hints="PARALLEL(a 16)"))
207+
self.assertEqual(q, "SELECT /*+ PARALLEL(a 16) */ b FROM (SELECT * FROM a LIMIT 10) tmp1")
208+
209+
q = c.compile(t.select(this.a).group_by(this.b).agg(this.c).select(optimizer_hints="PARALLEL(a 16)"))
210+
self.assertEqual(
211+
q, "SELECT /*+ PARALLEL(a 16) */ * FROM (SELECT b, c FROM (SELECT a FROM a) tmp2 GROUP BY 1) tmp3"
212+
)
213+
192214
def test_table_ops(self):
193215
c = Compiler(MockDatabase())
194216
a = table("a").select(this.x)

0 commit comments

Comments
 (0)