diff --git a/src/auth/scopes.py b/src/auth/scopes.py index 06cb8aa..e26c515 100644 --- a/src/auth/scopes.py +++ b/src/auth/scopes.py @@ -44,6 +44,20 @@ class Scope(str, Enum): SECURITY_WRITE = "security:write" # Modify firewall rules (high risk) SECURITY_ADMIN = "security:admin" # Full security management + # Database scopes + DATABASE_READ = "database:read" # List backups + DATABASE_BACKUP = "database:backup" # Create backups + DATABASE_RESTORE = "database:restore" # Restore from backups + DATABASE_ADMIN = "database:admin" # Cleanup and management + + # Certificate scopes + CERTIFICATE_READ = "certificate:read" # View certificates + CERTIFICATE_ADMIN = "certificate:admin" # Request, renew, delete certificates + + # Metrics scopes + METRICS_READ = "metrics:read" # Export metrics + METRICS_ADMIN = "metrics:admin" # Write metrics to files + # Meta scopes ADMIN = "admin" # All permissions READ_ONLY = "readonly" # All read permissions @@ -294,6 +308,99 @@ class ToolScopeRequirement: risk_level="low", description="Get security scanner availability" ), + + # Database Tools + "backup_database": ToolScopeRequirement( + tool_name="backup_database", + required_scopes=[Scope.DATABASE_BACKUP], + risk_level="moderate", + description="Backup PostgreSQL or MySQL database" + ), + "restore_database": ToolScopeRequirement( + tool_name="restore_database", + required_scopes=[Scope.DATABASE_RESTORE], + risk_level="critical", + requires_approval=True, + description="Restore database from backup (destructive)" + ), + "list_database_backups": ToolScopeRequirement( + tool_name="list_database_backups", + required_scopes=[Scope.DATABASE_READ], + risk_level="low", + description="List available database backups" + ), + "cleanup_database_backups": ToolScopeRequirement( + tool_name="cleanup_database_backups", + required_scopes=[Scope.DATABASE_ADMIN], + risk_level="high", + description="Clean up old database backups" + ), + + # Certificate Tools + "check_ssl_certificate_status": ToolScopeRequirement( + tool_name="check_ssl_certificate_status", + required_scopes=[Scope.CERTIFICATE_READ], + risk_level="low", + description="Check SSL certificate status and expiration" + ), + "request_letsencrypt_certificate": ToolScopeRequirement( + tool_name="request_letsencrypt_certificate", + required_scopes=[Scope.CERTIFICATE_ADMIN], + risk_level="high", + requires_approval=True, + description="Obtain new Let's Encrypt certificate" + ), + "renew_letsencrypt_certificate": ToolScopeRequirement( + tool_name="renew_letsencrypt_certificate", + required_scopes=[Scope.CERTIFICATE_ADMIN], + risk_level="moderate", + description="Renew Let's Encrypt certificate" + ), + "list_letsencrypt_certificates": ToolScopeRequirement( + tool_name="list_letsencrypt_certificates", + required_scopes=[Scope.CERTIFICATE_READ], + risk_level="low", + description="List all Let's Encrypt certificates" + ), + "delete_letsencrypt_certificate": ToolScopeRequirement( + tool_name="delete_letsencrypt_certificate", + required_scopes=[Scope.CERTIFICATE_ADMIN], + risk_level="high", + requires_approval=True, + description="Delete Let's Encrypt certificate" + ), + "setup_certificate_auto_renewal": ToolScopeRequirement( + tool_name="setup_certificate_auto_renewal", + required_scopes=[Scope.CERTIFICATE_ADMIN], + risk_level="moderate", + description="Setup automatic certificate renewal" + ), + + # Metrics Tools + "export_prometheus_metrics": ToolScopeRequirement( + tool_name="export_prometheus_metrics", + required_scopes=[Scope.METRICS_READ], + risk_level="low", + description="Export Prometheus metrics" + ), + "get_prometheus_system_metrics": ToolScopeRequirement( + tool_name="get_prometheus_system_metrics", + required_scopes=[Scope.METRICS_READ], + risk_level="low", + description="Export system metrics in Prometheus format" + ), + "get_prometheus_docker_metrics": ToolScopeRequirement( + tool_name="get_prometheus_docker_metrics", + required_scopes=[Scope.METRICS_READ], + risk_level="low", + description="Export Docker metrics in Prometheus format" + ), + "write_metrics_textfile": ToolScopeRequirement( + tool_name="write_metrics_textfile", + required_scopes=[Scope.METRICS_ADMIN], + risk_level="moderate", + description="Write metrics to textfile for node_exporter" + ), } @@ -320,6 +427,9 @@ def expand_scopes(scopes: List[str]) -> Set[str]: Scope.CONTAINER_READ, Scope.FILE_READ, Scope.SECURITY_READ, + Scope.DATABASE_READ, + Scope.CERTIFICATE_READ, + Scope.METRICS_READ, ]) return expanded diff --git a/src/tools/__init__.py b/src/tools/__init__.py index 1ef9c89..ffb7778 100644 --- a/src/tools/__init__.py +++ b/src/tools/__init__.py @@ -19,6 +19,9 @@ def register_all_tools(mcp: FastMCP): image_tools, inventory_tools, security_tools, + database_tools, + certificate_tools, + metrics_tools, prompts, ) @@ -31,6 +34,9 @@ def register_all_tools(mcp: FastMCP): image_tools.register_tools(mcp) inventory_tools.register_tools(mcp) security_tools.register_tools(mcp) + database_tools.register_tools(mcp) + certificate_tools.register_tools(mcp) + metrics_tools.register_tools(mcp) prompts.register_prompts(mcp) logger.info("All MCP tools registered successfully") diff --git a/src/tools/certificate_tools.py b/src/tools/certificate_tools.py new file mode 100644 index 0000000..da2c8e6 --- /dev/null +++ b/src/tools/certificate_tools.py @@ -0,0 +1,815 @@ +"""Let's Encrypt certificate automation tools. + +Provides automated SSL certificate management using certbot/acme.sh +for homelab services. +""" + +from __future__ import annotations + +import asyncio +import json +import os +import logging +from typing import Dict, Any, Optional, List +from datetime import datetime, timedelta +from pathlib import Path +import subprocess +from asyncio import to_thread + +from fastmcp import FastMCP +from src.utils.audit import AuditLogger +from src.auth.middleware import secure_tool +from src.server.utils import format_error +from pydantic import BaseModel + + +audit = AuditLogger() +logger = logging.getLogger(__name__) + + +class CertificateInfo(BaseModel): + """Information about an SSL certificate.""" + domain: str + issuer: Optional[str] = None + valid_from: Optional[datetime] = None + valid_until: Optional[datetime] = None + days_remaining: Optional[int] = None + status: str = "unknown" # valid, expiring_soon, expired, not_found + cert_path: Optional[str] = None + key_path: Optional[str] = None + + +class CertificateRequest(BaseModel): + """Request to obtain a new certificate.""" + domain: str + additional_domains: Optional[List[str]] = [] + email: str + challenge_type: str = "http-01" # http-01, dns-01 + webroot_path: Optional[str] = None + dns_provider: Optional[str] = None + staging: bool = False # Use Let's Encrypt staging for testing + + +class CertificateResult(BaseModel): + """Result of certificate operation.""" + success: bool + domain: str + cert_path: Optional[str] = None + key_path: Optional[str] = None + fullchain_path: Optional[str] = None + error: Optional[str] = None + renewed: bool = False + + +async def check_certificate( + domain: str, + cert_path: Optional[str] = None +) -> Dict[str, Any]: + """Check certificate status and expiration. + + Args: + domain: Domain name to check + cert_path: Optional path to certificate file (auto-detect if not provided) + + Returns: + CertificateInfo with certificate status and metadata + """ + try: + # Auto-detect certbot path if not provided + if not cert_path: + # Try common certbot paths + cert_paths = [ + f"/etc/letsencrypt/live/{domain}/cert.pem", + f"/etc/letsencrypt/live/{domain}/fullchain.pem", + ] + for path in cert_paths: + if os.path.exists(path): + cert_path = path + break + + if not cert_path or not os.path.exists(cert_path): + # Try to get info from remote server + cmd = [ + "openssl", "s_client", + "-connect", f"{domain}:443", + "-servername", domain, + "-showcerts" + ] + + result = await to_thread( + subprocess.run, + cmd, + input="", + capture_output=True, + text=True, + timeout=10 + ) + + # Extract certificate from output + cert_text = result.stdout + if "BEGIN CERTIFICATE" not in cert_text: + return CertificateInfo( + domain=domain, + status="not_found" + ).dict() + + # Write to temp file for parsing + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.pem', delete=False) as f: + # Extract first certificate + start = cert_text.find("-----BEGIN CERTIFICATE-----") + end = cert_text.find("-----END CERTIFICATE-----") + len("-----END CERTIFICATE-----") + f.write(cert_text[start:end]) + temp_cert_path = f.name + + cert_path = temp_cert_path + + # Parse certificate details + cmd = [ + "openssl", "x509", + "-in", cert_path, + "-noout", + "-dates", + "-issuer", + "-subject" + ] + + result = await to_thread( + subprocess.run, + cmd, + capture_output=True, + text=True, + check=True + ) + + # Parse output + output = result.stdout + issuer = None + valid_from = None + valid_until = None + + for line in output.splitlines(): + if line.startswith("notBefore="): + date_str = line.replace("notBefore=", "").strip() + valid_from = datetime.strptime(date_str, "%b %d %H:%M:%S %Y %Z") + elif line.startswith("notAfter="): + date_str = line.replace("notAfter=", "").strip() + valid_until = datetime.strptime(date_str, "%b %d %H:%M:%S %Y %Z") + elif line.startswith("issuer="): + issuer = line.replace("issuer=", "").strip() + + # Calculate days remaining + days_remaining = None + status = "unknown" + if valid_until: + days_remaining = (valid_until - datetime.now()).days + if days_remaining < 0: + status = "expired" + elif days_remaining < 30: + status = "expiring_soon" + else: + status = "valid" + + # Determine cert and key paths + key_path = None + if cert_path.endswith("/cert.pem"): + key_path = cert_path.replace("/cert.pem", "/privkey.pem") + elif cert_path.endswith("/fullchain.pem"): + key_path = cert_path.replace("/fullchain.pem", "/privkey.pem") + + cert_info = CertificateInfo( + domain=domain, + issuer=issuer, + valid_from=valid_from, + valid_until=valid_until, + days_remaining=days_remaining, + status=status, + cert_path=cert_path if os.path.exists(cert_path) else None, + key_path=key_path if key_path and os.path.exists(key_path) else None + ) + + audit.log("check_certificate", { + "domain": domain + }, { + "success": True, + "status": status, + "days_remaining": days_remaining + }) + + return cert_info.dict() + + except Exception as e: + audit.log("check_certificate", { + "domain": domain + }, { + "success": False, + "error": str(e) + }) + return CertificateInfo( + domain=domain, + status="error" + ).dict() + + +async def obtain_certificate( + domain: str, + email: str, + additional_domains: Optional[List[str]] = None, + challenge_type: str = "http-01", + webroot_path: Optional[str] = None, + dns_provider: Optional[str] = None, + staging: bool = False, + dry_run: bool = False +) -> Dict[str, Any]: + """Obtain a new Let's Encrypt certificate. + + Args: + domain: Primary domain name + email: Email for certificate notifications + additional_domains: Additional domains for SAN certificate + challenge_type: Challenge type (http-01 or dns-01) + webroot_path: Path for HTTP challenge (required for http-01) + dns_provider: DNS provider for DNS challenge (required for dns-01) + staging: Use Let's Encrypt staging environment + dry_run: Test without actually obtaining certificate + + Returns: + CertificateResult with certificate paths and status + """ + try: + # Check if certbot is installed + which_result = await to_thread( + subprocess.run, + ["which", "certbot"], + capture_output=True, + text=True + ) + + if which_result.returncode != 0: + return CertificateResult( + success=False, + domain=domain, + error="certbot not installed. Install with: apt-get install certbot" + ).dict() + + # Build certbot command + cmd = ["certbot", "certonly"] + + # Add domains + cmd.extend(["-d", domain]) + if additional_domains: + for d in additional_domains: + cmd.extend(["-d", d]) + + # Add email + cmd.extend(["--email", email]) + cmd.append("--agree-tos") + cmd.append("--non-interactive") + + # Add challenge type + if challenge_type == "http-01": + if not webroot_path: + return CertificateResult( + success=False, + domain=domain, + error="webroot_path required for http-01 challenge" + ).dict() + cmd.extend(["--webroot", "-w", webroot_path]) + elif challenge_type == "dns-01": + if not dns_provider: + return CertificateResult( + success=False, + domain=domain, + error="dns_provider required for dns-01 challenge" + ).dict() + cmd.extend([f"--dns-{dns_provider}"]) + else: + return CertificateResult( + success=False, + domain=domain, + error=f"Invalid challenge_type: {challenge_type}" + ).dict() + + # Add staging if requested + if staging: + cmd.append("--staging") + + # Add dry-run if requested + if dry_run: + cmd.append("--dry-run") + + # Execute certbot + result = await to_thread( + subprocess.run, + cmd, + capture_output=True, + text=True, + check=True + ) + + # Determine certificate paths + cert_path = f"/etc/letsencrypt/live/{domain}/cert.pem" + key_path = f"/etc/letsencrypt/live/{domain}/privkey.pem" + fullchain_path = f"/etc/letsencrypt/live/{domain}/fullchain.pem" + + cert_result = CertificateResult( + success=True, + domain=domain, + cert_path=cert_path if not dry_run else None, + key_path=key_path if not dry_run else None, + fullchain_path=fullchain_path if not dry_run else None, + renewed=False + ) + + audit.log("obtain_certificate", { + "domain": domain, + "challenge_type": challenge_type, + "staging": staging, + "dry_run": dry_run + }, { + "success": True, + "cert_path": cert_path if not dry_run else None + }) + + return cert_result.dict() + + except subprocess.CalledProcessError as e: + error_msg = f"Certificate request failed: {e.stderr}" + audit.log("obtain_certificate", { + "domain": domain + }, { + "success": False, + "error": error_msg + }) + return CertificateResult( + success=False, + domain=domain, + error=error_msg + ).dict() + + except Exception as e: + error_msg = f"Certificate request failed: {str(e)}" + audit.log("obtain_certificate", { + "domain": domain + }, { + "success": False, + "error": error_msg + }) + return CertificateResult( + success=False, + domain=domain, + error=error_msg + ).dict() + + +async def renew_certificate( + domain: Optional[str] = None, + force: bool = False, + dry_run: bool = False +) -> Dict[str, Any]: + """Renew Let's Encrypt certificate(s). + + Args: + domain: Specific domain to renew (None for all) + force: Force renewal even if not expiring soon + dry_run: Test renewal without actually renewing + + Returns: + CertificateResult with renewal status + """ + try: + # Build certbot command + cmd = ["certbot", "renew"] + + if domain: + cmd.extend(["--cert-name", domain]) + + if force: + cmd.append("--force-renewal") + + if dry_run: + cmd.append("--dry-run") + + # Execute certbot + result = await to_thread( + subprocess.run, + cmd, + capture_output=True, + text=True, + check=True + ) + + # Parse output to determine if renewal occurred + renewed = "Renewing" in result.stdout or "renewed" in result.stdout + + cert_result = CertificateResult( + success=True, + domain=domain or "all", + renewed=renewed + ) + + audit.log("renew_certificate", { + "domain": domain, + "force": force, + "dry_run": dry_run + }, { + "success": True, + "renewed": renewed + }) + + return cert_result.dict() + + except subprocess.CalledProcessError as e: + error_msg = f"Certificate renewal failed: {e.stderr}" + audit.log("renew_certificate", { + "domain": domain + }, { + "success": False, + "error": error_msg + }) + return CertificateResult( + success=False, + domain=domain or "all", + error=error_msg + ).dict() + + except Exception as e: + error_msg = f"Certificate renewal failed: {str(e)}" + audit.log("renew_certificate", { + "domain": domain + }, { + "success": False, + "error": error_msg + }) + return CertificateResult( + success=False, + domain=domain or "all", + error=error_msg + ).dict() + + +async def list_certificates() -> List[Dict[str, Any]]: + """List all Let's Encrypt certificates managed by certbot. + + Returns: + List of certificates with their status + """ + certificates = [] + + try: + # Run certbot certificates command + result = await to_thread( + subprocess.run, + ["certbot", "certificates"], + capture_output=True, + text=True, + check=True + ) + + # Parse output + # Format is typically: + # Certificate Name: example.com + # Domains: example.com www.example.com + # Expiry Date: 2024-06-01 12:00:00+00:00 + # Certificate Path: /etc/letsencrypt/live/example.com/fullchain.pem + # Private Key Path: /etc/letsencrypt/live/example.com/privkey.pem + + current_cert = {} + for line in result.stdout.splitlines(): + line = line.strip() + if line.startswith("Certificate Name:"): + if current_cert: + certificates.append(current_cert) + current_cert = { + "name": line.split(":", 1)[1].strip() + } + elif line.startswith("Domains:"): + current_cert["domains"] = [d.strip() for d in line.split(":", 1)[1].split()] + elif line.startswith("Expiry Date:"): + expiry_str = line.split(":", 1)[1].strip() + try: + # Parse date + expiry = datetime.strptime(expiry_str.split("+")[0].strip(), "%Y-%m-%d %H:%M:%S") + days_remaining = (expiry - datetime.now()).days + current_cert["expiry_date"] = expiry.isoformat() + current_cert["days_remaining"] = days_remaining + current_cert["status"] = "valid" if days_remaining > 30 else "expiring_soon" + except: + pass + elif line.startswith("Certificate Path:"): + current_cert["cert_path"] = line.split(":", 1)[1].strip() + elif line.startswith("Private Key Path:"): + current_cert["key_path"] = line.split(":", 1)[1].strip() + + # Add last certificate + if current_cert: + certificates.append(current_cert) + + audit.log("list_certificates", {}, { + "success": True, + "count": len(certificates) + }) + + return certificates + + except subprocess.CalledProcessError as e: + audit.log("list_certificates", {}, { + "success": False, + "error": str(e) + }) + return [] + + except Exception as e: + audit.log("list_certificates", {}, { + "success": False, + "error": str(e) + }) + return [] + + +async def delete_certificate( + domain: str +) -> Dict[str, Any]: + """Delete a Let's Encrypt certificate. + + Args: + domain: Domain name of certificate to delete + + Returns: + Result with success status + """ + try: + # Run certbot delete command + result = await to_thread( + subprocess.run, + ["certbot", "delete", "--cert-name", domain], + capture_output=True, + text=True, + input="y\n", # Confirm deletion + check=True + ) + + audit.log("delete_certificate", { + "domain": domain + }, { + "success": True + }) + + return { + "success": True, + "domain": domain, + "message": f"Certificate for {domain} deleted successfully" + } + + except subprocess.CalledProcessError as e: + error_msg = f"Certificate deletion failed: {e.stderr}" + audit.log("delete_certificate", { + "domain": domain + }, { + "success": False, + "error": error_msg + }) + return { + "success": False, + "domain": domain, + "error": error_msg + } + + except Exception as e: + error_msg = f"Certificate deletion failed: {str(e)}" + audit.log("delete_certificate", { + "domain": domain + }, { + "success": False, + "error": error_msg + }) + return { + "success": False, + "domain": domain, + "error": error_msg + } + + +async def setup_auto_renewal() -> Dict[str, Any]: + """Setup automatic certificate renewal via cron/systemd timer. + + Returns: + Result with setup status + """ + try: + # Check if systemd timer exists (preferred method) + timer_check = await to_thread( + subprocess.run, + ["systemctl", "list-timers", "certbot.timer"], + capture_output=True, + text=True + ) + + if "certbot.timer" in timer_check.stdout: + # Timer already exists, ensure it's enabled + await to_thread( + subprocess.run, + ["systemctl", "enable", "certbot.timer"], + capture_output=True, + text=True, + check=True + ) + await to_thread( + subprocess.run, + ["systemctl", "start", "certbot.timer"], + capture_output=True, + text=True, + check=True + ) + + audit.log("setup_auto_renewal", {}, { + "success": True, + "method": "systemd-timer" + }) + + return { + "success": True, + "method": "systemd-timer", + "message": "Automatic renewal enabled via systemd timer" + } + + # Fall back to cron + cron_entry = "0 0,12 * * * root certbot renew --quiet" + cron_file = "/etc/cron.d/certbot" + + # Check if cron file exists + if os.path.exists(cron_file): + audit.log("setup_auto_renewal", {}, { + "success": True, + "method": "cron", + "already_exists": True + }) + + return { + "success": True, + "method": "cron", + "message": "Automatic renewal already configured via cron" + } + + # Create cron file + with open(cron_file, 'w') as f: + f.write(cron_entry + "\n") + + audit.log("setup_auto_renewal", {}, { + "success": True, + "method": "cron" + }) + + return { + "success": True, + "method": "cron", + "message": "Automatic renewal configured via cron" + } + + except Exception as e: + error_msg = f"Auto-renewal setup failed: {str(e)}" + audit.log("setup_auto_renewal", {}, { + "success": False, + "error": error_msg + }) + return { + "success": False, + "error": error_msg + } + + +def register_tools(mcp: FastMCP): + """Register Let's Encrypt certificate tools with MCP instance.""" + + @mcp.tool() + @secure_tool("certificate:read") + async def check_ssl_certificate_status( + domain: str, + cert_path: str = None + ) -> dict: + """Check SSL certificate status and expiration. + + Args: + domain: Domain name to check + cert_path: Optional path to certificate file (auto-detect if not provided) + + Returns: + CertificateInfo with certificate status and metadata + """ + try: + result = await check_certificate(domain=domain, cert_path=cert_path) + return result + except Exception as e: + return format_error(e, "check_ssl_certificate_status") + + @mcp.tool() + @secure_tool("certificate:admin") + async def request_letsencrypt_certificate( + domain: str, + email: str, + additional_domains: list = None, + challenge_type: str = "http-01", + webroot_path: str = None, + dns_provider: str = None, + staging: bool = False, + dry_run: bool = False + ) -> dict: + """Obtain a new Let's Encrypt certificate. + + Args: + domain: Primary domain name + email: Email for certificate notifications + additional_domains: Additional domains for SAN certificate + challenge_type: Challenge type (http-01 or dns-01, default: http-01) + webroot_path: Path for HTTP challenge (required for http-01) + dns_provider: DNS provider for DNS challenge (required for dns-01) + staging: Use Let's Encrypt staging environment (default: False) + dry_run: Test without actually obtaining certificate (default: False) + + Returns: + CertificateResult with certificate paths and status + """ + try: + result = await obtain_certificate( + domain=domain, + email=email, + additional_domains=additional_domains, + challenge_type=challenge_type, + webroot_path=webroot_path, + dns_provider=dns_provider, + staging=staging, + dry_run=dry_run + ) + return result + except Exception as e: + return format_error(e, "request_letsencrypt_certificate") + + @mcp.tool() + @secure_tool("certificate:admin") + async def renew_letsencrypt_certificate( + domain: str = None, + force: bool = False, + dry_run: bool = False + ) -> dict: + """Renew Let's Encrypt certificate(s). + + Args: + domain: Specific domain to renew (None for all) + force: Force renewal even if not expiring soon (default: False) + dry_run: Test renewal without actually renewing (default: False) + + Returns: + CertificateResult with renewal status + """ + try: + result = await renew_certificate(domain=domain, force=force, dry_run=dry_run) + return result + except Exception as e: + return format_error(e, "renew_letsencrypt_certificate") + + @mcp.tool() + @secure_tool("certificate:read") + async def list_letsencrypt_certificates() -> list: + """List all Let's Encrypt certificates managed by certbot. + + Returns: + List of certificates with their status + """ + try: + result = await list_certificates() + return result + except Exception as e: + return format_error(e, "list_letsencrypt_certificates") + + @mcp.tool() + @secure_tool("certificate:admin") + async def delete_letsencrypt_certificate(domain: str) -> dict: + """Delete a Let's Encrypt certificate. + + Args: + domain: Domain name of certificate to delete + + Returns: + Result with success status + """ + try: + result = await delete_certificate(domain=domain) + return result + except Exception as e: + return format_error(e, "delete_letsencrypt_certificate") + + @mcp.tool() + @secure_tool("certificate:admin") + async def setup_certificate_auto_renewal() -> dict: + """Setup automatic certificate renewal via cron/systemd timer. + + Returns: + Result with setup status + """ + try: + result = await setup_auto_renewal() + return result + except Exception as e: + return format_error(e, "setup_certificate_auto_renewal") + + logger.info("Registered 6 certificate tools") diff --git a/src/tools/database_tools.py b/src/tools/database_tools.py new file mode 100644 index 0000000..a2ad36e --- /dev/null +++ b/src/tools/database_tools.py @@ -0,0 +1,931 @@ +"""Database backup and restore tools for PostgreSQL and MySQL. + +Provides automated backup, restore, and scheduling functionality for +homelab database instances. +""" + +from __future__ import annotations + +import asyncio +import json +import os +import logging +from typing import Dict, Any, Optional, List +from datetime import datetime +from pathlib import Path +import subprocess +from asyncio import to_thread + +from fastmcp import FastMCP +from src.utils.audit import AuditLogger +from src.auth.middleware import secure_tool +from src.server.utils import format_error +from pydantic import BaseModel + + +audit = AuditLogger() +logger = logging.getLogger(__name__) + + +class BackupConfig(BaseModel): + """Configuration for database backups.""" + database_type: str # postgresql, mysql + host: str = "localhost" + port: Optional[int] = None + database: Optional[str] = None # None means all databases + username: str + password: Optional[str] = None + backup_path: str = "/var/backups/databases" + retention_days: int = 7 + compress: bool = True + + +class BackupResult(BaseModel): + """Result of a backup operation.""" + success: bool + backup_file: Optional[str] = None + size_bytes: Optional[int] = None + duration_seconds: Optional[float] = None + error: Optional[str] = None + timestamp: datetime = None + + +class RestoreResult(BaseModel): + """Result of a restore operation.""" + success: bool + restored_database: Optional[str] = None + error: Optional[str] = None + duration_seconds: Optional[float] = None + + +async def backup_postgresql( + host: str = "localhost", + port: int = 5432, + database: Optional[str] = None, + username: str = "postgres", + password: Optional[str] = None, + backup_path: str = "/var/backups/databases", + compress: bool = True +) -> Dict[str, Any]: + """Backup PostgreSQL database(s). + + Args: + host: Database host + port: Database port + database: Specific database to backup (None for all) + username: Database username + password: Database password (uses PGPASSWORD env var) + backup_path: Path to store backups + compress: Whether to compress the backup + + Returns: + BackupResult with backup file path and metadata + """ + start_time = datetime.now() + timestamp = start_time.strftime("%Y%m%d_%H%M%S") + + # Ensure backup directory exists + Path(backup_path).mkdir(parents=True, exist_ok=True) + + # Build backup filename + db_name = database or "all_databases" + backup_file = f"{backup_path}/postgresql_{db_name}_{timestamp}.sql" + if compress: + backup_file += ".gz" + + try: + # Set up environment + env = os.environ.copy() + if password: + env["PGPASSWORD"] = password + + # Build command + if database: + # Backup single database + cmd = [ + "pg_dump", + "-h", host, + "-p", str(port), + "-U", username, + "-F", "c" if not compress else "c", # custom format + "-f", backup_file, + database + ] + else: + # Backup all databases + cmd = [ + "pg_dumpall", + "-h", host, + "-p", str(port), + "-U", username, + "-f", backup_file + ] + + # Add compression if needed + if compress and not database: # pg_dumpall doesn't support -F c + cmd_str = " ".join(cmd) + f" | gzip > {backup_file}" + result = await to_thread( + subprocess.run, + cmd_str, + shell=True, + env=env, + capture_output=True, + text=True, + check=True + ) + else: + result = await to_thread( + subprocess.run, + cmd, + env=env, + capture_output=True, + text=True, + check=True + ) + + # Get backup file size + size_bytes = os.path.getsize(backup_file) + duration = (datetime.now() - start_time).total_seconds() + + backup_result = BackupResult( + success=True, + backup_file=backup_file, + size_bytes=size_bytes, + duration_seconds=duration, + timestamp=start_time + ) + + audit.log("backup_postgresql", { + "host": host, + "database": database, + "backup_file": backup_file + }, { + "success": True, + "size_bytes": size_bytes, + "duration_seconds": duration + }) + + return backup_result.dict() + + except subprocess.CalledProcessError as e: + error_msg = f"Backup failed: {e.stderr}" + backup_result = BackupResult( + success=False, + error=error_msg, + timestamp=start_time + ) + audit.log("backup_postgresql", { + "host": host, + "database": database + }, { + "success": False, + "error": error_msg + }) + return backup_result.dict() + + except Exception as e: + error_msg = f"Backup failed: {str(e)}" + backup_result = BackupResult( + success=False, + error=error_msg, + timestamp=start_time + ) + audit.log("backup_postgresql", { + "host": host, + "database": database + }, { + "success": False, + "error": error_msg + }) + return backup_result.dict() + + +async def backup_mysql( + host: str = "localhost", + port: int = 3306, + database: Optional[str] = None, + username: str = "root", + password: Optional[str] = None, + backup_path: str = "/var/backups/databases", + compress: bool = True +) -> Dict[str, Any]: + """Backup MySQL/MariaDB database(s). + + Args: + host: Database host + port: Database port + database: Specific database to backup (None for all) + username: Database username + password: Database password + backup_path: Path to store backups + compress: Whether to compress the backup + + Returns: + BackupResult with backup file path and metadata + """ + start_time = datetime.now() + timestamp = start_time.strftime("%Y%m%d_%H%M%S") + + # Ensure backup directory exists + Path(backup_path).mkdir(parents=True, exist_ok=True) + + # Build backup filename + db_name = database or "all_databases" + backup_file = f"{backup_path}/mysql_{db_name}_{timestamp}.sql" + if compress: + backup_file += ".gz" + + try: + # Build command + cmd = [ + "mysqldump", + "-h", host, + "-P", str(port), + "-u", username, + ] + + if password: + cmd.extend([f"-p{password}"]) + + # Add options + cmd.extend([ + "--single-transaction", + "--quick", + "--lock-tables=false", + ]) + + if database: + cmd.append(database) + else: + cmd.append("--all-databases") + + # Execute backup with optional compression + if compress: + cmd_str = " ".join(cmd) + f" | gzip > {backup_file}" + result = await to_thread( + subprocess.run, + cmd_str, + shell=True, + capture_output=True, + text=True, + check=True + ) + else: + cmd.extend(["-r", backup_file]) + result = await to_thread( + subprocess.run, + cmd, + capture_output=True, + text=True, + check=True + ) + + # Get backup file size + size_bytes = os.path.getsize(backup_file) + duration = (datetime.now() - start_time).total_seconds() + + backup_result = BackupResult( + success=True, + backup_file=backup_file, + size_bytes=size_bytes, + duration_seconds=duration, + timestamp=start_time + ) + + audit.log("backup_mysql", { + "host": host, + "database": database, + "backup_file": backup_file + }, { + "success": True, + "size_bytes": size_bytes, + "duration_seconds": duration + }) + + return backup_result.dict() + + except subprocess.CalledProcessError as e: + error_msg = f"Backup failed: {e.stderr}" + backup_result = BackupResult( + success=False, + error=error_msg, + timestamp=start_time + ) + audit.log("backup_mysql", { + "host": host, + "database": database + }, { + "success": False, + "error": error_msg + }) + return backup_result.dict() + + except Exception as e: + error_msg = f"Backup failed: {str(e)}" + backup_result = BackupResult( + success=False, + error=error_msg, + timestamp=start_time + ) + audit.log("backup_mysql", { + "host": host, + "database": database + }, { + "success": False, + "error": error_msg + }) + return backup_result.dict() + + +async def restore_postgresql( + backup_file: str, + host: str = "localhost", + port: int = 5432, + database: Optional[str] = None, + username: str = "postgres", + password: Optional[str] = None, + drop_existing: bool = False +) -> Dict[str, Any]: + """Restore PostgreSQL database from backup. + + Args: + backup_file: Path to backup file + host: Database host + port: Database port + database: Target database name (for single DB restores) + username: Database username + password: Database password + drop_existing: Whether to drop existing database before restore + + Returns: + RestoreResult with success status and metadata + """ + start_time = datetime.now() + + if not os.path.exists(backup_file): + return RestoreResult( + success=False, + error=f"Backup file not found: {backup_file}" + ).dict() + + try: + env = os.environ.copy() + if password: + env["PGPASSWORD"] = password + + # Determine if it's a custom format or SQL dump + is_custom_format = backup_file.endswith(".dump") or not backup_file.endswith((".sql", ".sql.gz")) + + if is_custom_format and database: + # Restore custom format to specific database + cmd = [ + "pg_restore", + "-h", host, + "-p", str(port), + "-U", username, + "-d", database, + ] + if drop_existing: + cmd.append("-c") # Clean (drop) database objects before recreating + cmd.append(backup_file) + + elif backup_file.endswith(".gz"): + # Restore gzipped SQL dump + if database: + cmd_str = f"gunzip -c {backup_file} | psql -h {host} -p {port} -U {username} -d {database}" + else: + cmd_str = f"gunzip -c {backup_file} | psql -h {host} -p {port} -U {username}" + + result = await to_thread( + subprocess.run, + cmd_str, + shell=True, + env=env, + capture_output=True, + text=True, + check=True + ) + duration = (datetime.now() - start_time).total_seconds() + restore_result = RestoreResult( + success=True, + restored_database=database or "all", + duration_seconds=duration + ) + audit.log("restore_postgresql", { + "backup_file": backup_file, + "database": database + }, { + "success": True, + "duration_seconds": duration + }) + return restore_result.dict() + else: + # Restore plain SQL dump + cmd = [ + "psql", + "-h", host, + "-p", str(port), + "-U", username, + ] + if database: + cmd.extend(["-d", database]) + cmd.extend(["-f", backup_file]) + + result = await to_thread( + subprocess.run, + cmd, + env=env, + capture_output=True, + text=True, + check=True + ) + + duration = (datetime.now() - start_time).total_seconds() + + restore_result = RestoreResult( + success=True, + restored_database=database or "all", + duration_seconds=duration + ) + + audit.log("restore_postgresql", { + "backup_file": backup_file, + "database": database + }, { + "success": True, + "duration_seconds": duration + }) + + return restore_result.dict() + + except subprocess.CalledProcessError as e: + error_msg = f"Restore failed: {e.stderr}" + restore_result = RestoreResult( + success=False, + error=error_msg + ) + audit.log("restore_postgresql", { + "backup_file": backup_file, + "database": database + }, { + "success": False, + "error": error_msg + }) + return restore_result.dict() + + except Exception as e: + error_msg = f"Restore failed: {str(e)}" + restore_result = RestoreResult( + success=False, + error=error_msg + ) + audit.log("restore_postgresql", { + "backup_file": backup_file, + "database": database + }, { + "success": False, + "error": error_msg + }) + return restore_result.dict() + + +async def restore_mysql( + backup_file: str, + host: str = "localhost", + port: int = 3306, + database: Optional[str] = None, + username: str = "root", + password: Optional[str] = None +) -> Dict[str, Any]: + """Restore MySQL/MariaDB database from backup. + + Args: + backup_file: Path to backup file + host: Database host + port: Database port + database: Target database name + username: Database username + password: Database password + + Returns: + RestoreResult with success status and metadata + """ + start_time = datetime.now() + + if not os.path.exists(backup_file): + return RestoreResult( + success=False, + error=f"Backup file not found: {backup_file}" + ).dict() + + try: + cmd = [ + "mysql", + "-h", host, + "-P", str(port), + "-u", username, + ] + + if password: + cmd.append(f"-p{password}") + + if database: + cmd.append(database) + + # Handle compressed backups + if backup_file.endswith(".gz"): + cmd_str = " ".join(cmd) + f" < <(gunzip -c {backup_file})" + result = await to_thread( + subprocess.run, + cmd_str, + shell=True, + executable="/bin/bash", + capture_output=True, + text=True, + check=True + ) + else: + cmd_str = " ".join(cmd) + f" < {backup_file}" + result = await to_thread( + subprocess.run, + cmd_str, + shell=True, + capture_output=True, + text=True, + check=True + ) + + duration = (datetime.now() - start_time).total_seconds() + + restore_result = RestoreResult( + success=True, + restored_database=database or "all", + duration_seconds=duration + ) + + audit.log("restore_mysql", { + "backup_file": backup_file, + "database": database + }, { + "success": True, + "duration_seconds": duration + }) + + return restore_result.dict() + + except subprocess.CalledProcessError as e: + error_msg = f"Restore failed: {e.stderr}" + restore_result = RestoreResult( + success=False, + error=error_msg + ) + audit.log("restore_mysql", { + "backup_file": backup_file, + "database": database + }, { + "success": False, + "error": error_msg + }) + return restore_result.dict() + + except Exception as e: + error_msg = f"Restore failed: {str(e)}" + restore_result = RestoreResult( + success=False, + error=error_msg + ) + audit.log("restore_mysql", { + "backup_file": backup_file, + "database": database + }, { + "success": False, + "error": error_msg + }) + return restore_result.dict() + + +async def list_backups( + backup_path: str = "/var/backups/databases", + database_type: Optional[str] = None +) -> List[Dict[str, Any]]: + """List available database backups. + + Args: + backup_path: Path to backup directory + database_type: Filter by database type (postgresql, mysql) + + Returns: + List of backup files with metadata + """ + backups = [] + + try: + if not os.path.exists(backup_path): + return backups + + for file in os.listdir(backup_path): + file_path = os.path.join(backup_path, file) + + # Filter by database type if specified + if database_type: + if database_type == "postgresql" and not file.startswith("postgresql_"): + continue + if database_type == "mysql" and not file.startswith("mysql_"): + continue + + # Get file metadata + stat = os.stat(file_path) + backups.append({ + "filename": file, + "path": file_path, + "size_bytes": stat.st_size, + "created_at": datetime.fromtimestamp(stat.st_ctime).isoformat(), + "modified_at": datetime.fromtimestamp(stat.st_mtime).isoformat() + }) + + # Sort by creation time (newest first) + backups.sort(key=lambda x: x["created_at"], reverse=True) + + audit.log("list_backups", { + "backup_path": backup_path, + "database_type": database_type + }, { + "success": True, + "count": len(backups) + }) + + return backups + + except Exception as e: + audit.log("list_backups", { + "backup_path": backup_path, + "database_type": database_type + }, { + "success": False, + "error": str(e) + }) + return backups + + +async def cleanup_old_backups( + backup_path: str = "/var/backups/databases", + retention_days: int = 7, + database_type: Optional[str] = None, + dry_run: bool = True +) -> Dict[str, Any]: + """Clean up old database backups based on retention policy. + + Args: + backup_path: Path to backup directory + retention_days: Number of days to retain backups + database_type: Filter by database type (postgresql, mysql) + dry_run: If True, only list files that would be deleted + + Returns: + Dict with cleanup results + """ + from datetime import timedelta + + deleted_files = [] + errors = [] + total_space_freed = 0 + + try: + if not os.path.exists(backup_path): + return { + "success": True, + "dry_run": dry_run, + "deleted_count": 0, + "space_freed_bytes": 0, + "deleted_files": [] + } + + cutoff_time = datetime.now() - timedelta(days=retention_days) + + for file in os.listdir(backup_path): + file_path = os.path.join(backup_path, file) + + # Filter by database type if specified + if database_type: + if database_type == "postgresql" and not file.startswith("postgresql_"): + continue + if database_type == "mysql" and not file.startswith("mysql_"): + continue + + # Check file age + stat = os.stat(file_path) + file_time = datetime.fromtimestamp(stat.st_mtime) + + if file_time < cutoff_time: + if not dry_run: + try: + os.remove(file_path) + total_space_freed += stat.st_size + deleted_files.append(file) + except Exception as e: + errors.append(f"Failed to delete {file}: {str(e)}") + else: + total_space_freed += stat.st_size + deleted_files.append(file) + + result = { + "success": len(errors) == 0, + "dry_run": dry_run, + "deleted_count": len(deleted_files), + "space_freed_bytes": total_space_freed, + "deleted_files": deleted_files, + "errors": errors + } + + audit.log("cleanup_old_backups", { + "backup_path": backup_path, + "retention_days": retention_days, + "database_type": database_type, + "dry_run": dry_run + }, result) + + return result + + except Exception as e: + audit.log("cleanup_old_backups", { + "backup_path": backup_path, + "retention_days": retention_days + }, { + "success": False, + "error": str(e) + }) + return { + "success": False, + "error": str(e), + "deleted_files": deleted_files, + "errors": errors + } + + +def register_tools(mcp: FastMCP): + """Register database backup and restore tools with MCP instance.""" + + @mcp.tool() + @secure_tool("database:backup") + async def backup_database( + database_type: str, + host: str = "localhost", + port: int = None, + database: str = None, + username: str = None, + password: str = None, + backup_path: str = "/var/backups/databases", + compress: bool = True + ) -> dict: + """Backup PostgreSQL or MySQL database. + + Args: + database_type: Database type (postgresql or mysql) + host: Database host (default: localhost) + port: Database port (default: 5432 for PostgreSQL, 3306 for MySQL) + database: Specific database to backup (None for all databases) + username: Database username (default: postgres/root) + password: Database password (optional, will use environment variable if not provided) + backup_path: Path to store backups (default: /var/backups/databases) + compress: Whether to compress the backup (default: True) + + Returns: + BackupResult with backup file path and metadata + """ + try: + if database_type == "postgresql": + result = await backup_postgresql( + host=host, + port=port or 5432, + database=database, + username=username or "postgres", + password=password, + backup_path=backup_path, + compress=compress + ) + elif database_type == "mysql": + result = await backup_mysql( + host=host, + port=port or 3306, + database=database, + username=username or "root", + password=password, + backup_path=backup_path, + compress=compress + ) + else: + return {"error": f"Invalid database_type: {database_type}. Must be 'postgresql' or 'mysql'"} + + return result + except Exception as e: + return format_error(e, "backup_database") + + @mcp.tool() + @secure_tool("database:restore") + async def restore_database( + database_type: str, + backup_file: str, + host: str = "localhost", + port: int = None, + database: str = None, + username: str = None, + password: str = None, + drop_existing: bool = False + ) -> dict: + """Restore PostgreSQL or MySQL database from backup. + + Args: + database_type: Database type (postgresql or mysql) + backup_file: Path to backup file + host: Database host (default: localhost) + port: Database port (default: 5432 for PostgreSQL, 3306 for MySQL) + database: Target database name (required for single database restores) + username: Database username (default: postgres/root) + password: Database password (optional) + drop_existing: Whether to drop existing database before restore (PostgreSQL only) + + Returns: + RestoreResult with success status and metadata + """ + try: + if database_type == "postgresql": + result = await restore_postgresql( + backup_file=backup_file, + host=host, + port=port or 5432, + database=database, + username=username or "postgres", + password=password, + drop_existing=drop_existing + ) + elif database_type == "mysql": + result = await restore_mysql( + backup_file=backup_file, + host=host, + port=port or 3306, + database=database, + username=username or "root", + password=password + ) + else: + return {"error": f"Invalid database_type: {database_type}. Must be 'postgresql' or 'mysql'"} + + return result + except Exception as e: + return format_error(e, "restore_database") + + @mcp.tool() + @secure_tool("database:read") + async def list_database_backups( + backup_path: str = "/var/backups/databases", + database_type: str = None + ) -> list: + """List available database backups. + + Args: + backup_path: Path to backup directory (default: /var/backups/databases) + database_type: Filter by database type (postgresql or mysql, optional) + + Returns: + List of backup files with metadata + """ + try: + result = await list_backups( + backup_path=backup_path, + database_type=database_type + ) + return result + except Exception as e: + return format_error(e, "list_database_backups") + + @mcp.tool() + @secure_tool("database:admin") + async def cleanup_database_backups( + backup_path: str = "/var/backups/databases", + retention_days: int = 7, + database_type: str = None, + dry_run: bool = True + ) -> dict: + """Clean up old database backups based on retention policy. + + Args: + backup_path: Path to backup directory (default: /var/backups/databases) + retention_days: Number of days to retain backups (default: 7) + database_type: Filter by database type (postgresql or mysql, optional) + dry_run: If True, only list files that would be deleted (default: True) + + Returns: + Dict with cleanup results including deleted files and space freed + """ + try: + result = await cleanup_old_backups( + backup_path=backup_path, + retention_days=retention_days, + database_type=database_type, + dry_run=dry_run + ) + return result + except Exception as e: + return format_error(e, "cleanup_database_backups") + + logger.info("Registered 4 database tools") diff --git a/src/tools/metrics_tools.py b/src/tools/metrics_tools.py new file mode 100644 index 0000000..ad13a61 --- /dev/null +++ b/src/tools/metrics_tools.py @@ -0,0 +1,679 @@ +"""Prometheus metrics export tools. + +Provides Prometheus-compatible metrics endpoints for homelab monitoring +integration. +""" + +from __future__ import annotations + +import asyncio +import json +import os +import logging +import psutil +from typing import Dict, Any, Optional, List +from datetime import datetime +import subprocess +from asyncio import to_thread + +from fastmcp import FastMCP +from src.utils.audit import AuditLogger +from src.auth.middleware import secure_tool +from src.server.utils import format_error + + +audit = AuditLogger() +logger = logging.getLogger(__name__) + + +class PrometheusMetric: + """Helper class to format Prometheus metrics.""" + + @staticmethod + def gauge(name: str, value: float, labels: Optional[Dict[str, str]] = None, help_text: Optional[str] = None) -> str: + """Format a gauge metric.""" + lines = [] + if help_text: + lines.append(f"# HELP {name} {help_text}") + lines.append(f"# TYPE {name} gauge") + + if labels: + label_str = ",".join([f'{k}="{v}"' for k, v in labels.items()]) + lines.append(f"{name}{{{label_str}}} {value}") + else: + lines.append(f"{name} {value}") + + return "\n".join(lines) + + @staticmethod + def counter(name: str, value: float, labels: Optional[Dict[str, str]] = None, help_text: Optional[str] = None) -> str: + """Format a counter metric.""" + lines = [] + if help_text: + lines.append(f"# HELP {name} {help_text}") + lines.append(f"# TYPE {name} counter") + + if labels: + label_str = ",".join([f'{k}="{v}"' for k, v in labels.items()]) + lines.append(f"{name}{{{label_str}}} {value}") + else: + lines.append(f"{name} {value}") + + return "\n".join(lines) + + @staticmethod + def histogram(name: str, buckets: Dict[float, int], sum_value: float, count: int, labels: Optional[Dict[str, str]] = None, help_text: Optional[str] = None) -> str: + """Format a histogram metric.""" + lines = [] + if help_text: + lines.append(f"# HELP {name} {help_text}") + lines.append(f"# TYPE {name} histogram") + + label_base = "" + if labels: + label_base = ",".join([f'{k}="{v}"' for k, v in labels.items()]) + + for le, count_val in sorted(buckets.items()): + bucket_labels = f"{label_base},le=\"{le}\"" if label_base else f"le=\"{le}\"" + lines.append(f"{name}_bucket{{{bucket_labels}}} {count_val}") + + # Add +Inf bucket + inf_labels = f"{label_base},le=\"+Inf\"" if label_base else "le=\"+Inf\"" + lines.append(f"{name}_bucket{{{inf_labels}}} {count}") + + # Add sum and count + if label_base: + lines.append(f"{name}_sum{{{label_base}}} {sum_value}") + lines.append(f"{name}_count{{{label_base}}} {count}") + else: + lines.append(f"{name}_sum {sum_value}") + lines.append(f"{name}_count {count}") + + return "\n".join(lines) + + +async def get_system_metrics() -> str: + """Export system metrics in Prometheus format. + + Returns: + Prometheus-formatted metrics text + """ + metrics = [] + + try: + # CPU metrics + cpu_percent = psutil.cpu_percent(interval=1) + cpu_count = psutil.cpu_count() + cpu_freq = psutil.cpu_freq() + + metrics.append(PrometheusMetric.gauge( + "node_cpu_usage_percent", + cpu_percent, + help_text="CPU usage percentage" + )) + metrics.append(PrometheusMetric.gauge( + "node_cpu_count", + cpu_count, + help_text="Number of CPU cores" + )) + if cpu_freq: + metrics.append(PrometheusMetric.gauge( + "node_cpu_frequency_mhz", + cpu_freq.current, + help_text="Current CPU frequency in MHz" + )) + + # Per-CPU metrics + cpu_percents = psutil.cpu_percent(interval=1, percpu=True) + for i, percent in enumerate(cpu_percents): + metrics.append(PrometheusMetric.gauge( + "node_cpu_usage_percent", + percent, + labels={"cpu": str(i)}, + help_text="Per-CPU usage percentage" + )) + + # Memory metrics + mem = psutil.virtual_memory() + metrics.append(PrometheusMetric.gauge( + "node_memory_total_bytes", + mem.total, + help_text="Total physical memory" + )) + metrics.append(PrometheusMetric.gauge( + "node_memory_available_bytes", + mem.available, + help_text="Available memory" + )) + metrics.append(PrometheusMetric.gauge( + "node_memory_used_bytes", + mem.used, + help_text="Used memory" + )) + metrics.append(PrometheusMetric.gauge( + "node_memory_usage_percent", + mem.percent, + help_text="Memory usage percentage" + )) + + # Swap metrics + swap = psutil.swap_memory() + metrics.append(PrometheusMetric.gauge( + "node_swap_total_bytes", + swap.total, + help_text="Total swap memory" + )) + metrics.append(PrometheusMetric.gauge( + "node_swap_used_bytes", + swap.used, + help_text="Used swap memory" + )) + metrics.append(PrometheusMetric.gauge( + "node_swap_usage_percent", + swap.percent, + help_text="Swap usage percentage" + )) + + # Disk metrics + for partition in psutil.disk_partitions(): + try: + usage = psutil.disk_usage(partition.mountpoint) + labels = { + "device": partition.device, + "mountpoint": partition.mountpoint, + "fstype": partition.fstype + } + + metrics.append(PrometheusMetric.gauge( + "node_disk_total_bytes", + usage.total, + labels=labels, + help_text="Total disk space" + )) + metrics.append(PrometheusMetric.gauge( + "node_disk_used_bytes", + usage.used, + labels=labels, + help_text="Used disk space" + )) + metrics.append(PrometheusMetric.gauge( + "node_disk_free_bytes", + usage.free, + labels=labels, + help_text="Free disk space" + )) + metrics.append(PrometheusMetric.gauge( + "node_disk_usage_percent", + usage.percent, + labels=labels, + help_text="Disk usage percentage" + )) + except: + continue + + # Disk I/O metrics + disk_io = psutil.disk_io_counters() + if disk_io: + metrics.append(PrometheusMetric.counter( + "node_disk_read_bytes_total", + disk_io.read_bytes, + help_text="Total bytes read from disk" + )) + metrics.append(PrometheusMetric.counter( + "node_disk_write_bytes_total", + disk_io.write_bytes, + help_text="Total bytes written to disk" + )) + metrics.append(PrometheusMetric.counter( + "node_disk_reads_total", + disk_io.read_count, + help_text="Total read operations" + )) + metrics.append(PrometheusMetric.counter( + "node_disk_writes_total", + disk_io.write_count, + help_text="Total write operations" + )) + + # Network metrics + net_io = psutil.net_io_counters(pernic=True) + for interface, stats in net_io.items(): + labels = {"interface": interface} + + metrics.append(PrometheusMetric.counter( + "node_network_receive_bytes_total", + stats.bytes_recv, + labels=labels, + help_text="Total bytes received" + )) + metrics.append(PrometheusMetric.counter( + "node_network_transmit_bytes_total", + stats.bytes_sent, + labels=labels, + help_text="Total bytes transmitted" + )) + metrics.append(PrometheusMetric.counter( + "node_network_receive_packets_total", + stats.packets_recv, + labels=labels, + help_text="Total packets received" + )) + metrics.append(PrometheusMetric.counter( + "node_network_transmit_packets_total", + stats.packets_sent, + labels=labels, + help_text="Total packets transmitted" + )) + metrics.append(PrometheusMetric.counter( + "node_network_receive_errors_total", + stats.errin, + labels=labels, + help_text="Total receive errors" + )) + metrics.append(PrometheusMetric.counter( + "node_network_transmit_errors_total", + stats.errout, + labels=labels, + help_text="Total transmit errors" + )) + + # Load average + load_avg = os.getloadavg() + metrics.append(PrometheusMetric.gauge( + "node_load1", + load_avg[0], + help_text="1-minute load average" + )) + metrics.append(PrometheusMetric.gauge( + "node_load5", + load_avg[1], + help_text="5-minute load average" + )) + metrics.append(PrometheusMetric.gauge( + "node_load15", + load_avg[2], + help_text="15-minute load average" + )) + + # Boot time + boot_time = psutil.boot_time() + metrics.append(PrometheusMetric.gauge( + "node_boot_time_seconds", + boot_time, + help_text="System boot time in Unix time" + )) + + # Uptime + uptime_seconds = datetime.now().timestamp() - boot_time + metrics.append(PrometheusMetric.gauge( + "node_uptime_seconds", + uptime_seconds, + help_text="System uptime in seconds" + )) + + audit.log("get_system_metrics", {}, {"success": True}) + + return "\n".join(metrics) + "\n" + + except Exception as e: + audit.log("get_system_metrics", {}, { + "success": False, + "error": str(e) + }) + return f"# Error collecting metrics: {str(e)}\n" + + +async def get_docker_metrics() -> str: + """Export Docker container metrics in Prometheus format. + + Returns: + Prometheus-formatted metrics text + """ + metrics = [] + + try: + # Try using docker SDK first + try: + import docker + client = docker.from_env() + containers = client.containers.list(all=True) + + for container in containers: + try: + labels = { + "name": container.name, + "id": container.id[:12], + "image": container.image.tags[0] if container.image.tags else "unknown" + } + + # Container state + state_value = 1 if container.status == "running" else 0 + metrics.append(PrometheusMetric.gauge( + "docker_container_running", + state_value, + labels=labels, + help_text="Container running status (1=running, 0=stopped)" + )) + + # Get stats if running + if container.status == "running": + stats = container.stats(stream=False) + + # CPU usage + cpu_delta = stats["cpu_stats"]["cpu_usage"]["total_usage"] - stats["precpu_stats"]["cpu_usage"]["total_usage"] + system_delta = stats["cpu_stats"]["system_cpu_usage"] - stats["precpu_stats"]["system_cpu_usage"] + cpu_percent = (cpu_delta / system_delta) * len(stats["cpu_stats"]["cpu_usage"]["percpu_usage"]) * 100.0 if system_delta > 0 else 0 + + metrics.append(PrometheusMetric.gauge( + "docker_container_cpu_usage_percent", + cpu_percent, + labels=labels, + help_text="Container CPU usage percentage" + )) + + # Memory usage + mem_usage = stats["memory_stats"].get("usage", 0) + mem_limit = stats["memory_stats"].get("limit", 0) + mem_percent = (mem_usage / mem_limit) * 100.0 if mem_limit > 0 else 0 + + metrics.append(PrometheusMetric.gauge( + "docker_container_memory_usage_bytes", + mem_usage, + labels=labels, + help_text="Container memory usage in bytes" + )) + metrics.append(PrometheusMetric.gauge( + "docker_container_memory_limit_bytes", + mem_limit, + labels=labels, + help_text="Container memory limit in bytes" + )) + metrics.append(PrometheusMetric.gauge( + "docker_container_memory_usage_percent", + mem_percent, + labels=labels, + help_text="Container memory usage percentage" + )) + + # Network I/O + networks = stats.get("networks", {}) + for net_name, net_stats in networks.items(): + net_labels = {**labels, "network": net_name} + + metrics.append(PrometheusMetric.counter( + "docker_container_network_receive_bytes_total", + net_stats.get("rx_bytes", 0), + labels=net_labels, + help_text="Container network bytes received" + )) + metrics.append(PrometheusMetric.counter( + "docker_container_network_transmit_bytes_total", + net_stats.get("tx_bytes", 0), + labels=net_labels, + help_text="Container network bytes transmitted" + )) + + except Exception as e: + continue + + except ImportError: + # Fall back to docker CLI + result = await to_thread( + subprocess.run, + ["docker", "ps", "-a", "--format", "{{json .}}"], + capture_output=True, + text=True, + check=True + ) + + for line in result.stdout.splitlines(): + if not line.strip(): + continue + + container_info = json.loads(line) + name = container_info.get("Names", "unknown") + container_id = container_info.get("ID", "unknown") + image = container_info.get("Image", "unknown") + status = container_info.get("State", "unknown") + + labels = { + "name": name, + "id": container_id, + "image": image + } + + state_value = 1 if status == "running" else 0 + metrics.append(PrometheusMetric.gauge( + "docker_container_running", + state_value, + labels=labels, + help_text="Container running status (1=running, 0=stopped)" + )) + + audit.log("get_docker_metrics", {}, {"success": True}) + + return "\n".join(metrics) + "\n" if metrics else "# No Docker containers found\n" + + except Exception as e: + audit.log("get_docker_metrics", {}, { + "success": False, + "error": str(e) + }) + return f"# Error collecting Docker metrics: {str(e)}\n" + + +async def get_all_metrics() -> str: + """Export all available metrics in Prometheus format. + + Returns: + Prometheus-formatted metrics text + """ + metrics = [] + + # Add system metrics + system_metrics = await get_system_metrics() + metrics.append(system_metrics) + + # Add Docker metrics + docker_metrics = await get_docker_metrics() + metrics.append(docker_metrics) + + # Add custom metrics + metrics.append(PrometheusMetric.gauge( + "tailopsmcp_scrape_timestamp_seconds", + datetime.now().timestamp(), + help_text="Timestamp of metrics collection" + )) + + audit.log("get_all_metrics", {}, {"success": True}) + + return "\n".join(metrics) + + +async def start_metrics_server( + port: int = 9100, + bind_address: str = "0.0.0.0" +) -> Dict[str, Any]: + """Start a simple HTTP server to expose Prometheus metrics. + + Args: + port: Port to bind to (default: 9100) + bind_address: Address to bind to (default: 0.0.0.0) + + Returns: + Dict with server info + """ + try: + from aiohttp import web + + async def metrics_handler(request): + """Handle /metrics requests.""" + metrics = await get_all_metrics() + return web.Response(text=metrics, content_type="text/plain; version=0.0.4") + + async def health_handler(request): + """Handle health check requests.""" + return web.Response(text="OK\n") + + app = web.Application() + app.router.add_get("/metrics", metrics_handler) + app.router.add_get("/health", health_handler) + app.router.add_get("/", lambda r: web.Response(text="TailOpsMCP Metrics Exporter\n/metrics - Prometheus metrics\n/health - Health check\n")) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, bind_address, port) + await site.start() + + audit.log("start_metrics_server", { + "port": port, + "bind_address": bind_address + }, { + "success": True + }) + + return { + "success": True, + "message": f"Metrics server started on {bind_address}:{port}", + "endpoint": f"http://{bind_address}:{port}/metrics" + } + + except Exception as e: + error_msg = f"Failed to start metrics server: {str(e)}" + audit.log("start_metrics_server", { + "port": port, + "bind_address": bind_address + }, { + "success": False, + "error": error_msg + }) + return { + "success": False, + "error": error_msg + } + + +async def save_metrics_to_file( + output_path: str = "/var/lib/node_exporter/textfile_collector/tailopsmcp.prom" +) -> Dict[str, Any]: + """Save metrics to a file for node_exporter textfile collector. + + Args: + output_path: Path to write metrics file + + Returns: + Dict with result status + """ + try: + # Ensure directory exists + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Get all metrics + metrics = await get_all_metrics() + + # Write to temp file first, then rename (atomic operation) + temp_path = output_path + ".tmp" + with open(temp_path, 'w') as f: + f.write(metrics) + + os.rename(temp_path, output_path) + + audit.log("save_metrics_to_file", { + "output_path": output_path + }, { + "success": True + }) + + return { + "success": True, + "message": f"Metrics written to {output_path}", + "path": output_path + } + + except Exception as e: + error_msg = f"Failed to write metrics file: {str(e)}" + audit.log("save_metrics_to_file", { + "output_path": output_path + }, { + "success": False, + "error": error_msg + }) + return { + "success": False, + "error": error_msg + } + + +def register_tools(mcp: FastMCP): + """Register Prometheus metrics export tools with MCP instance.""" + + @mcp.tool() + @secure_tool("metrics:read") + async def export_prometheus_metrics( + include_docker: bool = True + ) -> str: + """Export system and Docker metrics in Prometheus format. + + Args: + include_docker: Include Docker container metrics (default: True) + + Returns: + Prometheus-formatted metrics text + """ + try: + if include_docker: + result = await get_all_metrics() + else: + result = await get_system_metrics() + return result + except Exception as e: + return format_error(e, "export_prometheus_metrics") + + @mcp.tool() + @secure_tool("metrics:read") + async def get_prometheus_system_metrics() -> str: + """Export system metrics in Prometheus format. + + Returns: + Prometheus-formatted system metrics text + """ + try: + result = await get_system_metrics() + return result + except Exception as e: + return format_error(e, "get_prometheus_system_metrics") + + @mcp.tool() + @secure_tool("metrics:read") + async def get_prometheus_docker_metrics() -> str: + """Export Docker container metrics in Prometheus format. + + Returns: + Prometheus-formatted Docker metrics text + """ + try: + result = await get_docker_metrics() + return result + except Exception as e: + return format_error(e, "get_prometheus_docker_metrics") + + @mcp.tool() + @secure_tool("metrics:admin") + async def write_metrics_textfile( + output_path: str = "/var/lib/node_exporter/textfile_collector/tailopsmcp.prom" + ) -> dict: + """Save metrics to a file for node_exporter textfile collector. + + Args: + output_path: Path to write metrics file (default: /var/lib/node_exporter/textfile_collector/tailopsmcp.prom) + + Returns: + Dict with result status + """ + try: + result = await save_metrics_to_file(output_path=output_path) + return result + except Exception as e: + return format_error(e, "write_metrics_textfile") + + logger.info("Registered 4 metrics tools") diff --git a/src/tools/stack_tools.py b/src/tools/stack_tools.py index f11fda6..191adab 100644 --- a/src/tools/stack_tools.py +++ b/src/tools/stack_tools.py @@ -75,14 +75,109 @@ async def get_stack_status(host: str, stack_name: str, format: Optional[str] = N async def get_repo_status(stack_name: str) -> Dict[str, Any]: - """Return repo status for a stack (git info). Placeholder uses Inventory repo_url only.""" + """Return repo status for a stack with actual git information.""" inv = Inventory() stacks = inv.list_stacks() stack = stacks.get(stack_name) - repo = stack.get("repo_url") if stack else None - res = RepoStatus(repo_url=repo, branch=stack.get("branch") if stack else None, latest_commit=None, deployed_commit=stack.get("deployed_commit") if stack else None) - audit.log("get_repo_status", {"stack_name": stack_name}, {"success": True}) - return res.dict() + + if not stack: + audit.log("get_repo_status", {"stack_name": stack_name}, {"success": False, "error": "stack_not_found"}) + return RepoStatus(repo_url=None).dict() + + repo_url = stack.get("repo_url") + stack_path = stack.get("path") + deployed_commit = stack.get("deployed_commit") + branch = stack.get("branch", "main") + + # If no path or git repo doesn't exist, return basic info + if not stack_path or not os.path.exists(os.path.join(stack_path, ".git")): + res = RepoStatus( + repo_url=repo_url, + branch=branch, + latest_commit=None, + deployed_commit=deployed_commit + ) + audit.log("get_repo_status", {"stack_name": stack_name}, {"success": True, "no_git": True}) + return res.dict() + + try: + # Get latest commit from remote + await to_thread( + subprocess.run, + ["git", "fetch", "origin", branch], + cwd=stack_path, + capture_output=True, + text=True, + check=True + ) + + # Get remote HEAD commit + remote_rev = await to_thread( + subprocess.run, + ["git", "rev-parse", f"origin/{branch}"], + cwd=stack_path, + capture_output=True, + text=True, + check=True + ) + latest_commit = remote_rev.stdout.strip() + + # Get current HEAD commit + local_rev = await to_thread( + subprocess.run, + ["git", "rev-parse", "HEAD"], + cwd=stack_path, + capture_output=True, + text=True, + check=True + ) + current_commit = local_rev.stdout.strip() + + # Check if there are uncommitted changes + status_result = await to_thread( + subprocess.run, + ["git", "status", "--porcelain"], + cwd=stack_path, + capture_output=True, + text=True, + check=True + ) + has_uncommitted_changes = bool(status_result.stdout.strip()) + + # Calculate ahead/behind + revlist = await to_thread( + subprocess.run, + ["git", "rev-list", "--left-right", "--count", f"HEAD...origin/{branch}"], + cwd=stack_path, + capture_output=True, + text=True, + check=True + ) + ahead, behind = map(int, revlist.stdout.strip().split()) + + res = RepoStatus( + repo_url=repo_url, + branch=branch, + latest_commit=latest_commit, + deployed_commit=current_commit, + ahead_by=ahead, + behind_by=behind, + has_uncommitted_changes=has_uncommitted_changes + ) + + audit.log("get_repo_status", {"stack_name": stack_name}, {"success": True}) + return res.dict() + + except Exception as e: + # Fall back to basic info if git commands fail + res = RepoStatus( + repo_url=repo_url, + branch=branch, + latest_commit=None, + deployed_commit=deployed_commit + ) + audit.log("get_repo_status", {"stack_name": stack_name}, {"success": False, "error": str(e)}) + return res.dict() async def get_config_diff(stack_name: str) -> Dict[str, Any]: @@ -93,24 +188,208 @@ async def get_config_diff(stack_name: str) -> Dict[str, Any]: async def deploy_stack(req: DeployRequest) -> Dict[str, Any]: - """Perform a deploy operation. For now this is a dry-run-only PoC. + """Perform a git-based stack deployment with docker-compose. - It returns planned changes and does not perform actual container operations - unless `dry_run` is False and a real implementation is provided. + This function: + - Clones or updates a git repository + - Checks out the specified commit/branch + - Runs docker-compose up (with optional image pull) + - Updates inventory with deployed commit info """ - # Build a minimal planned change list - planned = [f"Would pull images for stack {req.stack_name}", f"Would restart services for {req.stack_name}"] - result = DeployResult(success=True, dry_run=req.dry_run, planned_changes=planned, errors=[], deployed_commit=req.target_commit) + errors = [] + planned = [] + + # Get stack info from inventory + inv = Inventory() + stacks = inv.list_stacks() + stack = stacks.get(req.stack_name, {}) + + # Determine deployment path + deploy_base = os.getenv("STACK_DEPLOY_PATH", "/opt/stacks") + stack_path = stack.get("path") or os.path.join(deploy_base, req.stack_name) + repo_url = stack.get("repo_url") + + if not repo_url: + errors.append(f"No repo_url configured for stack {req.stack_name}") + result = DeployResult(success=False, dry_run=req.dry_run, planned_changes=planned, errors=errors, deployed_commit=None) + audit.log("deploy_stack", req.dict(), {"success": False, "error": "no_repo_url"}) + return result.dict() + + # Plan the deployment + repo_exists = os.path.exists(os.path.join(stack_path, ".git")) + if repo_exists: + planned.append(f"Update repository in {stack_path}") + planned.append(f"Fetch latest changes from {repo_url}") + else: + planned.append(f"Clone repository {repo_url} to {stack_path}") + + if req.target_commit: + planned.append(f"Checkout commit/branch: {req.target_commit}") + + if req.pull_images: + planned.append(f"Pull Docker images for {req.stack_name}") + + planned.append(f"Run docker-compose up -d for {req.stack_name}") + + if req.dry_run: + result = DeployResult(success=True, dry_run=True, planned_changes=planned, errors=[], deployed_commit=req.target_commit) + audit.log("deploy_stack", req.dict(), {"success": True, "dry_run": True}) + return result.dict() + + # Execute deployment + try: + # Ensure base directory exists + os.makedirs(deploy_base, exist_ok=True) + + # Clone or update repository + if repo_exists: + # Update existing repo + git_fetch = await to_thread( + subprocess.run, + ["git", "fetch", "--all"], + cwd=stack_path, + capture_output=True, + text=True, + check=True + ) + git_reset = await to_thread( + subprocess.run, + ["git", "reset", "--hard", f"origin/{req.target_commit or stack.get('branch', 'main')}"], + cwd=stack_path, + capture_output=True, + text=True, + check=True + ) + else: + # Clone new repo + git_clone = await to_thread( + subprocess.run, + ["git", "clone", repo_url, stack_path], + capture_output=True, + text=True, + check=True + ) + + if req.target_commit: + git_checkout = await to_thread( + subprocess.run, + ["git", "checkout", req.target_commit], + cwd=stack_path, + capture_output=True, + text=True, + check=True + ) + + # Get current commit hash + git_rev = await to_thread( + subprocess.run, + ["git", "rev-parse", "HEAD"], + cwd=stack_path, + capture_output=True, + text=True, + check=True + ) + deployed_commit = git_rev.stdout.strip() + + # Pull images if requested + if req.pull_images: + compose_pull = await to_thread( + subprocess.run, + ["docker", "compose", "pull"], + cwd=stack_path, + capture_output=True, + text=True, + check=True + ) + + # Deploy with docker-compose + compose_up = await to_thread( + subprocess.run, + ["docker", "compose", "up", "-d", "--remove-orphans"], + cwd=stack_path, + capture_output=True, + text=True, + check=True + ) + + # Update inventory + stack["deployed_commit"] = deployed_commit + stack["path"] = stack_path + stack["deployed_at"] = datetime.now().isoformat() + inv._data["stacks"][req.stack_name] = stack + inv._save() + + result = DeployResult( + success=True, + dry_run=False, + planned_changes=planned, + errors=[], + deployed_commit=deployed_commit + ) + + audit.log("deploy_stack", req.dict(), { + "success": True, + "deployed_commit": deployed_commit, + "stack_path": stack_path + }) + + except subprocess.CalledProcessError as e: + errors.append(f"Command failed: {e.cmd}") + errors.append(f"Error: {e.stderr}") + result = DeployResult( + success=False, + dry_run=False, + planned_changes=planned, + errors=errors, + deployed_commit=None + ) + audit.log("deploy_stack", req.dict(), { + "success": False, + "error": str(e) + }) + except Exception as e: + errors.append(f"Deployment failed: {str(e)}") + result = DeployResult( + success=False, + dry_run=False, + planned_changes=planned, + errors=errors, + deployed_commit=None + ) + audit.log("deploy_stack", req.dict(), { + "success": False, + "error": str(e) + }) - audit.log("deploy_stack", req.dict(), {"success": True, "dry_run": req.dry_run}) return result.dict() async def rollback_stack(host: str, stack_name: str, to_commit: str, dry_run: bool = True) -> Dict[str, Any]: - planned = [f"Would checkout {to_commit} for {stack_name}", f"Would redeploy {stack_name}"] - res = DeployResult(success=True, dry_run=dry_run, planned_changes=planned, errors=[], deployed_commit=to_commit) - audit.log("rollback_stack", {"host": host, "stack_name": stack_name, "to_commit": to_commit}, {"success": True, "dry_run": dry_run}) - return res.dict() + """Rollback a stack to a previous commit.""" + errors = [] + planned = [ + f"Checkout commit {to_commit} for {stack_name}", + f"Redeploy {stack_name} with docker-compose" + ] + + if dry_run: + res = DeployResult(success=True, dry_run=True, planned_changes=planned, errors=[], deployed_commit=to_commit) + audit.log("rollback_stack", {"host": host, "stack_name": stack_name, "to_commit": to_commit}, {"success": True, "dry_run": True}) + return res.dict() + + # Execute rollback via deploy_stack + deploy_req = DeployRequest( + host=host, + stack_name=stack_name, + target_commit=to_commit, + pull_images=False, + force=True, + dry_run=False + ) + + result = await deploy_stack(deploy_req) + audit.log("rollback_stack", {"host": host, "stack_name": stack_name, "to_commit": to_commit}, {"success": result.get("success"), "dry_run": False}) + return result async def get_stack_history(stack_name: str, limit: int = 20) -> List[Dict[str, Any]]: