diff --git a/libsubmit/channels/ssh/ssh.py b/libsubmit/channels/ssh/ssh.py index 12dfe35..87215ed 100644 --- a/libsubmit/channels/ssh/ssh.py +++ b/libsubmit/channels/ssh/ssh.py @@ -2,6 +2,7 @@ import getpass import logging import os +from stat import S_ISDIR import paramiko from libsubmit.channels.errors import * @@ -216,7 +217,7 @@ def pull_file(self, remote_source, local_dir): if os.path.exists(local_dest): logger.exception("Remote file copy will overwrite a local file:{0}".format(local_dest)) raise FileExists(None, self.hostname, filename=local_dest) - + try: self.sftp_client.get(remote_source, local_dest) except Exception as e: @@ -227,3 +228,108 @@ def pull_file(self, remote_source, local_dir): def close(self): return self.ssh_client.close() + + def _recursive_mkdir(self, remote_path, is_filename=False): + """ + recursively create directories if they don't exist + remote_path - remote path to create. + is_filename - specifies if remote path is a filename (rather than directory) + """ + + dirs_ = [] + + if is_filename: + dir_, basename = os.path.split(remote_path) + else: + dir_ = remote_path + + while len(dir_) > 1: + dirs_.append(dir_) + dir_, _ = os.path.split(dir_) + + if len(dir_) == 1 and not dir_.startswith("/"): + dirs_.append(dir_) # For a remote_path path like y/x.txt + + while len(dirs_): + dir_ = dirs_.pop() + try: + self.sftp_client.stat(dir_) + except: + self.sftp_client.mkdir(dir_) + + def push_directory(self, local_source, remote_dir): + ''' Recursively transport directory on the remote side to a local directory + + Args: + - local_source (string): local_source + - remote_dir (string): Remote directory to copy to + + + Returns: + - str: Path to copied folder on remote machine + + Raises: + - FileExists : Name collision at local directory. + - FileCopyException : FileCopy failed. + ''' + + try: + if os.path.isdir(local_source): + _, directory_name = os.path.split(local_source) + dest = remote_dir + '/' + directory_name + + try: + self.sftp_client.stat(directory_name) + except: + self.sftp_client.mkdir(directory_name) + + file_list = os.listdir(directory_name) + for filename in file_list: + src = os.path.join(local_source, filename) + self.push_directory( src, dest ) + else: + self.push_file(local_source, remote_dir) + + except Exception as e: + logger.exception("Directory push failed") + raise FileCopyException(e, self.hostname) + + return remote_dir + + def pull_directory(self, remote_source, local_dir): + ''' Recursively transport directory on the remote side to a local directory + + Args: + - remote_source (string): remote_source + - local_dir (string): Local directory to copy to + + + Returns: + - str: Local path to folder + + Raises: + - FileExists : Name collision at local directory. + - FileCopyException : FileCopy failed. + ''' + + try: + if S_ISDIR( self.sftp_client.stat(remote_source).st_mode ): + _, directory_name = os.path.split(remote_source) + dest = os.path.join(local_dir, directory_name) + + if not os.path.exists(directory_name): + os.makedirs(directory_name) + + file_list = self.sftp_client.listdir(path=remote_source) + for filename in file_list: + src = remote_source + '/' + filename + self.pull_directory( src, dest ) + else: + self.pull_file(remote_source,local_dir) + + except Exception as e: + logger.exception("Directory pull failed") + raise FileCopyException(e, self.hostname) + + return local_dir +