Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Upcoming (TBD)
==============

Features
---------
* Silently accept forward slash to introduce special commands.


Internal
--------
* Add test coverage for `client_commands.py`.
Expand Down
12 changes: 8 additions & 4 deletions mycli/clibuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from prompt_toolkit.filters import Condition, Filter

from mycli.packages.special import iocommands
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
from mycli.packages.special.main import (
CASE_INSENSITIVE_COMMANDS,
CASE_SENSITIVE_COMMANDS,
)


def cli_is_multiline(mycli) -> Filter:
Expand All @@ -26,12 +29,13 @@ def _multiline_exception(text: str) -> bool:
# Multi-statement favorite query is a special case. Because there will
# be a semicolon separating statements, we can't consider semicolon an
# EOL. Let's consider an empty line an EOL instead.
if first_word.startswith("\\fs"):
if first_word.startswith(("\\fs", "/fs")):
return orig.endswith("\n")

return (
# Special Command
first_word.startswith("\\")
or (first_word.startswith('/') and not first_word.startswith('/*'))
or text.endswith((
# Ended with the current delimiter (usually a semi-column)
iocommands.get_current_delimiter(),
Expand All @@ -44,10 +48,10 @@ def _multiline_exception(text: str) -> bool:
))
or
# non-backslashed special commands such as "exit" or "help" don't need semicolon
first_word in SPECIAL_COMMANDS
first_word in CASE_SENSITIVE_COMMANDS
or
# uppercase variants accepted
first_word.lower() in SPECIAL_COMMANDS
first_word.lower() in CASE_INSENSITIVE_COMMANDS
or
# just a plain enter without any text
(first_word == "")
Expand Down
3 changes: 2 additions & 1 deletion mycli/main_modes/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def complete_while_typing_filter() -> bool:
last_word = text[-MIN_COMPLETION_TRIGGER:]
if len(last_word) == text_len:
return text_len >= MIN_COMPLETION_TRIGGER
if text[:6].lower() in ['source', r'\.']:
# does \. make sense with text[:6] ?
if text[:6].lower() in ['source', r'\.', '/.']:
# Different word characters for paths; see comment below.
# In fact, it might be nice if paths had a different threshold.
return not bool(re.search(r'[\s!-,:-@\[-^\{\}-]', last_word))
Expand Down
25 changes: 16 additions & 9 deletions mycli/packages/completion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def suggest_type(full_text: str, text_before_cursor: str) -> list[dict[str, Any]
# but the statement won't have a first token
tok1 = statement.token_first()
# lenient because \. will parse as two tokens
if tok1 and tok1.value.startswith('\\'):
if tok1 and tok1.value.startswith(('\\', '/')) and not tok1.value.startswith('/*'):
return suggest_special(text_before_cursor)
elif tok1:
if tok1.value.lower() in SPECIAL_COMMANDS:
Expand All @@ -771,42 +771,49 @@ def suggest_special(text: str) -> list[dict[str, Any]]:
# Trying to complete the special command itself
return [{"type": "special"}]

if cmd in ("\\u", "\\r"):
if cmd in ("\\u", "/u", "\\r", "/r"):
return [{"type": "database"}]

if cmd.lower() in ('use', 'connect'):
if cmd.lower() in ('use', '/use', 'connect', '/connect'):
return [{'type': 'database'}]

if cmd in (r'\T', r'\Tr'):
if cmd in (r'\T', '/T', r'\Tr', '/Tr'):
return [{"type": "table_format"}]

if cmd.lower() in ('tableformat', 'redirectformat'):
if cmd.lower() in ('tableformat', '/tableformat', 'redirectformat', '/redirectformat'):
return [{"type": "table_format"}]

if cmd in ["\\f", "\\fs", "\\fd"]:
if cmd in ["\\f", "/f", "\\fs", "/fs", "\\fd", "/fd"]:
return [{"type": "favoritequery"}]

if cmd in ["\\dt", "\\dt+"]:
if cmd in ["\\dt", "/dt", "\\dt+", "/dt+"]:
return [
{"type": "table", "schema": []},
{"type": "view", "schema": []},
{"type": "schema"},
]
elif cmd.lower() in [
r'\.',
r'/.',
'source',
'/source',
r'\o',
'/o',
r'\once',
r'tee',
'/once',
'tee',
'/tee',
]:
return [{"type": "file_name"}]
# todo: why is \edit case-sensitive?
elif cmd in [
r'\e',
'/e',
r'\edit',
'/edit',
]:
return [{"type": "file_name"}]
if cmd in ["\\llm", "\\ai"]:
if cmd in ["\\llm", "/llm", "\\ai", "/ai"]:
return [{"type": "llm"}]

return [{"type": "keyword"}, {"type": "special"}]
Expand Down
18 changes: 10 additions & 8 deletions mycli/packages/special/iocommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,20 +205,19 @@ def editor_command(command: str) -> bool:
:param command: string
"""
# special case: allow help on the \edit command
if re.match(r'^([Hh][Ee][Ll][Pp])\s+(\\e|\\edit)\s*(;|\\G|\\g)?\s*$', command):
if re.match(r'^/?([Hh][Ee][Ll][Pp])\s+(\\e|\\edit|/e|/edit)\s*(;|\\G|\\g)?\s*$', command):
return False
# It is possible to have `\e filename` or `SELECT * FROM \e`. So we check
# for both conditions.
return (
command.strip().endswith("\\e")
or command.strip().startswith("\\e ")
or command.strip().endswith("\\edit")
or command.strip().startswith("\\edit ")
command.strip().endswith(("\\e", "\\edit"))
or command.strip().startswith(("\\e ", "/e ", "\\edit ", "/edit "))
or command.strip() in (("\\e", "/e", "\\edit", "/edit"))
)


def get_filename(sql: str) -> str | None:
if sql.strip().startswith("\\e ") or sql.strip().startswith("\\edit "):
if sql.strip().startswith(("\\e ", "/e ")) or sql.strip().startswith(("\\edit ", "/edit ")):
command, _, filename = sql.partition(" ")
return filename.strip() or None
else:
Expand All @@ -229,6 +228,9 @@ def get_editor_query(sql: str) -> str:
"""Get the query part of an editor command."""
sql = sql.strip()

if sql in ('\\e', '/e', '\\edit', '/edit'):
return ''

# The reason we can't simply do .strip('\e') is that it strips characters,
# not a substring. So it'll strip "e" in the end of the sql also!
# Ex: "select * from style\e" -> "select * from styl".
Expand Down Expand Up @@ -281,7 +283,7 @@ def clip_command(command: str) -> bool:
"""
# It is possible to have `\clip` or `SELECT * FROM \clip`. So we check
# for both conditions.
return command.strip().endswith("\\clip") or command.strip().startswith("\\clip")
return command.strip().endswith("\\clip") or command.strip().startswith(("\\clip", "/clip"))


def get_clip_query(sql: str) -> str:
Expand All @@ -290,7 +292,7 @@ def get_clip_query(sql: str) -> str:

# The reason we can't simply do .strip('\clip') is that it strips characters,
# not a substring. So it'll strip "c" in the end of the sql also!
pattern = re.compile(r"(^\\clip|\\clip$)")
pattern = re.compile(r"(^\\clip|^/clip|\\clip$)")
while pattern.search(sql):
sql = pattern.sub("", sql)

Expand Down
4 changes: 2 additions & 2 deletions mycli/packages/special/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def handle_llm(
_, command_verbosity, arg = parse_special_command(text)
if not LLM_IMPORTED:
raise FinishIteration(results=[SQLResult(preamble=NEED_DEPENDENCIES)])
if arg.strip().lower() in ['', 'help', '?', r'\?']:
if arg.strip().lower() in ['', 'help', '/help', '?', r'\?', '/?']:
raise FinishIteration(results=[SQLResult(preamble=USAGE)])
parts = shlex.split(arg)
restart = False
Expand Down Expand Up @@ -286,7 +286,7 @@ def handle_llm(

def is_llm_command(command: str) -> bool:
cmd, _, _ = parse_special_command(command)
return cmd in ("\\llm", "\\ai")
return cmd in ("\\llm", "/llm", "\\ai", "/ai")


def truncate_list_elements(row: list, prompt_field_truncate: int, prompt_section_truncate: int) -> list:
Expand Down
36 changes: 35 additions & 1 deletion mycli/packages/special/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ def register_special_command(
case_sensitive: bool = False,
aliases: list[SpecialCommandAlias] | None = None,
) -> None:
if command.startswith('\\'):
forwardslash_command = '/' + command.removeprefix('\\')
else:
forwardslash_command = '/' + command
cmd = command.lower() if not case_sensitive else command
fcmd = forwardslash_command.lower() if not case_sensitive else forwardslash_command
COMMANDS[cmd] = SpecialCommand(
handler,
command,
Expand All @@ -117,17 +122,36 @@ def register_special_command(
case_sensitive=case_sensitive,
aliases=aliases,
)
COMMANDS[fcmd] = SpecialCommand(
handler,
command,
usage,
description,
arg_type=arg_type,
hidden=True,
case_sensitive=case_sensitive,
aliases=aliases,
)
if case_sensitive:
CASE_SENSITIVE_COMMANDS.add(command)
CASE_SENSITIVE_COMMANDS.add(forwardslash_command)
else:
CASE_INSENSITIVE_COMMANDS.add(command.lower())
CASE_INSENSITIVE_COMMANDS.add(forwardslash_command.lower())
aliases = [] if aliases is None else aliases
for alias in aliases:
if alias.command.startswith('\\'):
forwardslash_alias_command = '/' + alias.command.removeprefix('\\')
else:
forwardslash_alias_command = '/' + alias.command
cmd = alias.command.lower() if not alias.case_sensitive else alias.command
fcmd = forwardslash_alias_command.lower() if not alias.case_sensitive else forwardslash_alias_command
if alias.case_sensitive:
CASE_SENSITIVE_COMMANDS.add(alias.command)
CASE_SENSITIVE_COMMANDS.add(forwardslash_alias_command)
else:
CASE_INSENSITIVE_COMMANDS.add(alias.command.lower())
CASE_INSENSITIVE_COMMANDS.add(forwardslash_alias_command.lower())
COMMANDS[cmd] = SpecialCommand(
handler,
command,
Expand All @@ -138,6 +162,16 @@ def register_special_command(
hidden=True,
aliases=None,
)
COMMANDS[fcmd] = SpecialCommand(
handler,
command,
usage,
description,
arg_type=arg_type,
case_sensitive=alias.case_sensitive,
hidden=True,
aliases=None,
)


def execute(cur: Cursor, sql: str) -> list[SQLResult]:
Expand All @@ -158,7 +192,7 @@ def execute(cur: Cursor, sql: str) -> list[SQLResult]:

# "help <SQL KEYWORD> is a special case. We want built-in help, not
# mycli help here.
if command.lower() == "help" and arg:
if command.lower().startswith(("help", "/help", "\\?", "/?", "?")) and arg:
return show_keyword_help(cur=cur, arg=arg)

if special_cmd.arg_type == ArgType.NO_QUERY:
Expand Down
21 changes: 17 additions & 4 deletions mycli/packages/sql_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,20 @@ def need_completion_refresh(queries: str) -> bool:
for query in sqlparse.split(queries):
try:
first_token = query.split()[0]
if first_token.lower() in ("alter", "create", "use", "\\r", "\\u", "connect", "drop", "rename"):
if first_token.lower() in (
"alter",
"create",
"use",
"/use",
"\\r",
"\\u",
"/r",
"/u",
"connect",
"/connect",
"drop",
"rename",
):
return True
except Exception:
continue
Expand All @@ -447,9 +460,9 @@ def need_completion_reset(queries: str) -> bool:
try:
tokens = query.split()
first_token = tokens[0]
if first_token.lower() in ("use", "\\u"):
if first_token.lower() in ("use", "/use", "\\u", "/u"):
return True
if first_token.lower() in ("\\r", "connect") and len(tokens) > 1:
if first_token.lower() in ("\\r", "/r", "connect", "/connect") and len(tokens) > 1:
return True
except Exception:
continue
Expand Down Expand Up @@ -502,7 +515,7 @@ def classify_sandbox_statement(text: str) -> tuple[str | None, str | None]:
return ('quit', None)

# \q
if len(tokens) == 2 and types[0] == tt.BACKSLASH and texts[1] == 'Q':
if len(tokens) == 2 and types[0] in (tt.BACKSLASH, tt.SLASH) and texts[1] in ('Q', 'QUIT', 'EXIT'):
return ('quit', None)

# ALTER USER ...
Expand Down
2 changes: 1 addition & 1 deletion mycli/sqlexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def run(self, statement: str) -> Generator[SQLResult, None, None]:
# Split the sql into separate queries and run each one.
# Unless it's saving a favorite query, in which case we
# want to save them all together.
if statement.startswith("\\fs"):
if statement.startswith(("\\fs", "/fs")):
components: Iterable[str] = [statement]
else:
components = iocommands.split_queries(statement)
Expand Down
9 changes: 6 additions & 3 deletions test/pytests/test_clibuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def test_multiline_exception_detects_commands_terminators_and_plain_sql(
expected: bool,
) -> None:
monkeypatch.setattr(clibuffer.iocommands, 'get_current_delimiter', lambda: '//')
monkeypatch.setattr(clibuffer, 'SPECIAL_COMMANDS', {'help': object(), 'exit': object()})
monkeypatch.setattr(clibuffer, 'CASE_SENSITIVE_COMMANDS', {'Camel'})
monkeypatch.setattr(clibuffer, 'CASE_INSENSITIVE_COMMANDS', {'help', 'exit'})

assert clibuffer._multiline_exception(text) is expected

Expand All @@ -85,7 +86,8 @@ def test_multiline_exception_recognizes_non_backslashed_special_commands_with_ge
text: str,
) -> None:
monkeypatch.setattr(clibuffer.iocommands, 'get_current_delimiter', lambda: ';')
monkeypatch.setattr(clibuffer, 'SPECIAL_COMMANDS', {'help': object(), 'exit': object()})
monkeypatch.setattr(clibuffer, 'CASE_SENSITIVE_COMMANDS', {'Camel'})
monkeypatch.setattr(clibuffer, 'CASE_INSENSITIVE_COMMANDS', {'help', 'exit'})

assert clibuffer._multiline_exception(text) is True

Expand All @@ -107,7 +109,8 @@ def test_cli_is_multiline_uses_buffer_text_when_multiline_mode_is_enabled(

monkeypatch.setattr(clibuffer, 'get_app', lambda: app)
monkeypatch.setattr(clibuffer.iocommands, 'get_current_delimiter', lambda: ';')
monkeypatch.setattr(clibuffer, 'SPECIAL_COMMANDS', {'help': object()})
monkeypatch.setattr(clibuffer, 'CASE_SENSITIVE_COMMANDS', {'Camel'})
monkeypatch.setattr(clibuffer, 'CASE_INSENSITIVE_COMMANDS', {'help'})

multiline_filter = clibuffer.cli_is_multiline(mycli)

Expand Down
2 changes: 2 additions & 0 deletions test/pytests/test_special_iocommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_editor_command(monkeypatch):
assert mycli.packages.special.editor_command(r"hello\edit")
assert mycli.packages.special.editor_command(r"\e hello")
assert mycli.packages.special.editor_command(r"\edit hello")
assert mycli.packages.special.editor_command('/edit')

assert not mycli.packages.special.editor_command(r"HELP \e")
assert not mycli.packages.special.editor_command(r"help \edit\g")
Expand All @@ -182,6 +183,7 @@ def test_editor_command(monkeypatch):
assert not mycli.packages.special.editor_command(r"\edithello")

assert mycli.packages.special.get_filename(r"\e filename") == "filename"
assert mycli.packages.special.get_editor_query('/edit') == ''

if os.name != "nt":
assert mycli.packages.special.open_external_editor(sql=r"select 1") == ('select 1', None)
Expand Down
4 changes: 3 additions & 1 deletion test/pytests/test_special_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_register_special_command_tracks_case_insensitive_commands(restore_comma
)

assert special_main.CASE_SENSITIVE_COMMANDS == set()
assert special_main.CASE_INSENSITIVE_COMMANDS == {'demo', '\\d'}
assert special_main.CASE_INSENSITIVE_COMMANDS == {'demo', '/demo', '\\d', '/d'}


def test_special_command_decorator_registers_case_sensitive_command(restore_commands: None) -> None:
Expand All @@ -134,8 +134,10 @@ def handler() -> None:

assert special_main.COMMANDS['Camel'].handler is handler
assert 'Camel' in special_main.CASE_SENSITIVE_COMMANDS
assert '/Camel' in special_main.CASE_SENSITIVE_COMMANDS
assert special_main.CASE_INSENSITIVE_COMMANDS == set()
assert 'camel' not in special_main.COMMANDS
assert '/camel' not in special_main.COMMANDS


def test_execute_raises_when_command_is_missing() -> None:
Expand Down
Loading