Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 109 additions & 103 deletions device-connectors/src/testflinger_device_connectors/devices/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,114 @@ def SerialLogger(host=None, port=None, filename=None):
return StubSerialLogger(host, port, filename)


def import_ssh_key(key: str, keyfile: str = "key.pub") -> None:
"""Import SSH key provided in Reserve data.

:param key: SSH key to import.
:param keyfile: Output file where to store the imported key
:raises RuntimeError: If failure during import ssh keys
"""
cmd = ["ssh-import-id", "-o", keyfile, key]
for retry in range(10):
try:
subprocess.run(
cmd,
timeout=30,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
check=True,
)
logger.info("Successfully imported key: %s", key)
break

except subprocess.TimeoutExpired:
pass
except subprocess.CalledProcessError as exc:
output = (exc.stdout or b"").decode()
if "status_code=404" in output:
raise RuntimeError(
f"Failed to import ssh key: {key}. User not found."
) from exc

logger.error("Unable to import ssh key from: %s", key)
logger.info("Retrying...")
time.sleep(min(2**retry, 100))
else:
raise RuntimeError(
f"Failed to import ssh key: {key}. Maximum retries reached"
)


def copy_ssh_key(
device_ip: str,
username: str,
password: Optional[str] = None,
key: Optional[str] = None,
):
"""If provided, copy the SSH `key` to the DUT,
otherwise copy the agent's using password authentication.

:raises RuntimeError in case it can't copy the SSH keys
"""
if not key and not password:
raise ValueError("Cannot copy the agent's SSH key w/o password")

if password:
cmd = ["sshpass", "-p", password]
else:
cmd = []

cmd.extend(["ssh-copy-id", "-f"])

if key:
cmd.extend(["-i", key])

cmd.extend(
[
"-o",
"StrictHostKeyChecking=no",
"-o",
"UserKnownHostsFile=/dev/null",
"{}@{}".format(username, device_ip),
]
)

for _retry in range(10):
# Retry ssh key copy just in case it's rebooting
try:
subprocess.check_call(cmd, timeout=30)
break
except (
subprocess.CalledProcessError,
subprocess.TimeoutExpired,
):
logger.error("Error copying ssh key to device for: %s", key)
logger.info("Retrying...")
time.sleep(60)

else:
logger.error("Failed to copy ssh key: %s", key)
raise RuntimeError


def copy_ssh_keys_to_devices(ssh_keys, device_ips, test_username="ubuntu"):
"""Copy list of ssh keys to list of devices."""
for key in ssh_keys:
with contextlib.suppress(FileNotFoundError):
os.unlink("key.pub")

try:
# Import SSH Keys with ssh-import-id
import_ssh_key(key, keyfile="key.pub")

# Attempt to copy keys only if import succeeds
with contextlib.suppress(RuntimeError):
for device_ip in device_ips:
copy_ssh_key(device_ip, test_username, key="key.pub")
except RuntimeError as exc:
logger.error(exc)


class StubSerialLogger:
"""Fake SerialLogger when we don't have Serial Logger data defined."""

Expand Down Expand Up @@ -269,95 +377,6 @@ def allocate(self):
"""Allocate devices for multi-agent jobs (default method)."""
pass

def import_ssh_key(self, key: str, keyfile: str = "key.pub") -> None:
"""Import SSH key provided in Reserve data.

:param key: SSH key to import.
:param keyfile: Output file where to store the imported key
:raises RuntimeError: If failure during import ssh keys
"""
cmd = ["ssh-import-id", "-o", keyfile, key]
for retry in range(10):
try:
subprocess.run(
cmd,
timeout=30,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
check=True,
)
logger.info("Successfully imported key: %s", key)
break

except subprocess.TimeoutExpired:
pass
except subprocess.CalledProcessError as exc:
output = (exc.stdout or b"").decode()
if "status_code=404" in output:
raise RuntimeError(
f"Failed to import ssh key: {key}. User not found."
) from exc

logger.error("Unable to import ssh key from: %s", key)
logger.info("Retrying...")
time.sleep(min(2**retry, 100))
else:
raise RuntimeError(
f"Failed to import ssh key: {key}. Maximum retries reached"
)

