Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions tools/find_rtkbase/scan_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,35 @@
log = logging.getLogger(__name__)
log.setLevel('ERROR')

def sort_hosts(hosts_list):
def first_port(host):
if host.get('port') is not None:
return host.get('port')
ports = host.get('PORTS') or []
return ports[0] if ports else 0

def host_sort_key(host):
return (
(host.get('server') or host.get('SERVER') or host.get('fqdn') or host.get('NAME') or '').casefold(),
host.get('ip') or host.get('IP') or '',
first_port(host),
)

return sorted(hosts_list, key=host_sort_key)


def iter_probe_addresses(result):
ip_address = result.get('IP')
server_name = result.get('SERVER')

if ip_address:
# Direct IP probes are faster and more reliable than mDNS on some VPN setups.
yield ip_address
yield ip_address
if server_name and server_name != ip_address:
yield server_name


class MyZeroConfListener:
def __init__(self):
self.services = []
Expand All @@ -27,15 +56,19 @@ def zeroconf_scan(name, prot_type, timeout=5):
log.debug("Scanning with zeroconf")
service_list = []
zeroconf = Zeroconf()
listener = MyZeroConfListener()
browser = ServiceBrowser(zeroconf, prot_type, listener)
time.sleep(timeout)
for service in listener.services:
if name.lower() in service.name.lower():
service_list.append({'NAME' : service.name,
'PORTS' : [service.port],
'SERVER' : service.server.rstrip('.'),
'IP' : '.'.join(str(byte) for byte in service.addresses[0])})
try:
listener = MyZeroConfListener()
browser = ServiceBrowser(zeroconf, prot_type, listener)
time.sleep(timeout)
for service in listener.services:
if name.lower() in service.name.lower():
service_list.append({'NAME' : service.name,
'PORTS' : [service.port],
'SERVER' : service.server.rstrip('.'),
'IP' : '.'.join(str(byte) for byte in service.addresses[0])})
finally:
zeroconf.close()
service_list = sort_hosts(service_list)
log.debug(f"filtered list for {name}")
log.debug(service_list)
return service_list
Expand Down Expand Up @@ -104,10 +137,10 @@ def get_rtkbase_infos(host_list):
if result.get('PORTS') and len(result.get('PORTS')) > 0:
try:
for port in result.get('PORTS'):
#try with mDns server name at first, then with the ip address if it fails
for address in (result.get('SERVER'), result.get('IP')):
ans = None
# Prefer direct IP probes before falling back to the advertised mDNS name.
for address in iter_probe_addresses(result):
try:
ans = None
if address is None:
continue
log.debug(f"{address}:{port} Api request")
Expand Down Expand Up @@ -202,6 +235,7 @@ def main(ports, allscan=False, iprange=None):
available_rtkbase = get_rtkbase_infos(scan_results)
#remove duplicate
available_rtkbase = remove_duplicate_hosts(available_rtkbase)
available_rtkbase = sort_hosts(available_rtkbase)
log.debug("RTKBase station found: ")
log.debug(available_rtkbase)
return available_rtkbase
Expand All @@ -211,4 +245,4 @@ def main(ports, allscan=False, iprange=None):
if args.debug:
log.setLevel('DEBUG')
log.debug(f"Arguments: {args}")
print(main(args.ports, args.allscan, args.iprange))
print(main(args.ports, args.allscan, args.iprange))
132 changes: 132 additions & 0 deletions tools/find_rtkbase/test_scan_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

import importlib.util
import unittest
from pathlib import Path
from unittest import mock


MODULE_PATH = Path(__file__).with_name("scan_network.py")
SPEC = importlib.util.spec_from_file_location("scan_network_under_test", MODULE_PATH)
scan_network = importlib.util.module_from_spec(SPEC)
assert SPEC.loader is not None
SPEC.loader.exec_module(scan_network)


class FakeService:
def __init__(self, name, port, server, address_bytes):
self.name = name
self.port = port
self.server = server
self.addresses = [address_bytes]


class FakeZeroconf:
def __init__(self, services):
self.services = services
self.closed = False

def get_service_info(self, service_type, name):
return self.services[name]

def close(self):
self.closed = True


class FakeResponse:
def __init__(self, status_code, payload):
self.status_code = status_code
self._payload = payload

def json(self):
return self._payload


class ScanNetworkTests(unittest.TestCase):
def test_zeroconf_scan_returns_results_sorted_by_server_name(self):
fake_services = {
"svc-b": FakeService(
"RTKBase Web Server Beta",
80,
"beta.local.",
bytes([192, 168, 1, 20]),
),
"svc-a": FakeService(
"RTKBase Web Server Alpha",
80,
"alpha.local.",
bytes([192, 168, 1, 10]),
),
}
fake_zeroconf = FakeZeroconf(fake_services)

def fake_browser(zeroconf, service_type, listener):
listener.add_service(zeroconf, service_type, "svc-b")
listener.add_service(zeroconf, service_type, "svc-a")
return object()

with (
mock.patch.object(scan_network, "Zeroconf", return_value=fake_zeroconf),
mock.patch.object(scan_network, "ServiceBrowser", side_effect=fake_browser),
mock.patch.object(scan_network.time, "sleep", return_value=None),
):
results = scan_network.zeroconf_scan("RTKBase Web Server", "_http._tcp.local.")

self.assertEqual(["alpha.local", "beta.local"], [item["SERVER"] for item in results])
self.assertTrue(fake_zeroconf.closed)

def test_get_rtkbase_infos_retries_ip_before_server_name(self):
calls = []

def fake_get(url, timeout):
calls.append(url)
if len(calls) < 3:
raise scan_network.requests.exceptions.ConnectionError("network")
return FakeResponse(
200,
{"app": "RTKBase", "app_version": "2.0.0", "fqdn": "alpha.local"},
)

host_list = [{"IP": "10.0.0.5", "SERVER": "alpha.local", "PORTS": [80]}]
with mock.patch.object(scan_network.requests, "get", side_effect=fake_get):
results = scan_network.get_rtkbase_infos(host_list)

self.assertEqual(
[
"http://10.0.0.5:80/api/v1/infos",
"http://10.0.0.5:80/api/v1/infos",
"http://alpha.local:80/api/v1/infos",
],
calls,
)
self.assertEqual("alpha.local", results[0]["server"])
self.assertEqual("10.0.0.5", results[0]["ip"])

def test_get_rtkbase_infos_stops_after_successful_ip_retry(self):
calls = []

def fake_get(url, timeout):
calls.append(url)
if len(calls) == 1:
raise scan_network.requests.exceptions.ConnectionError("network")
return FakeResponse(
200,
{"app": "RTKBase", "app_version": "2.0.0", "fqdn": "alpha.local"},
)

host_list = [{"IP": "10.0.0.5", "SERVER": "alpha.local", "PORTS": [80]}]
with mock.patch.object(scan_network.requests, "get", side_effect=fake_get):
results = scan_network.get_rtkbase_infos(host_list)

self.assertEqual(
[
"http://10.0.0.5:80/api/v1/infos",
"http://10.0.0.5:80/api/v1/infos",
],
calls,
)
self.assertEqual(1, len(results))


if __name__ == "__main__":
unittest.main()