Source code for dpgen.simplify.simplify

"""Simplify dataset (minimize the dataset size).

Init:
pick up init data from dataset randomly

Iter:
00: train models (same as generator)
01: calculate model deviations of the rest dataset, pick up data with proper model deviaiton 
02: fp (optional, if the original dataset do not have fp data, same as generator)
"""
import logging
import warnings
import queue
import os
import json
import argparse
import pickle
import glob
import fnmatch
import dpdata
import numpy as np

from dpgen import dlog
from dpgen import SHORT_CMD
from dpgen.util import sepline, expand_sys_str
from distutils.version import LooseVersion
from dpgen.dispatcher.Dispatcher import Dispatcher, _split_tasks, make_dispatcher, make_submission
from dpgen.generator.run import make_train, run_train, post_train, run_fp, post_fp, fp_name, model_devi_name, train_name, train_task_fmt, sys_link_fp_vasp_pp, make_fp_vasp_incar, make_fp_vasp_kp, make_fp_vasp_cp_cvasp, data_system_fmt, model_devi_task_fmt, fp_task_fmt
# TODO: maybe the following functions can be moved to dpgen.util
from dpgen.generator.lib.utils import log_iter, make_iter_name, create_path, record_iter
from dpgen.generator.lib.gaussian import make_gaussian_input
from dpgen.remote.decide_machine import  convert_mdata


picked_data_name = "data.picked"
rest_data_name = "data.rest"
accurate_data_name = "data.accurate"
detail_file_name_prefix = "details"
sys_name_fmt = 'sys.' + data_system_fmt
sys_name_pattern = 'sys.[0-9]*[0-9]'


