diff --git a/tools/find_rtkbase/scan_network.py b/tools/find_rtkbase/scan_network.py index dd69c9d1..2bcd2a36 100644 --- a/tools/find_rtkbase/scan_network.py +++ b/tools/find_rtkbase/scan_network.py @@ -11,6 +11,35 @@ log = logging.getLogger(__name__) log.setLevel('ERROR') +def sort_hosts(hosts_list): + def first_port(host): + if host.get('port') is not None: + return host.get('port') + ports = host.get('PORTS') or [] + return ports[0] if ports else 0 + + def host_sort_key(host): + return ( + (host.get('server') or host.get('SERVER') or host.get('fqdn') or host.get('NAME') or '').casefold(), + host.get('ip') or host.get('IP') or '', + first_port(host), + ) + + return sorted(hosts_list, key=host_sort_key) + + +def iter_probe_addresses(result): + ip_address = result.get('IP') + server_name = result.get('SERVER') + + if ip_address: + # Direct IP probes are faster and more reliable than mDNS on some VPN setups. + yield ip_address + yield ip_address + if server_name and server_name != ip_address: + yield server_name + + class MyZeroConfListener: def __init__(self): self.services = [] @@ -27,15 +56,19 @@ def zeroconf_scan(name, prot_type, timeout=5): log.debug("Scanning with zeroconf") service_list = [] zeroconf = Zeroconf() - listener = MyZeroConfListener() - browser = ServiceBrowser(zeroconf, prot_type, listener) - time.sleep(timeout) - for service in listener.services: - if name.lower() in service.name.lower(): - service_list.append({'NAME' : service.name, - 'PORTS' : [service.port], - 'SERVER' : service.server.rstrip('.'), - 'IP' : '.'.join(str(byte) for byte in service.addresses[0])}) + try: + listener = MyZeroConfListener() + browser = ServiceBrowser(zeroconf, prot_type, listener) + time.sleep(timeout) + for service in listener.services: + if name.lower() in service.name.lower(): + service_list.append({'NAME' : service.name, + 'PORTS' : [service.port], + 'SERVER' : service.server.rstrip('.'), + 'IP' : '.'.join(str(byte) for byte in service.addresses[0])}) + finally: + zeroconf.close() + service_list = sort_hosts(service_list) log.debug(f"filtered list for {name}") log.debug(service_list) return service_list @@ -104,10 +137,10 @@ def get_rtkbase_infos(host_list): if result.get('PORTS') and len(result.get('PORTS')) > 0: try: for port in result.get('PORTS'): - #try with mDns server name at first, then with the ip address if it fails - for address in (result.get('SERVER'), result.get('IP')): + ans = None + # Prefer direct IP probes before falling back to the advertised mDNS name. + for address in iter_probe_addresses(result): try: - ans = None if address is None: continue log.debug(f"{address}:{port} Api request") @@ -202,6 +235,7 @@ def main(ports, allscan=False, iprange=None): available_rtkbase = get_rtkbase_infos(scan_results) #remove duplicate available_rtkbase = remove_duplicate_hosts(available_rtkbase) + available_rtkbase = sort_hosts(available_rtkbase) log.debug("RTKBase station found: ") log.debug(available_rtkbase) return available_rtkbase @@ -211,4 +245,4 @@ def main(ports, allscan=False, iprange=None): if args.debug: log.setLevel('DEBUG') log.debug(f"Arguments: {args}") - print(main(args.ports, args.allscan, args.iprange)) \ No newline at end of file + print(main(args.ports, args.allscan, args.iprange)) diff --git a/tools/find_rtkbase/test_scan_network.py b/tools/find_rtkbase/test_scan_network.py new file mode 100644 index 00000000..f6d9b333 --- /dev/null +++ b/tools/find_rtkbase/test_scan_network.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import importlib.util +import unittest +from pathlib import Path +from unittest import mock + + +MODULE_PATH = Path(__file__).with_name("scan_network.py") +SPEC = importlib.util.spec_from_file_location("scan_network_under_test", MODULE_PATH) +scan_network = importlib.util.module_from_spec(SPEC) +assert SPEC.loader is not None +SPEC.loader.exec_module(scan_network) + + +class FakeService: + def __init__(self, name, port, server, address_bytes): + self.name = name + self.port = port + self.server = server + self.addresses = [address_bytes] + + +class FakeZeroconf: + def __init__(self, services): + self.services = services + self.closed = False + + def get_service_info(self, service_type, name): + return self.services[name] + + def close(self): + self.closed = True + + +class FakeResponse: + def __init__(self, status_code, payload): + self.status_code = status_code + self._payload = payload + + def json(self): + return self._payload + + +class ScanNetworkTests(unittest.TestCase): + def test_zeroconf_scan_returns_results_sorted_by_server_name(self): + fake_services = { + "svc-b": FakeService( + "RTKBase Web Server Beta", + 80, + "beta.local.", + bytes([192, 168, 1, 20]), + ), + "svc-a": FakeService( + "RTKBase Web Server Alpha", + 80, + "alpha.local.", + bytes([192, 168, 1, 10]), + ), + } + fake_zeroconf = FakeZeroconf(fake_services) + + def fake_browser(zeroconf, service_type, listener): + listener.add_service(zeroconf, service_type, "svc-b") + listener.add_service(zeroconf, service_type, "svc-a") + return object() + + with ( + mock.patch.object(scan_network, "Zeroconf", return_value=fake_zeroconf), + mock.patch.object(scan_network, "ServiceBrowser", side_effect=fake_browser), + mock.patch.object(scan_network.time, "sleep", return_value=None), + ): + results = scan_network.zeroconf_scan("RTKBase Web Server", "_http._tcp.local.") + + self.assertEqual(["alpha.local", "beta.local"], [item["SERVER"] for item in results]) + self.assertTrue(fake_zeroconf.closed) + + def test_get_rtkbase_infos_retries_ip_before_server_name(self): + calls = [] + + def fake_get(url, timeout): + calls.append(url) + if len(calls) < 3: + raise scan_network.requests.exceptions.ConnectionError("network") + return FakeResponse( + 200, + {"app": "RTKBase", "app_version": "2.0.0", "fqdn": "alpha.local"}, + ) + + host_list = [{"IP": "10.0.0.5", "SERVER": "alpha.local", "PORTS": [80]}] + with mock.patch.object(scan_network.requests, "get", side_effect=fake_get): + results = scan_network.get_rtkbase_infos(host_list) + + self.assertEqual( + [ + "http://10.0.0.5:80/api/v1/infos", + "http://10.0.0.5:80/api/v1/infos", + "http://alpha.local:80/api/v1/infos", + ], + calls, + ) + self.assertEqual("alpha.local", results[0]["server"]) + self.assertEqual("10.0.0.5", results[0]["ip"]) + + def test_get_rtkbase_infos_stops_after_successful_ip_retry(self): + calls = [] + + def fake_get(url, timeout): + calls.append(url) + if len(calls) == 1: + raise scan_network.requests.exceptions.ConnectionError("network") + return FakeResponse( + 200, + {"app": "RTKBase", "app_version": "2.0.0", "fqdn": "alpha.local"}, + ) + + host_list = [{"IP": "10.0.0.5", "SERVER": "alpha.local", "PORTS": [80]}] + with mock.patch.object(scan_network.requests, "get", side_effect=fake_get): + results = scan_network.get_rtkbase_infos(host_list) + + self.assertEqual( + [ + "http://10.0.0.5:80/api/v1/infos", + "http://10.0.0.5:80/api/v1/infos", + ], + calls, + ) + self.assertEqual(1, len(results)) + + +if __name__ == "__main__": + unittest.main()