Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/drivers.act
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ DRIVER_STATE_ROLLING_BACK: str = "rolling_back"
DRIVER_STATE_ERROR: str = "error"
DRIVER_STATE_DISCONNECTED: str = "disconnected"

class _Driver(object):
class Driver(object):
"""Base implementation class for router drivers"""

device_type: str
Expand All @@ -71,6 +71,7 @@ class _Driver(object):
initialize: proc() -> None
is_ready: proc() -> bool
get_state: proc() -> str
is_valid_transition: proc(from_state: str, to_state: str) -> bool
handle_data: proc(data: bytes) -> None
execute_command: proc(cb: action(err: ?Exception, response: ?str) -> None, command: str) -> None
# TODO: default arg?
Expand All @@ -80,7 +81,7 @@ class _Driver(object):
rollback_configuration: proc(cb: action(err: ?Exception, session_log: str) -> None, commits_back: int) -> None
get_device_info: proc() -> dict[str, str]

class _BaseDriver(_Driver):
class _BaseDriver(Driver):
"""Base driver implementation with common functionality"""

_output_buffer: str
Expand Down Expand Up @@ -120,15 +121,15 @@ class _BaseDriver(_Driver):

def _transition_to_state(self, new_state: str) -> None:
"""Transition to new state with validation and logging"""
if self._is_valid_transition(self._state, new_state):
if self.is_valid_transition(self._state, new_state):
old_state = self._state
self._state = new_state
self.log.debug("Driver state transition", {"from": old_state, "to": new_state})
else:
self.log.error("Invalid state transition attempted", {"from": self._state, "to": new_state})
self._transition_to_state(DRIVER_STATE_ERROR)

def _is_valid_transition(self, from_state: str, to_state: str) -> bool:
def is_valid_transition(self, from_state: str, to_state: str) -> bool:
"""Check if state transition is valid"""
valid_transitions = {
DRIVER_STATE_INITIALIZING: [DRIVER_STATE_READY, DRIVER_STATE_ERROR, DRIVER_STATE_DISCONNECTED],
Expand Down Expand Up @@ -434,7 +435,7 @@ def auto_detect(ssh_client: ssh.SSHClient) -> str:
# that analyzes the SSH banner, initial prompts, etc.
return DEVICE_TYPE_UNKNOWN

def create_driver(device_type: str, ssh_client: SSHClientWrapper, log: logging.Logger) -> ?_Driver:
def create_driver(device_type: str, ssh_client: SSHClientWrapper, log: logging.Logger) -> ?Driver:
"""
Create appropriate driver instance based on device type

Expand All @@ -447,17 +448,17 @@ def create_driver(device_type: str, ssh_client: SSHClientWrapper, log: logging.L
Driver instance or None if unsupported
"""
if device_type == DEVICE_TYPE_JUNIPER:
return _JuniperDriver(device_type, ssh_client, log)
return JuniperDriver(device_type, ssh_client, log)
elif device_type == DEVICE_TYPE_CISCO_IOSXR:
return _CiscoIOSXRDriver(device_type, ssh_client, log)
return CiscoIOSXRDriver(device_type, ssh_client, log)
elif device_type == DEVICE_TYPE_CISCO_IOSXE:
return _CiscoIOSXEDriver(device_type, ssh_client, log)
return CiscoIOSXEDriver(device_type, ssh_client, log)
else:
log.warning("Unsupported device type", {"device_type": device_type})
return None


class _JuniperDriver(_BaseDriver):
class JuniperDriver(_BaseDriver):
"""Juniper JUNOS driver implementation"""

def initialize(self) -> None:
Expand Down Expand Up @@ -525,7 +526,7 @@ class _JuniperDriver(_BaseDriver):
"os": "JUNOS"
}

class _CiscoIOSXRDriver(_BaseDriver):
class CiscoIOSXRDriver(_BaseDriver):
"""Cisco IOS XR driver implementation"""

def initialize(self) -> None:
Expand Down Expand Up @@ -594,7 +595,7 @@ class _CiscoIOSXRDriver(_BaseDriver):
"os": "IOS XR"
}

