Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from pathlib import Path

from codeflash.cli_cmds import logging_config
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.cli_common import apologize_and_exit, get_git_repo_or_none, parse_config_file_or_exit
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.extension import install_vscode_extension
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.version import __version__ as version

Expand Down Expand Up @@ -163,10 +162,7 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:


def process_pyproject_config(args: Namespace) -> Namespace:
try:
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
except ValueError as e:
exit_with_message(f"Error parsing config file: {e}", error_on_exit=True)
pyproject_config, pyproject_file_path = parse_config_file_or_exit(args.config_file)
supported_keys = [
"module_root",
"tests_root",
Expand Down Expand Up @@ -248,21 +244,21 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
no_pr = getattr(args, "no_pr", False)

if not no_pr:
import git

from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
from codeflash.code_utils.github_utils import require_github_app_or_exit

# Ensure that the user can actually open PRs on the repo.
try:
git_repo = git.Repo(search_parent_directories=True)
except git.exc.InvalidGitRepositoryError:
maybe_git_repo = get_git_repo_or_none()
if maybe_git_repo is None:
mode = "--all" if hasattr(args, "all") else "--file"
logger.exception(
logger.error(
f"I couldn't find a git repository in the current directory. "
f"I need a git repository to run {mode} and open PRs for optimizations. Exiting..."
)
apologize_and_exit()
# After None check and apologize_and_exit(), we know git_repo is not None
git_repo = maybe_git_repo
assert git_repo is not None # For mypy
git_remote = getattr(args, "git_remote", None)
if not check_and_push_branch(git_repo, git_remote=git_remote):
exit_with_message("Branch is not pushed...", error_on_exit=True)
Expand Down
119 changes: 43 additions & 76 deletions codeflash/cli_cmds/cli_common.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

import shutil
import sys
from typing import Callable, cast

import click
import inquirer
from typing import TYPE_CHECKING, Any, Optional

from codeflash.cli_cmds.console import console, logger

if TYPE_CHECKING:
from pathlib import Path

from git import Repo


def apologize_and_exit() -> None:
console.rule()
Expand All @@ -20,78 +21,44 @@ def apologize_and_exit() -> None:
sys.exit(1)


def inquirer_wrapper(func: Callable[..., str | bool], *args: str | bool, **kwargs: str | bool) -> str | bool:
new_args = []
new_kwargs = {}

if len(args) == 1:
message = str(args[0])
else:
message = str(kwargs["message"])
new_kwargs = kwargs.copy()
split_messages = split_string_to_cli_width(message, is_confirm=func == inquirer.confirm)
for split_message in split_messages[:-1]:
click.echo(split_message)

last_message = split_messages[-1]

if len(args) == 1:
new_args.append(last_message)
else:
new_kwargs["message"] = last_message

return func(*new_args, **new_kwargs)


def split_string_to_cli_width(string: str, is_confirm: bool = False) -> list[str]: # noqa: FBT001, FBT002
cli_width, _ = shutil.get_terminal_size()
# split string to lines that accommodate "[?] " prefix
cli_width -= len("[?] ")
lines = split_string_to_fit_width(string, cli_width)
def get_git_repo_or_none(search_path: Optional[Path] = None) -> Optional[Repo]:
"""Get git repository or None if not in a git repo."""
import git

# split last line to additionally accommodate ": " or " (y/N): " suffix
cli_width -= len(" (y/N):") if is_confirm else len(": ")
last_lines = split_string_to_fit_width(lines[-1], cli_width)
try:
if search_path:
return git.Repo(search_path, search_parent_directories=True)
return git.Repo(search_parent_directories=True)
except git.InvalidGitRepositoryError:
return None

lines = lines[:-1] + last_lines

if len(lines) > 1:
for i in range(len(lines[:-1])):
# Add yellow color to question mark in "[?] " prefix
lines[i] = "[\033[33m?\033[0m] " + lines[i]
return lines


def inquirer_wrapper_path(*args: str, **kwargs: str) -> dict[str, str] | None:
new_args = []
message = kwargs["message"]
new_kwargs = kwargs.copy()
split_messages = split_string_to_cli_width(message)
for split_message in split_messages[:-1]:
click.echo(split_message)

last_message = split_messages[-1]
new_kwargs["message"] = last_message
new_args.append(args[0])

return cast("dict[str, str]", inquirer.prompt([inquirer.Path(*new_args, **new_kwargs)]))


def split_string_to_fit_width(string: str, width: int) -> list[str]:
words = string.split()
lines = []
current_line = [words[0]]
current_length = len(words[0])

for word in words[1:]:
word_length = len(word)
if current_length + word_length + 1 <= width:
current_line.append(word)
current_length += word_length + 1
def require_git_repo_or_exit(search_path: Optional[Path] = None, error_message: Optional[str] = None) -> Repo:
"""Get git repository or exit with error."""
repo = get_git_repo_or_none(search_path)
if repo is None:
if error_message:
logger.error(error_message)
else:
lines.append(" ".join(current_line))
current_line = [word]
current_length = word_length

lines.append(" ".join(current_line))
return lines
logger.error(
"I couldn't find a git repository in the current directory. "
"A git repository is required for this operation."
)
apologize_and_exit()
# After checking for None and calling apologize_and_exit(), we know repo is not None
# but mypy doesn't understand apologize_and_exit() never returns, so we assert
assert repo is not None
return repo


def parse_config_file_or_exit(config_file: Optional[Path] = None, **kwargs: Any) -> tuple[dict[str, Any], Path]: # noqa: ANN401
"""Parse config file or exit with error."""
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.config_parser import parse_config_file

try:
return parse_config_file(config_file, **kwargs)
except ValueError as e:
exit_with_message(f"Error parsing config file: {e}", error_on_exit=True)
# exit_with_message never returns when error_on_exit=True, but mypy doesn't know that
raise # pragma: no cover
Loading
Loading