diff --git a/aexpect/remote.py b/aexpect/remote.py index efc9e61..7df757f 100644 --- a/aexpect/remote.py +++ b/aexpect/remote.py @@ -662,6 +662,7 @@ def remote_copy( log_function=None, transfer_timeout=600, login_timeout=300, + tries=1, ): """ Transfer files using rsync or SCP, given a command line. @@ -677,25 +678,67 @@ def remote_copy( :param login_timeout: The maximal time duration (in seconds) to wait for each step of the login procedure (i.e. the "Are you sure" prompt or the password prompt) + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. """ - LOG.debug( - "Trying to copy with command '%s', timeout %ss", - command, - transfer_timeout, - ) - if log_filename: - output_func = log_function - output_params = (log_filename,) - else: - output_func = None - output_params = () method = "rsync" if "rsync" in command else "scp" - with Expect( - command, output_func=output_func, output_params=output_params - ) as session: - _remote_copy( - session, password_list, transfer_timeout, login_timeout, method - ) + + for attempt in range(tries): + try: + LOG.debug( + "Trying to copy with command '%s', timeout %ss (attempt %d/%d)", + command, + transfer_timeout, + attempt + 1, + tries, + ) + if log_filename: + output_func = log_function + output_params = (log_filename,) + else: + output_func = None + output_params = () + with Expect( + command, output_func=output_func, output_params=output_params + ) as session: + _remote_copy( + session, + password_list, + transfer_timeout, + login_timeout, + method, + ) + return # transfer is successful + except ( + TransferTimeoutError, + AuthenticationTimeoutError, + ExpectTimeoutError, + ) as error: + if attempt < tries - 1: + LOG.debug( + "Transient error on attempt %d/%d, retrying: %s", + attempt + 1, + tries, + error, + ) + time.sleep(1) # small delay before retry + else: + raise + except (TransferFailedError, SCPError, RsyncError) as error: + # For transfer failures, only retry on specific conditions + if "Connection" in str(error) or "timeout" in str(error).lower(): + if attempt < tries - 1: + LOG.debug( + "Connection error on attempt %d/%d, retrying: %s", + attempt + 1, + tries, + error, + ) + time.sleep(1) # small delay before retry + else: + raise + else: + raise def scp_to_remote( @@ -711,6 +754,7 @@ def scp_to_remote( log_function=None, timeout=600, interface=None, + tries=1, ): """ Copy files to a remote host (guest) through scp. @@ -729,6 +773,8 @@ def scp_to_remote( to complete. :param interface: The interface the neighbours attach to (only use when using ipv6 linklocal address). + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. """ if limit: limit = f"-l {limit}" @@ -753,7 +799,12 @@ def scp_to_remote( ) password_list = [password] return remote_copy( - command, password_list, log_filename, log_function, timeout + command, + password_list, + log_filename, + log_function, + timeout, + tries=tries, ) @@ -770,6 +821,7 @@ def scp_from_remote( log_function=None, timeout=600, interface=None, + tries=1, ): """ Copy files from a remote host (guest). @@ -788,6 +840,8 @@ def scp_from_remote( to complete. :param interface: The interface the neighbours attach to (only use when using ipv6 linklocal address). + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. """ if limit: limit = f"-l {limit}" @@ -810,7 +864,14 @@ def scp_from_remote( rf"{shlex.quote(local_path)}" ) password_list = [password] - remote_copy(command, password_list, log_filename, log_function, timeout) + remote_copy( + command, + password_list, + log_filename, + log_function, + timeout, + tries=tries, + ) def scp_between_remotes( @@ -830,6 +891,7 @@ def scp_between_remotes( timeout=600, src_inter=None, dst_inter=None, + tries=1, ): """ Copy files from a remote host (guest) to another remote host (guest). @@ -851,6 +913,8 @@ def scp_between_remotes( to complete. :param src_inter: The interface on local that the src neighbour attached :param dst_inter: The interface on the src that the dst neighbour attached + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. :return: True on success and False on failure. """ @@ -883,7 +947,12 @@ def scp_between_remotes( ) password_list = [s_passwd, d_passwd] return remote_copy( - command, password_list, log_filename, log_function, timeout + command, + password_list, + log_filename, + log_function, + timeout, + tries=tries, ) @@ -900,6 +969,7 @@ def rsync_to_remote( log_function=None, timeout=600, interface=None, + tries=1, ): """ Copy files to a remote host (guest) through rsync. @@ -918,6 +988,8 @@ def rsync_to_remote( to complete. :param interface: The interface the neighbours attach to (only use when using ipv6 linklocal address). + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. :raise: Whatever remote_rsync() raises """ if limit: @@ -941,7 +1013,12 @@ def rsync_to_remote( ) password_list = [password] return remote_copy( - command, password_list, log_filename, log_function, timeout + command, + password_list, + log_filename, + log_function, + timeout, + tries=tries, ) @@ -958,6 +1035,7 @@ def rsync_from_remote( log_function=None, timeout=600, interface=None, + tries=1, ): """ Copy files from a remote host (guest) through rsync. @@ -976,6 +1054,8 @@ def rsync_from_remote( to complete. :param interface: The interface the neighbours attach to (only use when using ipv6 linklocal address). + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. :raise: Whatever remote_rsync() raises """ if limit: @@ -997,7 +1077,14 @@ def rsync_from_remote( f"{username}@{host}:{quote_path(remote_path)} {shlex.quote(local_path)}" ) password_list = [password] - remote_copy(command, password_list, log_filename, log_function, timeout) + remote_copy( + command, + password_list, + log_filename, + log_function, + timeout, + tries=tries, + ) # noinspection PyBroadException @@ -1020,6 +1107,7 @@ def nc_copy_between_remotes( s_session=None, d_session=None, file_transfer_timeout=600, + tries=1, ): """ Copy files from guest to guest using netcat. @@ -1045,59 +1133,91 @@ def nc_copy_between_remotes( :param d_session: A shell session object for dst or None. :param check_sum: Whether to run checksum for the operation. :param file_transfer_timeout: Timeout for file transfer. - + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. :return: True on success and False on failure. """ - check_string = "NCFT" - if not s_session: - s_session = remote_login( - c_type, src, s_port, s_name, s_passwd, c_prompt - ) - if not d_session: - d_session = remote_login( - c_type, dst, s_port, d_name, d_passwd, c_prompt - ) + for attempt in range(tries): + try: + check_string = "NCFT" + if not s_session: + s_session = remote_login( + c_type, src, s_port, s_name, s_passwd, c_prompt + ) + if not d_session: + d_session = remote_login( + c_type, dst, s_port, d_name, d_passwd, c_prompt + ) - try: - s_session.cmd(f"iptables -I INPUT -p {d_protocol} -j ACCEPT") - d_session.cmd(f"iptables -I OUTPUT -p {d_protocol} -j ACCEPT") - except Exception: # pylint: disable=W0703 - pass - - LOG.info("Transfer data using netcat from %s to %s", src, dst) - cmd = f"nc -w {timeout}" - if d_protocol == "udp": - cmd += " -u" - receive_cmd = f"echo {check_string} | {cmd} -l {d_port} > {d_path}" - d_session.sendline(receive_cmd) - send_cmd = f"{cmd} {dst} {d_port} < {s_path}" - status, output = s_session.cmd_status_output( - send_cmd, timeout=file_transfer_timeout - ) - if status: - err = f"Fail to transfer file between {src} -> {dst}." - if check_string not in output: - err += ( - "src did not receive check " - f"string {check_string} sent by dst." + try: + s_session.cmd(f"iptables -I INPUT -p {d_protocol} -j ACCEPT") + d_session.cmd(f"iptables -I OUTPUT -p {d_protocol} -j ACCEPT") + except Exception: # pylint: disable=W0703 + pass + + LOG.info( + "Transfer data using netcat from %s to %s (attempt %d/%d)", + src, + dst, + attempt + 1, + tries, ) - err += f"send nc command {send_cmd}, output {output}" - err += f"Receive nc command {receive_cmd}." - raise NetcatTransferFailedError(status, err) - - if check_sum: - LOG.info("md5sum cmd = md5sum %s", s_path) - output = s_session.cmd(f"md5sum {s_path}") - src_md5 = output.split()[0] - dst_md5 = d_session.cmd(f"md5sum {d_path}").split()[0] - if src_md5.strip() != dst_md5.strip(): - err_msg = ( - "Files md5sum mismatch, " - f"file {s_path} md5sum is '{src_md5}', " - f"but the file {d_path} md5sum is {dst_md5}" + cmd = f"nc -w {timeout}" + if d_protocol == "udp": + cmd += " -u" + receive_cmd = f"echo {check_string} | {cmd} -l {d_port} > {d_path}" + d_session.sendline(receive_cmd) + send_cmd = f"{cmd} {dst} {d_port} < {s_path}" + status, output = s_session.cmd_status_output( + send_cmd, timeout=file_transfer_timeout ) - raise NetcatTransferIntegrityError(err_msg) - return True + if status: + err = f"Fail to transfer file between {src} -> {dst}." + if check_string not in output: + err += ( + "src did not receive check " + f"string {check_string} sent by dst." + ) + err += f"send nc command {send_cmd}, output {output}" + err += f"Receive nc command {receive_cmd}." + raise NetcatTransferFailedError(status, err) + + if check_sum: + LOG.info("md5sum cmd = md5sum %s", s_path) + output = s_session.cmd(f"md5sum {s_path}") + src_md5 = output.split()[0] + dst_md5 = d_session.cmd(f"md5sum {d_path}").split()[0] + if src_md5.strip() != dst_md5.strip(): + err_msg = ( + "Files md5sum mismatch, " + f"file {s_path} md5sum is '{src_md5}', " + f"but the file {d_path} md5sum is {dst_md5}" + ) + raise NetcatTransferIntegrityError(err_msg) + return True + except ( + NetcatTransferTimeoutError, + NetcatTransferFailedError, + UDPError, + ) as error: + if attempt < tries - 1: + LOG.debug( + "Transfer failed on attempt %d/%d, retrying: %s", + attempt + 1, + tries, + error, + ) + time.sleep(1) # small delay before retry + # reset sessions for retry + if s_session: + s_session.close() + if d_session: + d_session.close() + s_session = None + d_session = None + else: + raise + return False def udp_copy_between_remotes( @@ -1114,6 +1234,7 @@ def udp_copy_between_remotes( c_prompt="\n", d_port="9000", timeout=600, + tries=1, ): """ Copy files from guest to guest using udp. @@ -1131,9 +1252,9 @@ def udp_copy_between_remotes( :param c_prompt: command line prompt of remote host(guest) :param d_port: the port data transfer :param timeout: data transfer timeout + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. """ - s_session = remote_login(c_type, src, s_port, s_name, s_passwd, c_prompt) - d_session = remote_login(c_type, dst, s_port, d_name, d_passwd, c_prompt) def get_abs_path(session, filename, extension): """Return file path drive+path.""" @@ -1215,23 +1336,43 @@ def stop_server(session): if server_alive(session): session.cmd_output_safe(stop_cmd) - try: - src_md5 = get_file_md5(s_session, s_path) - if not server_alive(s_session): - start_server(s_session) - start_client(d_session) - dst_md5 = get_file_md5(d_session, d_path) - if src_md5 != dst_md5: - err_msg = ( - "Files md5sum mismatch, " - f"file {s_path} md5sum is '{src_md5}', " - f"but the file {d_path} md5sum is {dst_md5}" + for attempt in range(tries): + try: + s_session = remote_login( + c_type, src, s_port, s_name, s_passwd, c_prompt ) - raise UDPError(err_msg) - finally: - stop_server(s_session) - s_session.close() - d_session.close() + d_session = remote_login( + c_type, dst, s_port, d_name, d_passwd, c_prompt + ) + try: + src_md5 = get_file_md5(s_session, s_path) + if not server_alive(s_session): + start_server(s_session) + start_client(d_session) + dst_md5 = get_file_md5(d_session, d_path) + if src_md5 != dst_md5: + err_msg = ( + "Files md5sum mismatch, " + f"file {s_path} md5sum is '{src_md5}', " + f"but the file {d_path} md5sum is {dst_md5}" + ) + raise UDPError(err_msg) + finally: + stop_server(s_session) + s_session.close() + d_session.close() + return # transfer is successful + except UDPError as error: + if attempt < tries - 1: + LOG.debug( + "UDP transfer failed on attempt %d/%d, retrying: %s", + attempt + 1, + tries, + error, + ) + time.sleep(1) # small delay before retry + else: + raise def login_from_session( @@ -1284,6 +1425,7 @@ def scp_to_session( log_function=None, timeout=600, interface=None, + tries=1, ): """ Secure copy a filepath (w/o wildcard) to a remote location with the same @@ -1299,6 +1441,8 @@ def scp_to_session( :param log_function: Function to perform logging :param timeout: Timeout for the scp operation :param interface: Interface used for the transfer + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. The rest of the arguments are identical to scp_to_remote(). """ @@ -1315,6 +1459,7 @@ def scp_to_session( log_function, timeout, interface, + tries, ) @@ -1328,6 +1473,7 @@ def scp_from_session( log_function=None, timeout=600, interface=None, + tries=1, ): """ Secure copy a filepath (w/o wildcard) from a remote location with the same @@ -1343,6 +1489,8 @@ def scp_from_session( :param log_function: Function to perform logging :param timeout: Timeout for the scp operation :param interface: Interface used for the transfer + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. The rest of the arguments are identical to scp_from_remote(). """ @@ -1359,6 +1507,7 @@ def scp_from_session( log_function, timeout, interface, + tries, ) @@ -1406,6 +1555,7 @@ def copy_files_to( timeout=600, interface=None, filesize=None, # pylint: disable=unused-argument + tries=1, ): """ Copy files to a remote host (guest) using the selected client. @@ -1427,6 +1577,8 @@ def copy_files_to( :param interface: The interface the neighbours attach to (only use when using ipv6 linklocal address.) :param filesize: size of file will be transferred + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. """ if client == "scp": scp_to_remote( @@ -1442,6 +1594,7 @@ def copy_files_to( log_function, timeout, interface=interface, + tries=tries, ) elif client == "rsync": rsync_to_remote( @@ -1457,6 +1610,7 @@ def copy_files_to( log_function, timeout, interface=interface, + tries=tries, ) elif client == "rss": log_func = None @@ -1464,9 +1618,23 @@ def copy_files_to( log_func = LOG.debug if interface: address = f"{address}%{interface}" - fdclient = rss_client.FileUploadClient(address, port, log_func) - fdclient.upload(local_path, remote_path, timeout) - fdclient.close() + for attempt in range(tries): + try: + fdclient = rss_client.FileUploadClient(address, port, log_func) + fdclient.upload(local_path, remote_path, timeout) + fdclient.close() + return # transfer is successful + except Exception as error: # pylint: disable=broad-except + if attempt < tries - 1: + LOG.debug( + "RSS upload failed on attempt %d/%d, retrying: %s", + attempt + 1, + tries, + error, + ) + time.sleep(1) # small delay before retry + else: + raise else: raise TransferBadClientError(client) @@ -1489,6 +1657,7 @@ def copy_files_from( timeout=600, interface=None, filesize=None, # pylint: disable=unused-argument + tries=1, ): """ Copy files from a remote host (guest) using the selected client. @@ -1510,6 +1679,8 @@ def copy_files_from( :param interface: The interface the neighbours attach to (only use when using ipv6 linklocal address.) :param filesize: size of file will be transferred + :param tries: Number of attempts to make to deal with transient errors like + timeouts and connection issues. """ if client == "scp": scp_from_remote( @@ -1525,6 +1696,7 @@ def copy_files_from( log_function, timeout, interface=interface, + tries=tries, ) elif client == "rsync": rsync_from_remote( @@ -1540,6 +1712,7 @@ def copy_files_from( log_function, timeout, interface=interface, + tries=tries, ) elif client == "rss": log_func = None @@ -1547,8 +1720,24 @@ def copy_files_from( log_func = LOG.debug if interface: address = f"{address}%{interface}" - fdclient = rss_client.FileDownloadClient(address, port, log_func) - fdclient.download(remote_path, local_path, timeout) - fdclient.close() + for attempt in range(tries): + try: + fdclient = rss_client.FileDownloadClient( + address, port, log_func + ) + fdclient.download(remote_path, local_path, timeout) + fdclient.close() + return # transfer is successful + except Exception as error: # pylint: disable=broad-except + if attempt < tries - 1: + LOG.debug( + "RSS download failed on attempt %d/%d, retrying: %s", + attempt + 1, + tries, + error, + ) + time.sleep(1) # small delay before retry + else: + raise else: raise TransferBadClientError(client) diff --git a/tests/test_remote.py b/tests/test_remote.py index a590360..65b8cbb 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -59,6 +59,64 @@ def test_wait_for_login(self): " -o PreferredAuthentications=password user@127.0.0.1", ) + @mock.patch("aexpect.remote._remote_copy") + def test_remote_copy(self, mock_remote_copy): + remote.remote_copy("cp a b", ["pass"], "/local/path", "/remote/path") + mock_remote_copy.assert_called_once_with( + mock.ANY, + ["pass"], + 600, + 300, + "scp", + ) + self.assertEqual( + mock_remote_copy.call_args[0][0].command, + r"cp a b", + ) + + @mock.patch("aexpect.remote._remote_copy") + def test_remote_copy_retry(self, mock_remote_copy): + remote.remote_copy( + "cp a b", + ["pass"], + "/local/path", + "/remote/path", + tries=2, + ) + mock_remote_copy.assert_called_once_with( + mock.ANY, + ["pass"], + 600, + 300, + "scp", + ) + self.assertEqual( + mock_remote_copy.call_args[0][0].command, + r"cp a b", + ) + + mock_remote_copy.reset_mock() + mock_remote_copy.side_effect = [ + remote.SCPError("Copy failed", "Connection lost"), + None, + ] + remote.remote_copy( + "cp a b", + ["pass"], + "/local/path", + "/remote/path", + tries=2, + ) + self.assertEqual(mock_remote_copy.call_count, 2) + self.assertEqual( + mock_remote_copy.call_args_list[0][0][0].command, + r"cp a b", + ) + self.assertEqual( + mock_remote_copy.call_args_list[1][0][0].command, + r"cp a b", + ) + @mock.patch("aexpect.remote._remote_copy") def test_scp_to_remote(self, mock_remote_copy): remote.scp_to_remote( @@ -85,19 +143,6 @@ def test_scp_from_remote(self, mock_remote_copy): r"scp -r -v -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -o PreferredAuthentications=password -P 22 user@\[127.0.0.1\]:/remote/path /local/path", ) - @mock.patch("aexpect.remote._remote_copy") - def test_rsync_to_remote(self, mock_remote_copy): - remote.rsync_to_remote( - "127.0.0.1", 22, "user", "pass", "/local/path", "/remote/path" - ) - mock_remote_copy.assert_called_once_with( - mock.ANY, ["pass"], 600, 300, "rsync" - ) - self.assertEqual( - mock_remote_copy.call_args[0][0].command, - r"rsync -r -avz -e 'ssh -Tp 22 -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no' /local/path user@127.0.0.1:/remote/path", - ) - @mock.patch("aexpect.remote._remote_copy") def test_scp_between_remotes(self, mock_remote_copy): remote.scp_between_remotes( @@ -119,6 +164,19 @@ def test_scp_between_remotes(self, mock_remote_copy): r"scp -r -v -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -o PreferredAuthentications=password -P 22 src_user@\[src_host\]:/src/path dst_user@\[dst_host\]:/dst/path", ) + @mock.patch("aexpect.remote._remote_copy") + def test_rsync_to_remote(self, mock_remote_copy): + remote.rsync_to_remote( + "127.0.0.1", 22, "user", "pass", "/local/path", "/remote/path" + ) + mock_remote_copy.assert_called_once_with( + mock.ANY, ["pass"], 600, 300, "rsync" + ) + self.assertEqual( + mock_remote_copy.call_args[0][0].command, + r"rsync -r -avz -e 'ssh -Tp 22 -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no' /local/path user@127.0.0.1:/remote/path", + ) + @mock.patch("aexpect.remote._remote_copy") def test_rsync_from_remote(self, mock_remote_copy): remote.rsync_from_remote(