class _CiscoIOSXEDriver(_BaseDriver):
class CiscoIOSXEDriver(_BaseDriver):
"""
Cisco IOS XE driver implementation

Expand Down
2 changes: 1 addition & 1 deletion src/router_client.act
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ actor Client(auth: WorldCap,
_log.info("Initializing router client", {"address": address, "protocol_": protocol_})

var _ssh_client: ?ssh.SSHClient = None
var _driver: ?drivers._Driver = None
var _driver: ?drivers.Driver = None

def _on_ssh_connect(client: ssh.SSHClient) -> None:
"""Handle SSH connection establishment"""
Expand Down
16 changes: 8 additions & 8 deletions src/test_drivers.act
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ actor _test_juniper_driver_initialization(t: testing.AsyncT):
var log: logging.Logger = logging.Logger(t.log_handler)

def run_test() -> None:
driver = drivers._JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log)
driver = drivers.JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log)

# Verify initial state
if driver.get_state() != drivers.DRIVER_STATE_INITIALIZING:
Expand Down Expand Up @@ -46,7 +46,7 @@ actor _test_cisco_driver_initialization(t: testing.AsyncT):
var log: logging.Logger = logging.Logger(t.log_handler)

def run_test() -> None:
driver = drivers._CiscoIOSXRDriver(drivers.DEVICE_TYPE_CISCO_IOSXR, test_ssh, log)
driver = drivers.CiscoIOSXRDriver(drivers.DEVICE_TYPE_CISCO_IOSXR, test_ssh, log)

# Verify initial state
if driver.get_state() != drivers.DRIVER_STATE_INITIALIZING:
Expand Down Expand Up @@ -74,7 +74,7 @@ actor _test_command_execution_state_flow(t: testing.AsyncT):

var test_ssh: drivers.TestSSHWrapper = drivers.TestSSHWrapper()
var log: logging.Logger = logging.Logger(t.log_handler)
var driver: drivers._JuniperDriver = drivers._JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log)
var driver: drivers.JuniperDriver = drivers.JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log)

def run_test() -> None:
driver.initialize()
Expand Down Expand Up @@ -124,7 +124,7 @@ actor _test_configuration_and_commit_flow(t: testing.AsyncT):

var test_ssh: drivers.TestSSHWrapper = drivers.TestSSHWrapper()
var log: logging.Logger = logging.Logger(t.log_handler)
var driver: drivers._JuniperDriver = drivers._JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log)
var driver: drivers.JuniperDriver = drivers.JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log)

def run_test() -> None:
driver.initialize()
Expand Down Expand Up @@ -168,7 +168,7 @@ actor _test_commit_failure_automatic_rollback(t: testing.AsyncT):

var test_ssh: drivers.TestSSHWrapper = drivers.TestSSHWrapper()
var log: logging.Logger = logging.Logger(t.log_handler)
var driver: drivers._JuniperDriver = drivers._JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log)
var driver: drivers.JuniperDriver = drivers.JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log)

def run_test() -> None:
driver.initialize()
Expand Down Expand Up @@ -213,7 +213,7 @@ actor _test_driver_state_transition_validation(t: testing.AsyncT):
var log: logging.Logger = logging.Logger(t.log_handler)

def run_test() -> None:
driver = drivers._JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log)
driver = drivers.JuniperDriver(drivers.DEVICE_TYPE_JUNIPER, test_ssh, log)

# Test valid transitions
valid_tests = [
Expand All @@ -227,7 +227,7 @@ actor _test_driver_state_transition_validation(t: testing.AsyncT):
]

for from_state, to_state in valid_tests:
if not driver._is_valid_transition(from_state, to_state):
if not driver.is_valid_transition(from_state, to_state):
t.failure(Exception("Expected valid transition: " + from_state + " -> " + to_state))
return

Expand All @@ -241,7 +241,7 @@ actor _test_driver_state_transition_validation(t: testing.AsyncT):
]

for from_state, to_state in invalid_tests:
if driver._is_valid_transition(from_state, to_state):
if driver.is_valid_transition(from_state, to_state):
t.failure(Exception("Expected invalid transition: " + from_state + " -> " + to_state))
return

Expand Down