def copy_ssh_key(
self,
device_ip: str,
username: str,
password: Optional[str] = None,
key: Optional[str] = None,
):
"""If provided, copy the SSH `key` to the DUT,
otherwise copy the agent's using password authentication.

:raises RuntimeError in case it can't copy the SSH keys
"""
if not key and not password:
raise ValueError("Cannot copy the agent's SSH key w/o password")

if password:
cmd = ["sshpass", "-p", password]
else:
cmd = []

cmd.extend(["ssh-copy-id", "-f"])

if key:
cmd.extend(["-i", key])

cmd.extend(
[
"-o",
"StrictHostKeyChecking=no",
"-o",
"UserKnownHostsFile=/dev/null",
"{}@{}".format(username, device_ip),
]
)

for _retry in range(10):
# Retry ssh key copy just in case it's rebooting
try:
subprocess.check_call(cmd, timeout=30)
break
except (
subprocess.CalledProcessError,
subprocess.TimeoutExpired,
):
logger.error("Error copying ssh key to device for: %s", key)
logger.info("Retrying...")
time.sleep(60)

else:
logger.error("Failed to copy ssh key: %s", key)
raise RuntimeError

def reserve(self, args):
"""Reserve systems (default method)."""
with open(args.config) as configfile:
Expand All @@ -373,20 +392,7 @@ def reserve(self, args):
device_ip = config["device_ip"]
reserve_data = job_data["reserve_data"]
ssh_keys = reserve_data.get("ssh_keys", [])
for key in ssh_keys:
with contextlib.suppress(FileNotFoundError):
os.unlink("key.pub")

try:
# Import SSH Keys with ssh-import-id
self.import_ssh_key(key, keyfile="key.pub")

# Attempt to copy keys only if import succeeds
with contextlib.suppress(RuntimeError):
self.copy_ssh_key(device_ip, test_username, key="key.pub")
except RuntimeError as exc:
logger.error(exc)

copy_ssh_keys_to_devices(ssh_keys, [device_ip], test_username)
# default reservation timeout is 1 hour
timeout = int(reserve_data.get("timeout", "3600"))
serial_host = config.get("serial_host")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
import logging
import os
import time
from datetime import datetime, timedelta

import requests

from testflinger_device_connectors.devices import ProvisioningError
from testflinger_device_connectors.devices import (
ProvisioningError,
copy_ssh_keys_to_devices,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -80,6 +84,45 @@ def provision(self):

self.save_job_list_file()

def reserve(self):
"""Push ssh keys to each device in reservation phase."""
logger.info("BEGIN multi device reservation")
job_data = self.job_data
try:
test_username = job_data["test_data"]["test_username"]
except KeyError:
test_username = "ubuntu"
reserve_data = job_data["reserve_data"]
ssh_keys = reserve_data.get("ssh_keys", [])
with open("job_list.json", "r") as json_file:
job_list = json.load(json_file)
device_ips = [job["device_info"]["device_ip"] for job in job_list]
copy_ssh_keys_to_devices(ssh_keys, device_ips, test_username)
print("*** TESTFLINGER SYSTEMS RESERVED ***")
print("You can now connect to the following devices:")
for job in job_list:
device_ip = job["device_info"]["device_ip"]
print(f"{test_username}@{device_ip}")

timeout = int(reserve_data.get("timeout", "3600"))
now = datetime.now().astimezone().isoformat()
expire_time = (
datetime.now().astimezone() + timedelta(seconds=timeout)
).isoformat()
print("Current time: [{}]".format(now))
print("Reservation expires at: [{}]".format(expire_time))
print(
"Reservation will automatically timeout in {} seconds".format(
timeout
)
)
job_id = job_data.get("job_id", "<job_id>")
print(
"To end the reservation sooner use: "
+ "testflinger-cli cancel {}".format(job_id)
)
time.sleep(timeout)

def terminate_if_parent_completed(self):
"""If parent job is completed or cancelled, cancel sub jobs."""
if self.this_job_completed():
Expand Down Expand Up @@ -142,7 +185,8 @@ def create_jobs(self):
updated_job = self.inject_parent_jobid(updated_job)

try:
job_id = self.client.submit_job(updated_job)
# Use agent job submission for credential inheritance
job_id = self.client.submit_agent_job(updated_job)
except requests.exceptions.HTTPError as exc:
logger.error("Unable to create job: %s", exc.response.text)
self.cancel_jobs(self.jobs)
Expand Down
Loading
Loading