From 6a226c22bcd20de49f74437cf459304c544e871d Mon Sep 17 00:00:00 2001 From: blackax Date: Tue, 3 Mar 2026 20:55:27 -0600 Subject: [PATCH] fix(lint): apply ruff formatting and remove unused imports Fix CI pipeline failures caused by ruff format and lint check violations. --- src/ssh_mcp/config.py | 26 +++++++++----------------- src/ssh_mcp/formatting.py | 12 +++++++----- src/ssh_mcp/ssh.py | 8 ++++---- tests/conftest.py | 1 - tests/test_config.py | 18 +++++------------- tests/test_formatting.py | 9 ++------- tests/test_ssh.py | 27 +++++++++++++++++---------- 7 files changed, 44 insertions(+), 57 deletions(-) diff --git a/src/ssh_mcp/config.py b/src/ssh_mcp/config.py index a1baaf6..3a86cd7 100644 --- a/src/ssh_mcp/config.py +++ b/src/ssh_mcp/config.py @@ -63,9 +63,7 @@ def get_server(self, name: str) -> ServerConfig: """ if name not in self._servers: available = ", ".join(sorted(self._servers.keys())) - raise KeyError( - f"Server '{name}' not found. Available servers: {available}" - ) + raise KeyError(f"Server '{name}' not found. Available servers: {available}") return self._servers[name] def get_group(self, name: str) -> GroupConfig: @@ -82,9 +80,7 @@ def get_group(self, name: str) -> GroupConfig: """ if name not in self._groups: available = ", ".join(sorted(self._groups.keys())) - raise KeyError( - f"Group '{name}' not found. Available groups: {available}" - ) + raise KeyError(f"Group '{name}' not found. Available groups: {available}") return self._groups[name] def servers_in_group(self, group_name: str) -> list[ServerConfig]: @@ -102,9 +98,7 @@ def servers_in_group(self, group_name: str) -> list[ServerConfig]: # Validate group exists first self.get_group(group_name) return [ - server - for server in self._servers.values() - if group_name in server.groups + server for server in self._servers.values() if group_name in server.groups ] def all_servers(self) -> list[ServerConfig]: @@ -199,29 +193,27 @@ def _validate(self) -> None: for server_name, server in self._servers.items(): # Every server must reference at least one group if not server.groups: - logger.warning( - "Server '%s' has no groups assigned", server_name - ) + logger.warning("Server '%s' has no groups assigned", server_name) # Every group referenced by a server must be defined for group in server.groups: if group not in group_names: logger.warning( "Server '%s' references undefined group '%s'", - server_name, group, + server_name, + group, ) # Server names must not collide with group names if server_name in group_names: - logger.warning( - "Server name '%s' collides with group name", server_name - ) + logger.warning("Server name '%s' collides with group name", server_name) # jump_host value must reference another defined server if server.jump_host and server.jump_host not in server_names: logger.warning( "Server '%s' references undefined jump_host '%s'", - server_name, server.jump_host, + server_name, + server.jump_host, ) # Detect circular jump host chains diff --git a/src/ssh_mcp/formatting.py b/src/ssh_mcp/formatting.py index 485a032..e4bc85b 100644 --- a/src/ssh_mcp/formatting.py +++ b/src/ssh_mcp/formatting.py @@ -60,9 +60,7 @@ def format_server_table(servers: list[ServerConfig], filter_label: str = "") -> return "\n".join(lines) -def format_group_table( - groups: list[GroupConfig], server_counts: dict[str, int] -) -> str: +def format_group_table(groups: list[GroupConfig], server_counts: dict[str, int]) -> str: """Format a list of groups into a text table. Args: @@ -96,7 +94,9 @@ def format_group_table( # Build rows for group in groups: count = server_counts.get(group.name, 0) - lines.append(f"{group.name:<{max_name}} {count:<{max_count}} {group.description}") + lines.append( + f"{group.name:<{max_name}} {count:<{max_count}} {group.description}" + ) return "\n".join(lines) @@ -175,7 +175,9 @@ def format_group_results(results: list[ExecResult], group_name: str) -> str: Summary: 2 succeeded, 0 failed """ if not results: - return f"Executing on group '{group_name}' (0 servers)...\n\nNo servers in group." + return ( + f"Executing on group '{group_name}' (0 servers)...\n\nNo servers in group." + ) lines = [ f"Executing on group '{group_name}' ({len(results)} servers)...", diff --git a/src/ssh_mcp/ssh.py b/src/ssh_mcp/ssh.py index 11693d1..20ba0ee 100644 --- a/src/ssh_mcp/ssh.py +++ b/src/ssh_mcp/ssh.py @@ -320,7 +320,9 @@ async def execute_with_semaphore(server: ServerConfig) -> ExecResult: for future in asyncio.as_completed(actual_tasks): result = await future results.append(result) - if result.error or (result.exit_code is not None and result.exit_code != 0): + if result.error or ( + result.exit_code is not None and result.exit_code != 0 + ): for task in actual_tasks: if not task.done(): task.cancel() @@ -376,9 +378,7 @@ async def execute_with_semaphore(server: ServerConfig) -> ExecResult: ) ] - async def upload( - self, server_name: str, local_path: str, remote_path: str - ) -> str: + async def upload(self, server_name: str, local_path: str, remote_path: str) -> str: """Upload file to remote server via SFTP. Args: diff --git a/tests/conftest.py b/tests/conftest.py index 608a807..9c7b869 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,6 @@ from __future__ import annotations -import tempfile from pathlib import Path import pytest diff --git a/tests/test_config.py b/tests/test_config.py index 3e4f9b0..5a750b8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -118,9 +118,7 @@ def test_get_group_unknown_raises_keyerror(self, tmp_config_file: Path) -> None: assert "not found" in error_msg.lower() assert "available groups" in error_msg.lower() - def test_servers_in_group_returns_correct_list( - self, tmp_config_file: Path - ) -> None: + def test_servers_in_group_returns_correct_list(self, tmp_config_file: Path) -> None: """Test servers_in_group returns all servers in a group.""" registry = ServerRegistry(str(tmp_config_file)) @@ -131,9 +129,7 @@ def test_servers_in_group_returns_correct_list( for server in servers: assert "test-prod" in server.groups - def test_servers_in_group_unknown_group_raises( - self, tmp_config_file: Path - ) -> None: + def test_servers_in_group_unknown_group_raises(self, tmp_config_file: Path) -> None: """Test servers_in_group raises KeyError for unknown group.""" registry = ServerRegistry(str(tmp_config_file)) @@ -169,7 +165,7 @@ def test_validation_warnings_logged( ) -> None: """Test validation warnings are logged for invalid config.""" with caplog.at_level(logging.WARNING): - registry = ServerRegistry(str(invalid_config_file)) + ServerRegistry(str(invalid_config_file)) # Check that warnings were logged assert len(caplog.records) > 0 @@ -229,9 +225,7 @@ def test_server_with_overrides_loaded_correctly( class TestServerRegistrySettings: """Tests for ServerRegistry settings loading.""" - def test_settings_property_returns_settings( - self, tmp_config_file: Path - ) -> None: + def test_settings_property_returns_settings(self, tmp_config_file: Path) -> None: """Test settings property returns Settings instance.""" registry = ServerRegistry(str(tmp_config_file)) @@ -252,9 +246,7 @@ def test_settings_defaults_applied_when_not_overridden( assert settings.max_output_bytes == 51200 assert settings.connection_idle_timeout == 300 - def test_settings_ssh_config_path_expanded( - self, tmp_config_file: Path - ) -> None: + def test_settings_ssh_config_path_expanded(self, tmp_config_file: Path) -> None: """Test tilde expansion in ssh_config_path.""" registry = ServerRegistry(str(tmp_config_file)) diff --git a/tests/test_formatting.py b/tests/test_formatting.py index 324915f..8bead0f 100644 --- a/tests/test_formatting.py +++ b/tests/test_formatting.py @@ -6,7 +6,6 @@ from __future__ import annotations -import pytest from ssh_mcp.formatting import ( format_exec_result, @@ -141,9 +140,7 @@ def test_format_group_table_zero_servers(self) -> None: class TestFormatExecResult: """Tests for format_exec_result function.""" - def test_format_exec_result_success( - self, sample_exec_result: ExecResult - ) -> None: + def test_format_exec_result_success(self, sample_exec_result: ExecResult) -> None: """Test formatting successful execution result.""" output = format_exec_result(sample_exec_result) @@ -158,9 +155,7 @@ def test_format_exec_result_success( assert "Exit code: 0" in output assert "150ms" in output - def test_format_exec_result_with_error( - self, sample_exec_error: ExecResult - ) -> None: + def test_format_exec_result_with_error(self, sample_exec_error: ExecResult) -> None: """Test formatting execution result with error.""" output = format_exec_result(sample_exec_error) diff --git a/tests/test_ssh.py b/tests/test_ssh.py index 2182801..307afa0 100644 --- a/tests/test_ssh.py +++ b/tests/test_ssh.py @@ -6,14 +6,18 @@ from __future__ import annotations -import asyncio -from unittest.mock import MagicMock import pytest from ssh_mcp.config import ServerRegistry from ssh_mcp.models import Settings -from ssh_mcp.ssh import SSHManager, _DANGEROUS_PATTERNS, _SENSITIVE_PATHS, _is_dangerous_command, _validate_remote_path +from ssh_mcp.ssh import ( + SSHManager, + _DANGEROUS_PATTERNS, + _SENSITIVE_PATHS, + _is_dangerous_command, + _validate_remote_path, +) # --------------------------------------------------------------------------- @@ -160,11 +164,15 @@ def test_rm_rf_slash_pattern(self) -> None: def test_mkfs_pattern(self) -> None: assert _DANGEROUS_PATTERNS[1].search("mkfs.ext4 /dev/sda") is not None assert _DANGEROUS_PATTERNS[1].search("mkfs") is not None - assert _DANGEROUS_PATTERNS[1].search("ls mkfs_backup") is not None # substring match + assert ( + _DANGEROUS_PATTERNS[1].search("ls mkfs_backup") is not None + ) # substring match def test_dd_if_pattern(self) -> None: assert _DANGEROUS_PATTERNS[2].search("dd if=/dev/zero of=/dev/sda") is not None - assert _DANGEROUS_PATTERNS[2].search("dd if=input.bin of=output.bin") is not None + assert ( + _DANGEROUS_PATTERNS[2].search("dd if=input.bin of=output.bin") is not None + ) assert _DANGEROUS_PATTERNS[2].search("dd bs=512 count=1") is None def test_redirect_dev_sd_pattern(self) -> None: @@ -185,7 +193,9 @@ def test_chmod_777_slash_pattern(self) -> None: def test_fork_bomb_pattern(self) -> None: # Pattern uses \s* (zero-or-more), so spaces are optional assert _DANGEROUS_PATTERNS[5].search(":(){ :|:& };:") is not None - assert _DANGEROUS_PATTERNS[5].search(":(){ :|:&};:") is not None # no trailing space also matches + assert ( + _DANGEROUS_PATTERNS[5].search(":(){ :|:&};:") is not None + ) # no trailing space also matches # Pattern requires the exact structure; unrelated strings do not match assert _DANGEROUS_PATTERNS[5].search("echo hello") is None assert _DANGEROUS_PATTERNS[5].search("ls -la") is None @@ -290,7 +300,6 @@ class TestSSHManagerInit: def _make_registry(self) -> ServerRegistry: """Return a minimal ServerRegistry backed by a real config file.""" import tempfile - from pathlib import Path config_content = """ [settings] @@ -303,9 +312,7 @@ def _make_registry(self) -> ServerRegistry: description = "Test server" groups = ["test"] """ - tmp = tempfile.NamedTemporaryFile( - suffix=".toml", mode="w", delete=False - ) + tmp = tempfile.NamedTemporaryFile(suffix=".toml", mode="w", delete=False) tmp.write(config_content) tmp.flush() tmp.close()