[docs]def get_system_cls(jdata): if jdata.get("labeled", False): return dpdata.LabeledSystem return dpdata.System
[docs]def get_multi_system(path, jdata): system = get_system_cls(jdata) system_paths = expand_sys_str(path) systems = dpdata.MultiSystems( *[system(s, fmt='deepmd/npy') for s in system_paths]) return systems
[docs]def init_model(iter_index, jdata, mdata): training_init_model = jdata.get('training_init_model', False) if not training_init_model: return iter0_models = [] training_iter0_model = jdata.get('training_iter0_model_path', []) if type(training_iter0_model) == str: training_iter0_model = [training_iter0_model] for ii in training_iter0_model: model_is = glob.glob(ii) model_is.sort() iter0_models += [os.path.abspath(ii) for ii in model_is] numb_models = jdata['numb_models'] assert(numb_models == len(iter0_models)), "training_iter0_model_path should be provided, and the number of models should be equal to %d" % numb_models work_path = os.path.join(make_iter_name(iter_index), train_name) create_path(work_path) cwd = os.getcwd() for ii in range(len(iter0_models)): train_path = os.path.join(work_path, train_task_fmt % ii) create_path(train_path) os.chdir(train_path) ckpt_files = glob.glob(os.path.join(iter0_models[ii], 'model.ckpt*')) for jj in ckpt_files: os.symlink(jj, os.path.basename(jj)) os.chdir(cwd)
[docs]def init_pick(iter_index, jdata, mdata): """pick up init data from dataset randomly""" pick_data = jdata['pick_data'] init_pick_number = jdata['init_pick_number'] # use MultiSystems with System # TODO: support System and LabeledSystem # TODO: support other format systems = get_multi_system(pick_data, jdata) # label the system labels = [] items = systems.systems.items() for key, system in items: labels.extend([(key, j) for j in range(len(system))]) # random pick iter_name = make_iter_name(iter_index) work_path = os.path.join(iter_name, model_devi_name) create_path(work_path) idx = np.arange(len(labels)) np.random.shuffle(idx) pick_idx = idx[:init_pick_number] rest_idx = idx[init_pick_number:] # dump the init data sys_data_path = os.path.join(work_path, picked_data_name) _init_dump_selected_frames(systems, labels, pick_idx, sys_data_path, jdata) # dump the rest data sys_data_path = os.path.join(work_path, rest_data_name) _init_dump_selected_frames(systems, labels, rest_idx, sys_data_path, jdata)
def _init_dump_selected_frames(systems, labels, selc_idx, sys_data_path, jdata): selc_systems = dpdata.MultiSystems() for j in selc_idx: sys_name, sys_id = labels[j] selc_systems.append(systems[sys_name][sys_id]) selc_systems.to_deepmd_raw(sys_data_path) selc_systems.to_deepmd_npy(sys_data_path, set_size=selc_idx.size)
[docs]def make_model_devi(iter_index, jdata, mdata): """calculate the model deviation of the rest idx""" pick_data = jdata['pick_data'] iter_name = make_iter_name(iter_index) work_path = os.path.join(iter_name, model_devi_name) create_path(work_path) # link the model train_path = os.path.join(iter_name, train_name) train_path = os.path.abspath(train_path) models = glob.glob(os.path.join(train_path, "graph*pb")) for mm in models: model_name = os.path.basename(mm) os.symlink(mm, os.path.join(work_path, model_name)) # link the last rest data last_iter_name = make_iter_name(iter_index-1) rest_data_path = os.path.join(last_iter_name, model_devi_name, rest_data_name) if not os.path.exists(rest_data_path): return False os.symlink(os.path.abspath(rest_data_path), os.path.join(work_path, rest_data_name + ".old")) return True
[docs]def run_model_devi(iter_index, jdata, mdata): """submit dp test tasks""" iter_name = make_iter_name(iter_index) work_path = os.path.join(iter_name, model_devi_name) # generate command commands = [] run_tasks = ["."] # get models models = glob.glob(os.path.join(work_path, "graph*pb")) model_names = [os.path.basename(ii) for ii in models] task_model_list = [] for ii in model_names: task_model_list.append(os.path.join('.', ii)) # models commands = [] detail_file_name = detail_file_name_prefix command = "{dp} model-devi -m {model} -s {system} -o {detail_file}".format( dp=mdata.get('model_devi_command', 'dp'), model=" ".join(task_model_list), system=rest_data_name + ".old", detail_file=detail_file_name, ) commands = [command] # submit model_devi_group_size = mdata.get('model_devi_group_size', 1) forward_files = [rest_data_name + ".old"] backward_files = [detail_file_name] api_version = mdata.get('api_version', '0.9') if LooseVersion(api_version) < LooseVersion('1.0'): warnings.warn(f"the dpdispatcher will be updated to new version." f"And the interface may be changed. Please check the documents for more details") dispatcher = make_dispatcher(mdata['model_devi_machine'], mdata['model_devi_resources'], work_path, run_tasks, model_devi_group_size) dispatcher.run_jobs(mdata['model_devi_resources'], commands, work_path, run_tasks, model_devi_group_size, model_names, forward_files, backward_files, outlog = 'model_devi.log', errlog = 'model_devi.log') elif LooseVersion(api_version) >= LooseVersion('1.0'): submission = make_submission( mdata['model_devi_machine'], mdata['model_devi_resources'], commands=commands, work_path=work_path, run_tasks=run_tasks, group_size=model_devi_group_size, forward_common_files=model_names, forward_files=forward_files, backward_files=backward_files, outlog = 'model_devi.log', errlog = 'model_devi.log') submission.run_submission()
[docs]def post_model_devi(iter_index, jdata, mdata): """calculate the model deviation""" iter_name = make_iter_name(iter_index) work_path = os.path.join(iter_name, model_devi_name) f_trust_lo = jdata['model_devi_f_trust_lo'] f_trust_hi = jdata['model_devi_f_trust_hi'] sys_accurate = dpdata.MultiSystems() sys_candinate = dpdata.MultiSystems() sys_failed = dpdata.MultiSystems() labeled = jdata.get("labeled", False) type_map = jdata.get("type_map", []) sys_entire = dpdata.MultiSystems(type_map = type_map).from_deepmd_npy(os.path.join(work_path, rest_data_name + ".old"), labeled=labeled) detail_file_name = detail_file_name_prefix with open(os.path.join(work_path, detail_file_name)) as f: for line in f: if line.startswith("# data.rest.old"): name = (line.split()[1]).split("/")[-1] elif line.startswith("#"): pass else: idx = int(line.split()[0]) f_devi = float(line.split()[4]) subsys = sys_entire[name][idx] if f_trust_lo <= f_devi < f_trust_hi: sys_candinate.append(subsys) elif f_devi >= f_trust_hi: sys_failed.append(subsys) elif f_devi < f_trust_lo: sys_accurate.append(subsys) else: raise RuntimeError('reach a place that should NOT be reached...') counter = {"candidate": sys_candinate.get_nframes(), "accurate": sys_accurate.get_nframes(), "failed": sys_failed.get_nframes()} fp_sum = sum(counter.values()) for cc_key, cc_value in counter.items(): dlog.info("{0:9s} : {1:6d} in {2:6d} {3:6.2f} %".format(cc_key, cc_value, fp_sum, cc_value/fp_sum*100)) if counter['candidate'] == 0 and counter['failed'] > 0: raise RuntimeError('no candidate but still have failed cases, stop. You may want to refine the training or to increase the trust level hi') # label the candidate system labels = [] items = sys_candinate.systems.items() for key, system in items: labels.extend([(key, j) for j in range(len(system))]) # candinate: pick up randomly iter_pick_number = jdata['iter_pick_number'] idx = np.arange(counter['candidate']) assert(len(idx) == len(labels)) np.random.shuffle(idx) pick_idx = idx[:iter_pick_number] rest_idx = idx[iter_pick_number:] if(counter['candidate'] == 0) : dlog.info("no candidate") else : dlog.info("total candidate {0:6d} picked {1:6d} ({2:6.2f} %) rest {3:6d} ({4:6.2f} % )".format\ (counter['candidate'], len(pick_idx), float(len(pick_idx))/counter['candidate']*100., len(rest_idx), float(len(rest_idx))/counter['candidate']*100.)) # dump the picked candinate data picked_systems = dpdata.MultiSystems() for j in pick_idx: sys_name, sys_id = labels[j] picked_systems.append(sys_candinate[sys_name][sys_id]) sys_data_path = os.path.join(work_path, picked_data_name) picked_systems.to_deepmd_raw(sys_data_path) picked_systems.to_deepmd_npy(sys_data_path, set_size=iter_pick_number) # dump the rest data (not picked candinate data and failed data) rest_systems = dpdata.MultiSystems() for j in rest_idx: sys_name, sys_id = labels[j] rest_systems.append(sys_candinate[sys_name][sys_id]) rest_systems += sys_failed sys_data_path = os.path.join(work_path, rest_data_name) rest_systems.to_deepmd_raw(sys_data_path) if rest_idx.size: rest_systems.to_deepmd_npy(sys_data_path, set_size=rest_idx.size) # dump the accurate data -- to another directory sys_data_path = os.path.join(work_path, accurate_data_name) sys_accurate.to_deepmd_raw(sys_data_path) sys_accurate.to_deepmd_npy(sys_data_path, set_size=sys_accurate.get_nframes())
[docs]def make_fp_labeled(iter_index, jdata): dlog.info("already labeled, skip make_fp and link data directly") pick_data = jdata['pick_data'] iter_name = make_iter_name(iter_index) work_path = os.path.join(iter_name, fp_name) create_path(work_path) picked_data_path = os.path.join(iter_name, model_devi_name, picked_data_name) os.symlink(os.path.abspath(picked_data_path), os.path.abspath( os.path.join(work_path, "task." + data_system_fmt % 0))) os.symlink(os.path.abspath(picked_data_path), os.path.abspath( os.path.join(work_path, "data." + data_system_fmt % 0)))
[docs]def make_fp_configs(iter_index, jdata): pick_data = jdata['pick_data'] iter_name = make_iter_name(iter_index) work_path = os.path.join(iter_name, fp_name) create_path(work_path) picked_data_path = os.path.join(iter_name, model_devi_name, picked_data_name) systems = get_multi_system(picked_data_path, jdata) ii = 0 jj = 0 for system in systems: for subsys in system: task_name = "task." + fp_task_fmt % (ii, jj) task_path = os.path.join(work_path, task_name) create_path(task_path) subsys.to('vasp/poscar', os.path.join(task_path, 'POSCAR')) jj += 1 ii += 1
[docs]def make_fp_gaussian(iter_index, jdata): work_path = os.path.join(make_iter_name(iter_index), fp_name) fp_tasks = glob.glob(os.path.join(work_path, 'task.*')) cwd = os.getcwd() if 'user_fp_params' in jdata.keys() : fp_params = jdata['user_fp_params'] else: fp_params = jdata['fp_params'] cwd = os.getcwd() for ii in fp_tasks: os.chdir(ii) sys_data = dpdata.System('POSCAR').data ret = make_gaussian_input(sys_data, fp_params) with open('input', 'w') as fp: fp.write(ret) os.chdir(cwd)
[docs]def make_fp_vasp(iter_index, jdata): # abs path for fp_incar if it exists if 'fp_incar' in jdata: jdata['fp_incar'] = os.path.abspath(jdata['fp_incar']) # get nbands esti if it exists if 'fp_nbands_esti_data' in jdata: nbe = NBandsEsti(jdata['fp_nbands_esti_data']) else: nbe = None # order is critical! # 1, create potcar sys_link_fp_vasp_pp(iter_index, jdata) # 2, create incar make_fp_vasp_incar(iter_index, jdata, nbands_esti = nbe) # 3, create kpoints make_fp_vasp_kp(iter_index, jdata) # 4, copy cvasp make_fp_vasp_cp_cvasp(iter_index,jdata)
[docs]def make_fp_calculation(iter_index, jdata): fp_style = jdata['fp_style'] if fp_style == 'vasp': make_fp_vasp(iter_index, jdata) elif fp_style == 'gaussian': make_fp_gaussian(iter_index, jdata) else : raise RuntimeError('unsupported fp_style ' + fp_style)
[docs]def make_fp(iter_index, jdata, mdata): labeled = jdata.get("labeled", False) if labeled: make_fp_labeled(iter_index, jdata) else: make_fp_configs(iter_index, jdata) make_fp_calculation(iter_index, jdata)
[docs]def run_iter(param_file, machine_file): """ init (iter 0): init_pick tasks (iter > 0): 00 make_train (same as generator) 01 run_train (same as generator) 02 post_train (same as generator) 03 make_model_devi 04 run_model_devi 05 post_model_devi 06 make_fp 07 run_fp (same as generator) 08 post_fp (same as generator) """ # TODO: function of handling input json should be combined as one function try: import ruamel from monty.serialization import loadfn, dumpfn warnings.simplefilter( 'ignore', ruamel.yaml.error.MantissaNoDotYAML1_1Warning) jdata = loadfn(param_file) mdata = loadfn(machine_file) except Exception: with open(param_file, 'r') as fp: jdata = json.load(fp) with open(machine_file, 'r') as fp: mdata = json.load(fp) if jdata.get('pretty_print', False): fparam = SHORT_CMD+'_' + \ param_file.split('.')[0]+'.'+jdata.get('pretty_format', 'json') dumpfn(jdata, fparam, indent=4) fmachine = SHORT_CMD+'_' + \ machine_file.split('.')[0]+'.'+jdata.get('pretty_format', 'json') dumpfn(mdata, fmachine, indent=4) if mdata.get('handlers', None): if mdata['handlers'].get('smtp', None): que = queue.Queue(-1) queue_handler = logging.handlers.QueueHandler(que) smtp_handler = logging.handlers.SMTPHandler( **mdata['handlers']['smtp']) listener = logging.handlers.QueueListener(que, smtp_handler) dlog.addHandler(queue_handler) listener.start() mdata = convert_mdata(mdata) max_tasks = 10000 numb_task = 9 record = "record.dpgen" iter_rec = [0, -1] if os.path.isfile(record): with open(record) as frec: for line in frec: iter_rec = [int(x) for x in line.split()] dlog.info("continue from iter %03d task %02d" % (iter_rec[0], iter_rec[1])) cont = True ii = -1 while cont: ii += 1 iter_name = make_iter_name(ii) sepline(iter_name, '=') for jj in range(numb_task): if ii * max_tasks + jj <= iter_rec[0] * max_tasks + iter_rec[1]: continue task_name = "task %02d" % jj sepline("{} {}".format(iter_name, task_name), '-') jdata['model_devi_jobs'] = [{} for _ in range(ii+1)] if ii == 0 and jj < 6: if jj == 0: log_iter("init_pick", ii, jj) init_model(ii, jdata, mdata) init_pick(ii, jdata, mdata) dlog.info("first iter, skip step 1-5") elif jj == 0: log_iter("make_train", ii, jj) make_train(ii, jdata, mdata) elif jj == 1: log_iter("run_train", ii, jj) #disp = make_dispatcher(mdata['train_machine']) run_train(ii, jdata, mdata) elif jj == 2: log_iter("post_train", ii, jj) post_train(ii, jdata, mdata) elif jj == 3: log_iter("make_model_devi", ii, jj) cont = make_model_devi(ii, jdata, mdata) if not cont or ii >= jdata.get("stop_iter", ii+1): break elif jj == 4: log_iter("run_model_devi", ii, jj) #disp = make_dispatcher(mdata['model_devi_machine']) run_model_devi(ii, jdata, mdata) elif jj == 5: log_iter("post_model_devi", ii, jj) post_model_devi(ii, jdata, mdata) elif jj == 6: log_iter("make_fp", ii, jj) make_fp(ii, jdata, mdata) elif jj == 7: log_iter("run_fp", ii, jj) if jdata.get("labeled", False): dlog.info("already have labeled data, skip run_fp") else: #disp = make_dispatcher(mdata['fp_machine']) run_fp(ii, jdata, mdata) elif jj == 8: log_iter("post_fp", ii, jj) if jdata.get("labeled", False): dlog.info("already have labeled data, skip post_fp") else: post_fp(ii, jdata) else: raise RuntimeError("unknown task %d, something wrong" % jj) record_iter(record, ii, jj)
[docs]def gen_simplify(args): if args.PARAM and args.MACHINE: if args.debug: dlog.setLevel(logging.DEBUG) dlog.info("start simplifying") run_iter(args.PARAM, args.MACHINE) dlog.info("finished")