Source code for dpgen.auto_test.lib.RemoteJob

#!/usr/bin/env python3

import os, sys, paramiko, json, uuid, tarfile, time, stat
from enum import Enum

[docs]class JobStatus (Enum) : unsubmitted = 1 waiting = 2 running = 3 terminated = 4 finished = 5 unknow = 100
def _default_item(resources, key, value) : if key not in resources : resources[key] = value def _set_default_resource(res) : if res == None : res = {} _default_item(res, 'numb_node', 1) _default_item(res, 'task_per_node', 1) _default_item(res, 'numb_gpu', 0) _default_item(res, 'time_limit', '1:0:0') _default_item(res, 'mem_limit', -1) _default_item(res, 'partition', '') _default_item(res, 'account', '') _default_item(res, 'qos', '') _default_item(res, 'constraint_list', []) _default_item(res, 'license_list', []) _default_item(res, 'exclude_list', []) _default_item(res, 'module_unload_list', []) _default_item(res, 'module_list', []) _default_item(res, 'source_list', []) _default_item(res, 'envs', None) _default_item(res, 'with_mpi', False)
[docs]class SSHSession (object) : def __init__ (self, jdata) : self.remote_profile = jdata # with open(remote_profile) as fp : # self.remote_profile = json.load(fp) self.remote_host = self.remote_profile['hostname'] self.remote_port = self.remote_profile['port'] self.remote_uname = self.remote_profile['username'] self.remote_password = self.remote_profile['password'] self.remote_workpath = self.remote_profile['work_path'] self.ssh = self._setup_ssh(self.remote_host, self.remote_port, username = self.remote_uname,password=self.remote_password) def _setup_ssh(self, hostname, port, username = None, password = None): ssh_client = paramiko.SSHClient() ssh_client.load_system_host_keys() ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy) ssh_client.connect(hostname, port=port, username=username, password=password) assert(ssh_client.get_transport().is_active()) return ssh_client
[docs] def get_ssh_client(self) : return self.ssh
[docs] def get_session_root(self) : return self.remote_workpath
[docs] def close(self) : self.ssh.close()
[docs]class RemoteJob (object): def __init__ (self, ssh_session, local_root ) : self.local_root = os.path.abspath(local_root) self.job_uuid = str(uuid.uuid4()) # self.job_uuid = 'a21d0017-c9f1-4d29-9a03-97df06965cef' self.remote_root = os.path.join(ssh_session.get_session_root(), self.job_uuid) print("local_root is ", local_root) print("remote_root is", self.remote_root) self.ssh = ssh_session.get_ssh_client() sftp = self.ssh.open_sftp() sftp.mkdir(self.remote_root) sftp.close() # open('job_uuid', 'w').write(self.job_uuid)
[docs] def get_job_root(self) : return self.remote_root
[docs] def upload(self, job_dirs, local_up_files, dereference = True) : cwd = os.getcwd() os.chdir(self.local_root) file_list = [] for ii in job_dirs : for jj in local_up_files : file_list.append(os.path.join(ii,jj)) self._put_files(file_list, dereference = dereference) os.chdir(cwd)
[docs] def download(self, job_dirs, remote_down_files) : cwd = os.getcwd() os.chdir(self.local_root) file_list = [] for ii in job_dirs : for jj in remote_down_files : file_list.append(os.path.join(ii,jj)) self._get_files(file_list) os.chdir(cwd)
[docs] def block_checkcall(self, cmd) : stdin, stdout, stderr = self.ssh.exec_command(('cd %s ;' % self.remote_root) + cmd) exit_status = stdout.channel.recv_exit_status() if exit_status != 0: raise RuntimeError("Get error code %d in calling through ssh with job: %s ", (exit_status, self.job_uuid)) return stdin, stdout, stderr
[docs] def block_call(self, cmd) : stdin, stdout, stderr = self.ssh.exec_command(('cd %s ;' % self.remote_root) + cmd) exit_status = stdout.channel.recv_exit_status() return exit_status, stdin, stdout, stderr
[docs] def clean(self) : sftp = self.ssh.open_sftp() self._rmtree(sftp, self.remote_root) sftp.close()
def _rmtree(self, sftp, remotepath, level=0, verbose = False): for f in sftp.listdir_attr(remotepath): rpath = os.path.join(remotepath, f.filename) if stat.S_ISDIR(f.st_mode): self._rmtree(sftp, rpath, level=(level + 1)) else: rpath = os.path.join(remotepath, f.filename) if verbose: print('removing %s%s' % (' ' * level, rpath)) sftp.remove(rpath) if verbose: print('removing %s%s' % (' ' * level, remotepath)) sftp.rmdir(remotepath) def _put_files(self, files, dereference = True) : of = self.job_uuid + '.tgz' # local tar cwd = os.getcwd() os.chdir(self.local_root) if os.path.isfile(of) : os.remove(of) with tarfile.open(of, "w:gz", dereference = dereference) as tar: for ii in files : tar.add(ii) os.chdir(cwd) # trans from_f = os.path.join(self.local_root, of) to_f = os.path.join(self.remote_root, of) sftp = self.ssh.open_sftp() sftp.put(from_f, to_f) # remote extract self.block_checkcall('tar xf %s' % of) # clean up os.remove(from_f) sftp.remove(to_f) sftp.close() def _get_files(self, files) : of = self.job_uuid + '.tgz' flist = "" for ii in files : flist += " " + ii # remote tar self.block_checkcall('tar czf %s %s' % (of, flist)) # trans from_f = os.path.join(self.remote_root, of) to_f = os.path.join(self.local_root, of) if os.path.isfile(to_f) : os.remove(to_f) sftp = self.ssh.open_sftp() sftp.get(from_f, to_f) # extract cwd = os.getcwd() os.chdir(self.local_root) with tarfile.open(of, "r:gz") as tar: tar.extractall() os.chdir(cwd) # cleanup os.remove(to_f) sftp.remove(from_f)
[docs]class CloudMachineJob (RemoteJob) :
[docs] def submit(self, job_dirs, cmd, args = None, resources = None) : #print("Current path is",os.getcwd()) #for ii in job_dirs : # if not os.path.isdir(ii) : # raise RuntimeError("cannot find dir %s" % ii) # print(self.remote_root) script_name = self._make_script(job_dirs, cmd, args, resources) self.stdin, self.stdout, self.stderr = self.ssh.exec_command(('cd %s; bash %s' % (self.remote_root, script_name)))
# print(self.stderr.read().decode('utf-8')) # print(self.stdout.read().decode('utf-8'))
[docs] def check_status(self) : if not self._check_finish(self.stdout) : return JobStatus.running elif self._get_exit_status(self.stdout) == 0 : return JobStatus.finished else : return JobStatus.terminated
def _check_finish(self, stdout) : return stdout.channel.exit_status_ready() def _get_exit_status(self, stdout) : return stdout.channel.recv_exit_status() def _make_script(self, job_dirs, cmd, args = None, resources = None) : _set_default_resource(resources) envs = resources['envs'] module_list = resources['module_list'] module_unload_list = resources['module_unload_list'] task_per_node = resources['task_per_node'] script_name = 'run.sh' if args == None : args = [] for ii in job_dirs: args.append('') script = os.path.join(self.remote_root, script_name) sftp = self.ssh.open_sftp() with sftp.open(script, 'w') as fp : fp.write('#!/bin/bash\n\n') # fp.write('set -euo pipefail\n') if envs != None : for key in envs.keys() : fp.write('export %s=%s\n' % (key, envs[key])) fp.write('\n') if module_unload_list is not None : for ii in module_unload_list : fp.write('module unload %s\n' % ii) fp.write('\n') if module_list is not None : for ii in module_list : fp.write('module load %s\n' % ii) fp.write('\n') for ii,jj in zip(job_dirs, args) : fp.write('cd %s\n' % ii) fp.write('test $? -ne 0 && exit\n') if resources['with_mpi'] == True : fp.write('mpirun -n %d %s %s\n' % (task_per_node, cmd, jj)) else : fp.write('%s %s\n' % (cmd, jj)) fp.write('test $? -ne 0 && exit\n') fp.write('cd %s\n' % self.remote_root) fp.write('test $? -ne 0 && exit\n') fp.write('\ntouch tag_finished\n') sftp.close() return script_name
[docs]class SlurmJob (RemoteJob) :
[docs] def submit(self, job_dirs, cmd, args = None, resources = None) : script_name = self._make_script(job_dirs, cmd, args, res = resources) stdin, stdout, stderr = self.block_checkcall(('cd %s; sbatch %s' % (self.remote_root, script_name))) subret = (stdout.readlines()) job_id = subret[0].split()[-1] sftp = self.ssh.open_sftp() with sftp.open(os.path.join(self.remote_root, 'job_id'), 'w') as fp: fp.write(job_id) sftp.close()
[docs] def check_status(self) : job_id = self._get_job_id() if job_id == "" : raise RuntimeError("job %s is has not been submitted" % self.remote_root) ret, stdin, stdout, stderr\ = self.block_call ("squeue --job " + job_id) err_str = stderr.read().decode('utf-8') if (ret != 0) : if str("Invalid job id specified") in err_str : if self._check_finish_tag() : return JobStatus.finished else : return JobStatus.terminated else : raise RuntimeError\ ("status command squeue fails to execute\nerror message:%s\nreturn code %d\n" % (err_str, ret)) status_line = stdout.read().decode('utf-8').split ('\n')[-2] status_word = status_line.split ()[-4] if status_word in ["PD","CF","S"] : return JobStatus.waiting elif status_word in ["R","CG"] : return JobStatus.running elif status_word in ["C","E","K","BF","CA","CD","F","NF","PR","SE","ST","TO"] : if self._check_finish_tag() : return JobStatus.finished else : return JobStatus.terminated else : return JobStatus.unknown
def _get_job_id(self) : sftp = self.ssh.open_sftp() with sftp.open(os.path.join(self.remote_root, 'job_id'), 'r') as fp: ret = fp.read().decode('utf-8') sftp.close() return ret def _check_finish_tag(self) : sftp = self.ssh.open_sftp() try: sftp.stat(os.path.join(self.remote_root, 'tag_finished')) ret = True except IOError: ret = False sftp.close() return ret def _make_script(self, job_dirs, cmd, args = None, res = None) : _set_default_resource(res) ret = '' ret += "#!/bin/bash -l\n" ret += "#SBATCH -N %d\n" % res['numb_node'] ret += "#SBATCH --ntasks-per-node %d\n" % res['task_per_node'] ret += "#SBATCH -t %s\n" % res['time_limit'] if res['mem_limit'] > 0 : ret += "#SBATCH --mem %dG \n" % res['mem_limit'] if len(res['account']) > 0 : ret += "#SBATCH --account %s \n" % res['account'] if len(res['partition']) > 0 : ret += "#SBATCH --partition %s \n" % res['partition'] if len(res['qos']) > 0 : ret += "#SBATCH --qos %s \n" % res['qos'] if res['numb_gpu'] > 0 : ret += "#SBATCH --gres=gpu:%d\n" % res['numb_gpu'] for ii in res['constraint_list'] : ret += '#SBATCH -C %s \n' % ii for ii in res['license_list'] : ret += '#SBATCH -L %s \n' % ii for ii in res['exclude_list'] : ret += '#SBATCH --exclude %s \n' % ii ret += "\n" # ret += 'set -euo pipefail\n\n' for ii in res['module_unload_list'] : ret += "module unload %s\n" % ii for ii in res['module_list'] : ret += "module load %s\n" % ii ret += "\n" for ii in res['source_list'] : ret += "source %s\n" %ii ret += "\n" envs = res['envs'] if envs != None : for key in envs.keys() : ret += 'export %s=%s\n' % (key, envs[key]) ret += '\n' if args == None : args = [] for ii in job_dirs: args.append('') for ii,jj in zip(job_dirs, args) : ret += 'cd %s\n' % ii ret += 'test $? -ne 0 && exit\n' if res['with_mpi'] : ret += 'mpirun -n %d %s %s\n' % (res['task_per_node'],cmd, jj) else : ret += '%s %s\n' % (cmd, jj) ret += 'test $? -ne 0 && exit\n' ret += 'cd %s\n' % self.remote_root ret += 'test $? -ne 0 && exit\n' ret += '\ntouch tag_finished\n' script_name = 'run.sub' script = os.path.join(self.remote_root, script_name) sftp = self.ssh.open_sftp() with sftp.open(script, 'w') as fp : fp.write(ret) sftp.close() return script_name
[docs]class PBSJob (RemoteJob) :
[docs] def submit(self, job_dirs, cmd, args = None, resources = None) : script_name = self._make_script(job_dirs, cmd, args, res = resources) stdin, stdout, stderr = self.block_checkcall(('cd %s; qsub %s' % (self.remote_root, script_name))) subret = (stdout.readlines()) job_id = subret[0].split()[0] sftp = self.ssh.open_sftp() with sftp.open(os.path.join(self.remote_root, 'job_id'), 'w') as fp: fp.write(job_id) sftp.close()
[docs] def check_status(self) : job_id = self._get_job_id() if job_id == "" : raise RuntimeError("job %s is has not been submitted" % self.remote_root) ret, stdin, stdout, stderr\ = self.block_call ("qstat " + job_id) err_str = stderr.read().decode('utf-8') if (ret != 0) : if str("qstat: Unknown Job Id") in err_str : if self._check_finish_tag() : return JobStatus.finished else : return JobStatus.terminated else : raise RuntimeError ("status command qstat fails to execute. erro info: %s return code %d" % (err_str, ret)) status_line = stdout.read().decode('utf-8').split ('\n')[-2] status_word = status_line.split ()[-2] # print (status_word) if status_word in ["Q","H"] : return JobStatus.waiting elif status_word in ["R"] : return JobStatus.running elif status_word in ["C","E","K"] : if self._check_finish_tag() : return JobStatus.finished else : return JobStatus.terminated else : return JobStatus.unknown
def _get_job_id(self) : sftp = self.ssh.open_sftp() with sftp.open(os.path.join(self.remote_root, 'job_id'), 'r') as fp: ret = fp.read().decode('utf-8') sftp.close() return ret def _check_finish_tag(self) : sftp = self.ssh.open_sftp() try: sftp.stat(os.path.join(self.remote_root, 'tag_finished')) ret = True except IOError: ret = False sftp.close() return ret def _make_script(self, job_dirs, cmd, args = None, res = None) : _set_default_resource(res) ret = '' ret += "#!/bin/bash -l\n" if res['numb_gpu'] == 0: ret += '#PBS -l nodes=%d:ppn=%d\n' % (res['numb_node'], res['task_per_node']) else : ret += '#PBS -l nodes=%d:ppn=%d:gpus=%d\n' % (res['numb_node'], res['task_per_node'], res['numb_gpu']) ret += '#PBS -l walltime=%s\n' % (res['time_limit']) if res['mem_limit'] > 0 : ret += "#PBS -l mem=%dG \n" % res['mem_limit'] ret += '#PBS -j oe\n' if len(res['partition']) > 0 : ret += '#PBS -q %s\n' % res['partition'] ret += "\n" for ii in res['module_unload_list'] : ret += "module unload %s\n" % ii for ii in res['module_list'] : ret += "module load %s\n" % ii ret += "\n" for ii in res['source_list'] : ret += "source %s\n" %ii ret += "\n" envs = res['envs'] if envs != None : for key in envs.keys() : ret += 'export %s=%s\n' % (key, envs[key]) ret += '\n' ret += 'cd $PBS_O_WORKDIR\n\n' if args == None : args = [] for ii in job_dirs: args.append('') for ii,jj in zip(job_dirs, args) : ret += 'cd %s\n' % ii ret += 'test $? -ne 0 && exit\n' if res['with_mpi'] : ret += 'mpirun -machinefile $PBS_NODEFILE -n %d %s %s\n' % (res['numb_node'] * res['task_per_node'], cmd, jj) else : ret += '%s %s\n' % (cmd, jj) ret += 'test $? -ne 0 && exit\n' ret += 'cd %s\n' % self.remote_root ret += 'test $? -ne 0 && exit\n' ret += '\ntouch tag_finished\n' script_name = 'run.sub' script = os.path.join(self.remote_root, script_name) sftp = self.ssh.open_sftp() with sftp.open(script, 'w') as fp : fp.write(ret) sftp.close() return script_name
# ssh_session = SSHSession('localhost.json') # rjob = CloudMachineJob(ssh_session, '.') # # can upload dirs and normal files # rjob.upload(['job0', 'job1'], ['batch_exec.py', 'test']) # rjob.submit(['job0', 'job1'], 'touch a; sleep 2') # while rjob.check_status() == JobStatus.running : # print('checked') # time.sleep(2) # print(rjob.check_status()) # # can download dirs and normal files # rjob.download(['job0', 'job1'], ['a']) # # rjob.clean()