Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions src/ssh_mcp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def get_server(self, name: str) -> ServerConfig:
"""
if name not in self._servers:
available = ", ".join(sorted(self._servers.keys()))
raise KeyError(
f"Server '{name}' not found. Available servers: {available}"
)
raise KeyError(f"Server '{name}' not found. Available servers: {available}")
return self._servers[name]

def get_group(self, name: str) -> GroupConfig:
Expand All @@ -82,9 +80,7 @@ def get_group(self, name: str) -> GroupConfig:
"""
if name not in self._groups:
available = ", ".join(sorted(self._groups.keys()))
raise KeyError(
f"Group '{name}' not found. Available groups: {available}"
)
raise KeyError(f"Group '{name}' not found. Available groups: {available}")
return self._groups[name]

def servers_in_group(self, group_name: str) -> list[ServerConfig]:
Expand All @@ -102,9 +98,7 @@ def servers_in_group(self, group_name: str) -> list[ServerConfig]:
# Validate group exists first
self.get_group(group_name)
return [
server
for server in self._servers.values()
if group_name in server.groups
server for server in self._servers.values() if group_name in server.groups
]

def all_servers(self) -> list[ServerConfig]:
Expand Down Expand Up @@ -199,29 +193,27 @@ def _validate(self) -> None:
for server_name, server in self._servers.items():
# Every server must reference at least one group
if not server.groups:
logger.warning(
"Server '%s' has no groups assigned", server_name
)
logger.warning("Server '%s' has no groups assigned", server_name)

# Every group referenced by a server must be defined
for group in server.groups:
if group not in group_names:
logger.warning(
"Server '%s' references undefined group '%s'",
server_name, group,
server_name,
group,
)

# Server names must not collide with group names
if server_name in group_names:
logger.warning(
"Server name '%s' collides with group name", server_name
)
logger.warning("Server name '%s' collides with group name", server_name)

# jump_host value must reference another defined server
if server.jump_host and server.jump_host not in server_names:
logger.warning(
"Server '%s' references undefined jump_host '%s'",
server_name, server.jump_host,
server_name,
server.jump_host,
)

# Detect circular jump host chains
Expand Down
12 changes: 7 additions & 5 deletions src/ssh_mcp/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def format_server_table(servers: list[ServerConfig], filter_label: str = "") ->
return "\n".join(lines)


def format_group_table(
groups: list[GroupConfig], server_counts: dict[str, int]
) -> str:
def format_group_table(groups: list[GroupConfig], server_counts: dict[str, int]) -> str:
"""Format a list of groups into a text table.

Args:
Expand Down Expand Up @@ -96,7 +94,9 @@ def format_group_table(
# Build rows
for group in groups:
count = server_counts.get(group.name, 0)
lines.append(f"{group.name:<{max_name}} {count:<{max_count}} {group.description}")
lines.append(
f"{group.name:<{max_name}} {count:<{max_count}} {group.description}"
)

return "\n".join(lines)

Expand Down Expand Up @@ -175,7 +175,9 @@ def format_group_results(results: list[ExecResult], group_name: str) -> str:
Summary: 2 succeeded, 0 failed
"""
if not results:
return f"Executing on group '{group_name}' (0 servers)...\n\nNo servers in group."
return (
f"Executing on group '{group_name}' (0 servers)...\n\nNo servers in group."
)

lines = [
f"Executing on group '{group_name}' ({len(results)} servers)...",
Expand Down
8 changes: 4 additions & 4 deletions src/ssh_mcp/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,9 @@ async def execute_with_semaphore(server: ServerConfig) -> ExecResult:
for future in asyncio.as_completed(actual_tasks):
result = await future
results.append(result)
if result.error or (result.exit_code is not None and result.exit_code != 0):
if result.error or (
result.exit_code is not None and result.exit_code != 0
):
for task in actual_tasks:
if not task.done():
task.cancel()
Expand Down Expand Up @@ -376,9 +378,7 @@ async def execute_with_semaphore(server: ServerConfig) -> ExecResult:
)
]

