Source code for dpgen.dispatcher.SSHContext

#!/usr/bin/env python
# coding: utf-8

import os, sys, paramiko, json, uuid, tarfile, time, stat, shutil
from glob import glob
from dpgen import dlog

[docs]class SSHSession (object) : def __init__ (self, jdata) : self.remote_profile = jdata # with open(remote_profile) as fp : # self.remote_profile = json.load(fp) self.remote_host = self.remote_profile['hostname'] self.remote_uname = self.remote_profile['username'] self.remote_port = self.remote_profile.get('port', 22) self.remote_password = self.remote_profile.get('password', None) self.local_key_filename = self.remote_profile.get('key_filename', None) self.remote_timeout = self.remote_profile.get('timeout', None) self.local_key_passphrase = self.remote_profile.get('passphrase', None) self.remote_workpath = self.remote_profile['work_path'] self.ssh = None self._setup_ssh(hostname=self.remote_host, port=self.remote_port, username=self.remote_uname, password=self.remote_password, key_filename=self.local_key_filename, timeout=self.remote_timeout, passphrase=self.local_key_passphrase)
[docs] def ensure_alive(self, max_check = 10, sleep_time = 10): count = 1 while not self._check_alive(): if count == max_check: raise RuntimeError('cannot connect ssh after %d failures at interval %d s' % (max_check, sleep_time)) dlog.info('connection check failed, try to reconnect to ' + self.remote_host) self._setup_ssh(hostname=self.remote_host, port=self.remote_port, username=self.remote_uname, password=self.remote_password, key_filename=self.local_key_filename, timeout=self.remote_timeout, passphrase=self.local_key_passphrase) count += 1 time.sleep(sleep_time)
def _check_alive(self): if self.ssh == None: return False try : transport = self.ssh.get_transport() transport.send_ignore() return True except EOFError: return False def _setup_ssh(self, hostname, port=22, username=None, password=None, key_filename=None, timeout=None, passphrase=None): self.ssh = paramiko.SSHClient() # ssh_client.load_system_host_keys() self.ssh.set_missing_host_key_policy(paramiko.WarningPolicy) self.ssh.connect(hostname=hostname, port=port, username=username, password=password, key_filename=key_filename, timeout=timeout, passphrase=passphrase) assert(self.ssh.get_transport().is_active()) transport = self.ssh.get_transport() transport.set_keepalive(60) # reset sftp self._sftp = None
[docs] def get_ssh_client(self) : return self.ssh
[docs] def get_session_root(self) : return self.remote_workpath
[docs] def close(self) : self.ssh.close()
[docs] def exec_command(self, cmd, retry = 0): """Calling self.ssh.exec_command but has an exception check.""" try: return self.ssh.exec_command(cmd) except paramiko.ssh_exception.SSHException: # SSH session not active # retry for up to 3 times if retry < 3: dlog.warning("SSH session not active in calling %s, retry the command..." % cmd) # ensure alive self.ensure_alive() return self.exec_command(cmd, retry = retry+1) raise RuntimeError("SSH session not active")
@property def sftp(self): """Returns sftp. Open a new one if not existing.""" if self._sftp is None: self.ensure_alive() self._sftp = self.ssh.open_sftp() return self._sftp
[docs]class SSHContext (object): def __init__ (self, local_root, ssh_session, job_uuid=None, ) : assert(type(local_root) == str) self.local_root = os.path.abspath(local_root) if job_uuid: self.job_uuid=job_uuid else: self.job_uuid = str(uuid.uuid4()) self.remote_root = os.path.join(ssh_session.get_session_root(), self.job_uuid) self.ssh_session = ssh_session self.ssh_session.ensure_alive() try: self.sftp.mkdir(self.remote_root) except Exception: pass @property def ssh(self): return self.ssh_session.get_ssh_client() @property def sftp(self): return self.ssh_session.sftp
[docs] def close(self): self.ssh_session.close()
[docs] def get_job_root(self) : return self.remote_root
[docs] def upload(self, job_dirs, local_up_files, dereference = True) : self.ssh_session.ensure_alive() cwd = os.getcwd() os.chdir(self.local_root) file_list = [] for ii in job_dirs : for jj in local_up_files : file_list.append(os.path.join(ii,jj)) self._put_files(file_list, dereference = dereference) os.chdir(cwd)
[docs] def download(self, job_dirs, remote_down_files, check_exists = False, mark_failure = True, back_error=False) : self.ssh_session.ensure_alive() cwd = os.getcwd() os.chdir(self.local_root) file_list = [] for ii in job_dirs : for jj in remote_down_files : file_name = os.path.join(ii,jj) if check_exists: if self.check_file_exists(file_name): file_list.append(file_name) elif mark_failure : with open(os.path.join(self.local_root, ii, 'tag_failure_download_%s' % jj), 'w') as fp: pass else: pass else: file_list.append(file_name) if back_error: errors=glob(os.path.join(ii,'error*')) file_list.extend(errors) if len(file_list) > 0: self._get_files(file_list) os.chdir(cwd)
[docs] def block_checkcall(self, cmd, retry=0) : self.ssh_session.ensure_alive() stdin, stdout, stderr = self.ssh_session.exec_command(('cd %s ;' % self.remote_root) + cmd) exit_status = stdout.channel.recv_exit_status() if exit_status != 0: if retry<3: # sleep 60 s dlog.warning("Get error code %d in calling %s through ssh with job: %s . message: %s" % (exit_status, cmd, self.job_uuid, stderr.read().decode('utf-8'))) dlog.warning("Sleep 60 s and retry the command...") time.sleep(60) return self.block_checkcall(cmd, retry=retry+1) raise RuntimeError("Get error code %d in calling %s through ssh with job: %s . message: %s" % (exit_status, cmd, self.job_uuid, stderr.read().decode('utf-8'))) return stdin, stdout, stderr
[docs] def block_call(self, cmd) : self.ssh_session.ensure_alive() stdin, stdout, stderr = self.ssh_session.exec_command(('cd %s ;' % self.remote_root) + cmd) exit_status = stdout.channel.recv_exit_status() return exit_status, stdin, stdout, stderr
[docs] def clean(self) : self.ssh_session.ensure_alive() sftp = self.ssh.open_sftp() self._rmtree(sftp, self.remote_root) sftp.close()
[docs] def write_file(self, fname, write_str): self.ssh_session.ensure_alive() with self.sftp.open(os.path.join(self.remote_root, fname), 'w') as fp : fp.write(write_str)
[docs] def read_file(self, fname): self.ssh_session.ensure_alive() with self.sftp.open(os.path.join(self.remote_root, fname), 'r') as fp: ret = fp.read().decode('utf-8') return ret
[docs] def check_file_exists(self, fname): self.ssh_session.ensure_alive() try: self.sftp.stat(os.path.join(self.remote_root, fname)) ret = True except IOError: ret = False return ret
[docs] def call(self, cmd): stdin, stdout, stderr = self.ssh_session.exec_command(cmd) # stdin, stdout, stderr = self.ssh.exec_command('echo $$; exec ' + cmd) # pid = stdout.readline().strip() # print(pid) return {'stdin':stdin, 'stdout':stdout, 'stderr':stderr}
[docs] def check_finish(self, cmd_pipes): return cmd_pipes['stdout'].channel.exit_status_ready()
[docs] def get_return(self, cmd_pipes): if not self.check_finish(cmd_pipes): return None, None, None else : retcode = cmd_pipes['stdout'].channel.recv_exit_status() return retcode, cmd_pipes['stdout'], cmd_pipes['stderr']
[docs] def kill(self, cmd_pipes) : raise RuntimeError('dose not work! we do not know how to kill proc through paramiko.SSHClient') self.block_checkcall('kill -15 %s' % cmd_pipes['pid'])
def _rmtree(self, sftp, remotepath, level=0, verbose = False): for f in sftp.listdir_attr(remotepath): rpath = os.path.join(remotepath, f.filename) if stat.S_ISDIR(f.st_mode): self._rmtree(sftp, rpath, level=(level + 1)) else: rpath = os.path.join(remotepath, f.filename) if verbose: dlog.info('removing %s%s' % (' ' * level, rpath)) sftp.remove(rpath) if verbose: dlog.info('removing %s%s' % (' ' * level, remotepath)) sftp.rmdir(remotepath) def _put_files(self, files, dereference = True) : of = self.job_uuid + '.tgz' # local tar cwd = os.getcwd() os.chdir(self.local_root) if os.path.isfile(of) : os.remove(of) with tarfile.open(of, "w:gz", dereference = dereference, compresslevel=6) as tar: for ii in files : tar.add(ii) os.chdir(cwd) # trans from_f = os.path.join(self.local_root, of) to_f = os.path.join(self.remote_root, of) try: self.sftp.put(from_f, to_f) except FileNotFoundError: raise FileNotFoundError("from %s to %s Error!"%(from_f,to_f)) # remote extract self.block_checkcall('tar xf %s' % of) # clean up os.remove(from_f) self.sftp.remove(to_f) def _get_files(self, files) : of = self.job_uuid + '.tar.gz' # remote tar # If the number of files are large, we may get "Argument list too long" error. # Thus, we may run tar commands for serveral times and tar only 100 files for # each time. per_nfile = 100 ntar = len(files) // per_nfile + 1 if ntar <= 1: self.block_checkcall('tar czfh %s %s' % (of, " ".join(files))) else: of_tar = self.job_uuid + '.tar' for ii in range(ntar): ff = files[per_nfile * ii : per_nfile * (ii+1)] if ii == 0: # tar cf for the first time self.block_checkcall('tar cfh %s %s' % (of_tar, " ".join(ff))) else: # append using tar rf # -r, --append append files to the end of an archive self.block_checkcall('tar rfh %s %s' % (of_tar, " ".join(ff))) # compress the tar file using gzip, and will get a tar.gz file # overwrite considering dpgen may stop and restart # -f, --force force overwrite of output file and compress links self.block_checkcall('gzip -f %s' % of_tar) # trans from_f = os.path.join(self.remote_root, of) to_f = os.path.join(self.local_root, of) if os.path.isfile(to_f) : os.remove(to_f) self.sftp.get(from_f, to_f) # extract cwd = os.getcwd() os.chdir(self.local_root) with tarfile.open(of, "r:gz") as tar: tar.extractall() os.chdir(cwd) # cleanup os.remove(to_f) self.sftp.remove(from_f)