diff --git a/src/drivers.act b/src/drivers.act index 4587c6d..c221912 100644 --- a/src/drivers.act +++ b/src/drivers.act @@ -62,7 +62,7 @@ DRIVER_STATE_ROLLING_BACK: str = "rolling_back" DRIVER_STATE_ERROR: str = "error" DRIVER_STATE_DISCONNECTED: str = "disconnected" -class _Driver(object): +class Driver(object): """Base implementation class for router drivers""" device_type: str @@ -71,6 +71,7 @@ class _Driver(object): initialize: proc() -> None is_ready: proc() -> bool get_state: proc() -> str + is_valid_transition: proc(from_state: str, to_state: str) -> bool handle_data: proc(data: bytes) -> None execute_command: proc(cb: action(err: ?Exception, response: ?str) -> None, command: str) -> None # TODO: default arg? @@ -80,7 +81,7 @@ class _Driver(object): rollback_configuration: proc(cb: action(err: ?Exception, session_log: str) -> None, commits_back: int) -> None get_device_info: proc() -> dict[str, str] -class _BaseDriver(_Driver): +class _BaseDriver(Driver): """Base driver implementation with common functionality""" _output_buffer: str @@ -120,7 +121,7 @@ class _BaseDriver(_Driver): def _transition_to_state(self, new_state: str) -> None: """Transition to new state with validation and logging""" - if self._is_valid_transition(self._state, new_state): + if self.is_valid_transition(self._state, new_state): old_state = self._state self._state = new_state self.log.debug("Driver state transition", {"from": old_state, "to": new_state}) @@ -128,7 +129,7 @@ class _BaseDriver(_Driver): self.log.error("Invalid state transition attempted", {"from": self._state, "to": new_state}) self._transition_to_state(DRIVER_STATE_ERROR) - def _is_valid_transition(self, from_state: str, to_state: str) -> bool: + def is_valid_transition(self, from_state: str, to_state: str) -> bool: """Check if state transition is valid""" valid_transitions = { DRIVER_STATE_INITIALIZING: [DRIVER_STATE_READY, DRIVER_STATE_ERROR, DRIVER_STATE_DISCONNECTED], @@ -434,7 +435,7 @@ def auto_detect(ssh_client: ssh.SSHClient) -> str: # that analyzes the SSH banner, initial prompts, etc. return DEVICE_TYPE_UNKNOWN -def create_driver(device_type: str, ssh_client: SSHClientWrapper, log: logging.Logger) -> ?_Driver: +def create_driver(device_type: str, ssh_client: SSHClientWrapper, log: logging.Logger) -> ?Driver: """ Create appropriate driver instance based on device type @@ -447,17 +448,17 @@ def create_driver(device_type: str, ssh_client: SSHClientWrapper, log: logging.L Driver instance or None if unsupported """ if device_type == DEVICE_TYPE_JUNIPER: - return _JuniperDriver(device_type, ssh_client, log) + return JuniperDriver(device_type, ssh_client, log) elif device_type == DEVICE_TYPE_CISCO_IOSXR: - return _CiscoIOSXRDriver(device_type, ssh_client, log) + return CiscoIOSXRDriver(device_type, ssh_client, log) elif device_type == DEVICE_TYPE_CISCO_IOSXE: - return _CiscoIOSXEDriver(device_type, ssh_client, log) + return CiscoIOSXEDriver(device_type, ssh_client, log) else: log.warning("Unsupported device type", {"device_type": device_type}) return None -class _JuniperDriver(_BaseDriver): +class JuniperDriver(_BaseDriver): """Juniper JUNOS driver implementation""" def initialize(self) -> None: @@ -525,7 +526,7 @@ class _JuniperDriver(_BaseDriver): "os": "JUNOS" } -class _CiscoIOSXRDriver(_BaseDriver): +class CiscoIOSXRDriver(_BaseDriver): """Cisco IOS XR driver implementation""" def initialize(self) -> None: @@ -594,7 +595,7 @@ class _CiscoIOSXRDriver(_BaseDriver): "os": "IOS XR" } -class _CiscoIOSXEDriver(_BaseDriver): +class CiscoIOSXEDriver(_BaseDriver): """ Cisco IOS XE driver implementation diff --git a/src/router_client.act b/src/router_client.act index 43e3d14..9a8cd81 100644 --- a/src/router_client.act +++ b/src/router_client.act @@ -34,7 +34,7 @@ actor Client(auth: WorldCap, _log.info("Initializing router client", {"address": address, "protocol_": protocol_}) var _ssh_client: ?ssh.SSHClient = None - var _driver: ?drivers._Driver = None + var _driver: ?drivers.Driver = None def _on_ssh_connect(client: ssh.SSHClient) -> None: """Handle SSH connection establishment""" diff --git a/src/test_drivers.act b/src/test_drivers.act index 0764ab0..dc2161c 100644 --- a/src/test_drivers.act +++ b/src/test_drivers.act @@ -16,7 +16,7 @@ actor _test_juniper_driver_initialization(t: testing.AsyncT): var log: logging.Logger = logging.Logger(t.log_handler) def run_test() -> None: - driver = drivers._JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log) + driver = drivers.JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log) # Verify initial state if driver.get_state() != drivers.DRIVER_STATE_INITIALIZING: @@ -46,7 +46,7 @@ actor _test_cisco_driver_initialization(t: testing.AsyncT): var log: logging.Logger = logging.Logger(t.log_handler) def run_test() -> None: - driver = drivers._CiscoIOSXRDriver(drivers.DEVICE_TYPE_CISCO_IOSXR, test_ssh, log) + driver = drivers.CiscoIOSXRDriver(drivers.DEVICE_TYPE_CISCO_IOSXR, test_ssh, log) # Verify initial state if driver.get_state() != drivers.DRIVER_STATE_INITIALIZING: @@ -74,7 +74,7 @@ actor _test_command_execution_state_flow(t: testing.AsyncT): var test_ssh: drivers.TestSSHWrapper = drivers.TestSSHWrapper() var log: logging.Logger = logging.Logger(t.log_handler) - var driver: drivers._JuniperDriver = drivers._JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log) + var driver: drivers.JuniperDriver = drivers.JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log) def run_test() -> None: driver.initialize() @@ -124,7 +124,7 @@ actor _test_configuration_and_commit_flow(t: testing.AsyncT): var test_ssh: drivers.TestSSHWrapper = drivers.TestSSHWrapper() var log: logging.Logger = logging.Logger(t.log_handler) - var driver: drivers._JuniperDriver = drivers._JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log) + var driver: drivers.JuniperDriver = drivers.JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log) def run_test() -> None: driver.initialize() @@ -168,7 +168,7 @@ actor _test_commit_failure_automatic_rollback(t: testing.AsyncT): var test_ssh: drivers.TestSSHWrapper = drivers.TestSSHWrapper() var log: logging.Logger = logging.Logger(t.log_handler) - var driver: drivers._JuniperDriver = drivers._JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log) + var driver: drivers.JuniperDriver = drivers.JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log) def run_test() -> None: driver.initialize() @@ -213,7 +213,7 @@ actor _test_driver_state_transition_validation(t: testing.AsyncT): var log: logging.Logger = logging.Logger(t.log_handler) def run_test() -> None: - driver = drivers._JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log) + driver = drivers.JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log) # Test valid transitions valid_tests = [ @@ -227,7 +227,7 @@ actor _test_driver_state_transition_validation(t: testing.AsyncT): ] for from_state, to_state in valid_tests: - if not driver._is_valid_transition(from_state, to_state): + if not driver.is_valid_transition(from_state, to_state): t.failure(Exception("Expected valid transition: " + from_state + " -> " + to_state)) return @@ -241,7 +241,7 @@ actor _test_driver_state_transition_validation(t: testing.AsyncT): ] for from_state, to_state in invalid_tests: - if driver._is_valid_transition(from_state, to_state): + if driver.is_valid_transition(from_state, to_state): t.failure(Exception("Expected invalid transition: " + from_state + " -> " + to_state)) return