async def upload(
self, server_name: str, local_path: str, remote_path: str
) -> str:
async def upload(self, server_name: str, local_path: str, remote_path: str) -> str:
"""Upload file to remote server via SFTP.

Args:
Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from __future__ import annotations

import tempfile
from pathlib import Path

import pytest
Expand Down
18 changes: 5 additions & 13 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ def test_get_group_unknown_raises_keyerror(self, tmp_config_file: Path) -> None:
assert "not found" in error_msg.lower()
assert "available groups" in error_msg.lower()

def test_servers_in_group_returns_correct_list(
self, tmp_config_file: Path
) -> None:
def test_servers_in_group_returns_correct_list(self, tmp_config_file: Path) -> None:
"""Test servers_in_group returns all servers in a group."""
registry = ServerRegistry(str(tmp_config_file))

Expand All @@ -131,9 +129,7 @@ def test_servers_in_group_returns_correct_list(
for server in servers:
assert "test-prod" in server.groups

def test_servers_in_group_unknown_group_raises(
self, tmp_config_file: Path
) -> None:
def test_servers_in_group_unknown_group_raises(self, tmp_config_file: Path) -> None:
"""Test servers_in_group raises KeyError for unknown group."""
registry = ServerRegistry(str(tmp_config_file))

Expand Down Expand Up @@ -169,7 +165,7 @@ def test_validation_warnings_logged(
) -> None:
"""Test validation warnings are logged for invalid config."""
with caplog.at_level(logging.WARNING):
registry = ServerRegistry(str(invalid_config_file))
ServerRegistry(str(invalid_config_file))

# Check that warnings were logged
assert len(caplog.records) > 0
Expand Down Expand Up @@ -229,9 +225,7 @@ def test_server_with_overrides_loaded_correctly(
class TestServerRegistrySettings:
"""Tests for ServerRegistry settings loading."""

def test_settings_property_returns_settings(
self, tmp_config_file: Path
) -> None:
def test_settings_property_returns_settings(self, tmp_config_file: Path) -> None:
"""Test settings property returns Settings instance."""
registry = ServerRegistry(str(tmp_config_file))

Expand All @@ -252,9 +246,7 @@ def test_settings_defaults_applied_when_not_overridden(
assert settings.max_output_bytes == 51200
assert settings.connection_idle_timeout == 300

def test_settings_ssh_config_path_expanded(
self, tmp_config_file: Path
) -> None:
def test_settings_ssh_config_path_expanded(self, tmp_config_file: Path) -> None:
"""Test tilde expansion in ssh_config_path."""
registry = ServerRegistry(str(tmp_config_file))

Expand Down
9 changes: 2 additions & 7 deletions tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from __future__ import annotations

import pytest

from ssh_mcp.formatting import (
format_exec_result,
Expand Down Expand Up @@ -141,9 +140,7 @@ def test_format_group_table_zero_servers(self) -> None:
class TestFormatExecResult:
"""Tests for format_exec_result function."""

def test_format_exec_result_success(
self, sample_exec_result: ExecResult
) -> None:
def test_format_exec_result_success(self, sample_exec_result: ExecResult) -> None:
"""Test formatting successful execution result."""
output = format_exec_result(sample_exec_result)

Expand All @@ -158,9 +155,7 @@ def test_format_exec_result_success(
assert "Exit code: 0" in output
assert "150ms" in output

def test_format_exec_result_with_error(
self, sample_exec_error: ExecResult
) -> None:
def test_format_exec_result_with_error(self, sample_exec_error: ExecResult) -> None:
"""Test formatting execution result with error."""
output = format_exec_result(sample_exec_error)

Expand Down
27 changes: 17 additions & 10 deletions tests/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@

from __future__ import annotations

import asyncio
from unittest.mock import MagicMock

import pytest

from ssh_mcp.config import ServerRegistry
from ssh_mcp.models import Settings
from ssh_mcp.ssh import SSHManager, _DANGEROUS_PATTERNS, _SENSITIVE_PATHS, _is_dangerous_command, _validate_remote_path
from ssh_mcp.ssh import (
SSHManager,
_DANGEROUS_PATTERNS,
_SENSITIVE_PATHS,
_is_dangerous_command,
_validate_remote_path,
)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -160,11 +164,15 @@ def test_rm_rf_slash_pattern(self) -> None:
def test_mkfs_pattern(self) -> None:
assert _DANGEROUS_PATTERNS[1].search("mkfs.ext4 /dev/sda") is not None
assert _DANGEROUS_PATTERNS[1].search("mkfs") is not None
assert _DANGEROUS_PATTERNS[1].search("ls mkfs_backup") is not None # substring match
assert (
_DANGEROUS_PATTERNS[1].search("ls mkfs_backup") is not None
) # substring match

def test_dd_if_pattern(self) -> None:
assert _DANGEROUS_PATTERNS[2].search("dd if=/dev/zero of=/dev/sda") is not None
assert _DANGEROUS_PATTERNS[2].search("dd if=input.bin of=output.bin") is not None
assert (
_DANGEROUS_PATTERNS[2].search("dd if=input.bin of=output.bin") is not None
)
assert _DANGEROUS_PATTERNS[2].search("dd bs=512 count=1") is None

def test_redirect_dev_sd_pattern(self) -> None:
Expand All @@ -185,7 +193,9 @@ def test_chmod_777_slash_pattern(self) -> None:
def test_fork_bomb_pattern(self) -> None:
# Pattern uses \s* (zero-or-more), so spaces are optional
assert _DANGEROUS_PATTERNS[5].search(":(){ :|:& };:") is not None
assert _DANGEROUS_PATTERNS[5].search(":(){ :|:&};:") is not None # no trailing space also matches
assert (
_DANGEROUS_PATTERNS[5].search(":(){ :|:&};:") is not None
) # no trailing space also matches
# Pattern requires the exact structure; unrelated strings do not match
assert _DANGEROUS_PATTERNS[5].search("echo hello") is None
assert _DANGEROUS_PATTERNS[5].search("ls -la") is None
Expand Down Expand Up @@ -290,7 +300,6 @@ class TestSSHManagerInit:
def _make_registry(self) -> ServerRegistry:
"""Return a minimal ServerRegistry backed by a real config file."""
import tempfile
from pathlib import Path

config_content = """
[settings]
Expand All @@ -303,9 +312,7 @@ def _make_registry(self) -> ServerRegistry:
description = "Test server"
groups = ["test"]
"""
tmp = tempfile.NamedTemporaryFile(
suffix=".toml", mode="w", delete=False
)
tmp = tempfile.NamedTemporaryFile(suffix=".toml", mode="w", delete=False)
tmp.write(config_content)
tmp.flush()
tmp.close()
Expand Down