From 6d4daa7e0628019d9621dfdec33b77962c747602 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Mon, 26 Aug 2024 16:27:20 -0400 Subject: [PATCH] Refined strategy for selecting configs for pretrained multihead --- mace/cli/fine_tuning_select.py | 229 ++++++++++++++++++--------------- mace/tools/arg_parser.py | 6 + mace/tools/multihead_tools.py | 4 +- 3 files changed, 131 insertions(+), 108 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index f3b7462f..d1933724 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -6,6 +6,8 @@ import logging from typing import List +from tqdm import tqdm + import ase.data import ase.io import numpy as np @@ -82,8 +84,8 @@ def parse_args() -> argparse.Namespace: "--filtering_type", help="filtering type", type=str, - choices=[None, "combinations", "exclusive", "inclusive"], - default=None, + choices=["none", "subset", "exact", "superset", "any_overlap"], + default="subset", ) parser.add_argument( "--weight_ft", @@ -114,40 +116,36 @@ def calculate_descriptors(atoms: List[ase.Atoms], calc: MACECalculator) -> None: def filter_atoms( - atoms: ase.Atoms, element_subset: List[str], filtering_type: str + atoms: ase.Atoms, selected_elements: List[str], filtering_type: str ) -> bool: """ Filters atoms based on the provided filtering type and element subset. Parameters: atoms (ase.Atoms): The atoms object to filter. - element_subset (list): The list of elements to consider during filtering. + selected_elements (list): The list of elements to consider during filtering. filtering_type (str): The type of filtering to apply. Can be 'none', 'exclusive', or 'inclusive'. 'none' - No filtering is applied. - 'combinations' - Return true if `atoms` is composed of combinations of elements in the subset, false otherwise. I.e. does not require all of the specified elements to be present. - 'exclusive' - Return true if `atoms` contains *only* elements in the subset, false otherwise. - 'inclusive' - Return true if `atoms` contains all elements in the subset, false otherwise. I.e. allows additional elements. + 'exact' - Return true if `atoms` is composed of exactly the same elements as the `seleted_elements`, false otherwise + 'subset' - Return true if `atoms` is composed of a subset of elements in `selected_elements`, false otherwise + 'superset' - Return true if `atoms` is composed of a superset of elements in `selected_elements`, false otherwise + `any_overlap` - Return true if `atoms` contains any of the elements in `selected_elements` Returns: bool: True if the atoms pass the filter, False otherwise. """ if filtering_type == "none": return True - if filtering_type == "combinations": - atom_symbols = np.unique(atoms.symbols) - return all( - x in element_subset for x in atom_symbols - ) # atoms must *only* contain elements in the subset - if filtering_type == "exclusive": - atom_symbols = set(list(atoms.symbols)) - return atom_symbols == set(element_subset) - if filtering_type == "inclusive": - atom_symbols = np.unique(atoms.symbols) - return all( - x in atom_symbols for x in element_subset - ) # atoms must *at least* contain elements in the subset + if filtering_type == "exact": + return set(atoms.symbols) == set(selected_elements) + if filtering_type == "subset": + return set(atoms.symbols).issubset(selected_elements) + if filtering_type == "superset": + return set(selected_elements).issubset(atoms.symbols) + if filtering_type == "any_overlap": + return len(set(selected_elements) & set(atoms.symbols)) >= 1 raise ValueError( - f"Filtering type {filtering_type} not recognised. Must be one of 'none', 'exclusive', or 'inclusive'." + f"Filtering type {filtering_type} not recognised. Must be one of 'none', 'subset', 'exact', 'superset', or 'any_overlap'" ) @@ -204,108 +202,127 @@ def assemble_descriptors(self) -> np.ndarray: def select_samples( args: argparse.Namespace, ) -> None: + # setup np.random.seed(args.seed) torch.manual_seed(args.seed) - if args.model in ["small", "medium", "large"]: - calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype) - else: - calc = MACECalculator( - model_paths=args.model, device=args.device, default_dtype=args.default_dtype - ) + + # read finetuning set if isinstance(args.configs_ft, str): - atoms_list_ft = ase.io.read(args.configs_ft, index=":") + atoms_list_ft = list(tqdm(ase.io.iread(args.configs_ft, index=":"), desc=f"reading configs_ft {args.configs_ft}")) else: atoms_list_ft = [] for path in args.configs_ft: - atoms_list_ft += ase.io.read(path, index=":") + atoms_list_ft += list(tqdm(ase.io.iread(path, index=":"), desc=f"reading configs_ft item {path}")) + + # read pretrained set + atoms_list_pt = list(tqdm(ase.io.iread(args.configs_pt, index=":"), desc="reading configs_pt")) - if args.filtering_type is not None: - all_species_ft = np.unique([x.symbol for atoms in atoms_list_ft for x in atoms]) + indices_pt_filtered = [] + atoms_list_pt_filtered = [] + + # do filtering by elements + if args.filtering_type != "none": + all_species_ft = {atom.symbol for atoms in atoms_list_ft for atom in atoms} logging.info( "Filtering configurations based on the finetuning set, " - f"filtering type: combinations, elements: {all_species_ft}" + f"filtering type: {args.filtering_type}, elements: {all_species_ft}" ) - if args.descriptors is not None: - logging.info("Loading descriptors") - descriptors = np.load(args.descriptors, allow_pickle=True) - atoms_list_pt = ase.io.read(args.configs_pt, index=":") - for i, atoms in enumerate(atoms_list_pt): - atoms.info["mace_descriptors"] = descriptors[i] - atoms_list_pt_filtered = [ - x - for x in atoms_list_pt - if filter_atoms(x, all_species_ft, "combinations") - ] - else: - atoms_list_pt = ase.io.read(args.configs_pt, index=":") - atoms_list_pt_filtered = [ - x - for x in atoms_list_pt - if filter_atoms(x, all_species_ft, "combinations") - ] - if len(atoms_list_pt_filtered) <= args.num_samples: - logging.info( - f"Number of configurations after filtering {len(atoms_list_pt_filtered)} " - f"is less than the number of samples {args.num_samples}, " - "selecting random configurations for the rest." - ) - atoms_list_pt_minus_filtered = [ - x for x in atoms_list_pt if x not in atoms_list_pt_filtered - ] - atoms_list_pt_random_inds = np.random.choice( - list(range(len(atoms_list_pt_minus_filtered))), - args.num_samples - len(atoms_list_pt_filtered), - replace=False, - ) - atoms_list_pt = atoms_list_pt_filtered + [ - atoms_list_pt_minus_filtered[ind] for ind in atoms_list_pt_random_inds - ] + + # select by requested strategy + pt_filter = [filter_atoms(atoms, all_species_ft, args.filtering_type) for atoms in atoms_list_pt] + if sum(pt_filter) <= args.num_samples: + # few enough to include all, will be supplemented by FPS/random later + logging.info(f"Found few enough to include all {sum(pt_filter)} filtered by elements") + indices_pt_filtered = np.where(pt_filter)[0] else: - atoms_list_pt = atoms_list_pt_filtered + # too many, select by increasingly generous strategy and within each one, match in composition + # [NB should we allow setting of exponential base relating overlap and probability, currently 10.0 ?]] + logging.info(f"Found too many filtered by elements {sum(pt_filter)}, choosing based on composition match") + # try increasingly generous matching strategies + indices_pt_filtered_orig = set(np.where(pt_filter)[0]) + indices_pt_filtered = set() + for strategy in ("exact", "subset", "any_overlap"): + strategy_filter = [filter_atoms(atoms, all_species_ft, strategy) for atoms in atoms_list_pt] + if sum(strategy_filter) == 0: + logging.info(f"Nothing selected by {strategy}") + continue + indices_pt_strategy = set(np.where(strategy_filter)[0]) & indices_pt_filtered_orig + indices_pt_strategy -= indices_pt_filtered + if len(indices_pt_filtered) + len(indices_pt_strategy) <= args.num_samples: + # can include all of these + indices_pt_filtered |= indices_pt_strategy + logging.info(f"Adding all {len(indices_pt_strategy)} selected by {strategy}") + else: + # pick a subset with weights, penalizing missing and extra elements + # first term is number of elements that are missing from each config + # second term is number of elements that are extra in each config + # + # for exact distances should all be 0 + # for subset should only have missing elements, no extra (first term only) + # for any_overlap could have either/both, add them up (both terms) + indices_pt_strategy = list(indices_pt_strategy) + d = np.asarray([len(all_species_ft - set(atoms_list_pt[ind].symbols)) + + len(set(atoms_list_pt[ind].symbols) - all_species_ft) for ind in indices_pt_strategy]) + p = 10.0 ** (-d) + p /= np.sum(p) + inds = np.random.choice(len(indices_pt_strategy), args.num_samples - len(indices_pt_filtered), replace=False, p=p) + logging.info(f"Adding subset len {len(inds)} randomly chosen from those selected by {strategy}") + indices_pt_filtered |= {indices_pt_strategy[ind] for ind in inds} + # we already had too many, don't check more generous strategies + break - else: - atoms_list_pt = ase.io.read(args.configs_pt, index=":") - if args.descriptors is not None: - logging.info( - f"Loading descriptors for the pretraining set from {args.descriptors}" - ) - descriptors = np.load(args.descriptors, allow_pickle=True) - for i, atoms in enumerate(atoms_list_pt): - atoms.info["mace_descriptors"] = descriptors[i] + # actually do filtering by composition done so far + atoms_list_pt_filtered = [atoms_list_pt[ind] for ind in indices_pt_filtered] + + # get additional configs from across DB + # [NB: should we be able to control this size separately from size set chosen by filtering?] + atoms_list_pt_extra = [] + if len(atoms_list_pt_filtered) < args.num_samples: + logging.info( + f"Number of configurations after filtering {len(atoms_list_pt_filtered)} " + f"< {args.num_samples} number of samples, " + f"selecting the rest with {args.subselect}" + ) + + indices_pt_avail = set(list(range(len(atoms_list_pt)))) - set(indices_pt_filtered) + atoms_list_pt_avail = [atoms_list_pt[ind] for ind in indices_pt_avail] - if args.num_samples is not None and args.num_samples < len(atoms_list_pt): - if args.subselect == "fps": - if args.descriptors is None: + if args.subselect == "random": + logging.info("Selecting configurations randomly") + idx_pt = np.random.choice(len(atoms_list_pt_avail), args.num_samples - len(atoms_list_pt_filtered), replace=False) + elif args.subselect == "fps": + if args.descriptors is not None: + logging.info(f"Loading descriptors from {args.descriptors}") + descriptors = np.load(args.descriptors, allow_pickle=True) + for descriptor, atoms in zip(descriptors, atoms_list_pt): + atoms.info["mace_descriptors"] = descriptor + else: logging.info("Calculating descriptors for the pretraining set") + # [NB Not great that this parsing of args.model happens here as well as other places. Refactor?] + if args.model in ["small", "medium", "large"]: + calc = mace_mp(args.model, device=args.device, default_dtype=args.default_dtype) + else: + calc = MACECalculator( + model_paths=args.model, device=args.device, default_dtype=args.default_dtype + ) calculate_descriptors(atoms_list_pt, calc) - descriptors_list = [ - atoms.info["mace_descriptors"] for atoms in atoms_list_pt - ] - logging.info( - f"Saving descriptors at {args.output.replace('.xyz', '_descriptors.npy')}" - ) - np.save( - args.output.replace(".xyz", "_descriptors.npy"), descriptors_list - ) + descriptors_list = [atoms.info["mace_descriptors"] for atoms in atoms_list_pt] + + descriptors_file = args.output.replace(".xyz", "descriptors.npy") + logging.info(f"Saving descriptors at {descriptors_file}") + np.save(descriptors_file, descriptors_list) + logging.info("Selecting configurations using Farthest Point Sampling") - try: - fps_pt = FPS(atoms_list_pt, args.num_samples) - idx_pt = fps_pt.run() - logging.info(f"Selected {len(idx_pt)} configurations") - except Exception as e: # pylint: disable=W0703 - logging.error( - f"FPS failed, selecting random configurations instead: {e}" - ) - idx_pt = np.random.choice( - list(range(len(atoms_list_pt))), args.num_samples, replace=False - ) - atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] + fps_pt = FPS(atoms_list_pt_avail, args.num_samples - len(atoms_list_pt_filtered)) + idx_pt = fps_pt.run() else: - logging.info("Selecting random configurations") - idx_pt = np.random.choice( - list(range(len(atoms_list_pt))), args.num_samples, replace=False - ) - atoms_list_pt = [atoms_list_pt[i] for i in idx_pt] + raise ValueError(f"subselect type {args.subselect} not 'random' or 'fps'") + + logging.info(f"Selected {len(idx_pt)} configurations") + atoms_list_pt_extra = [atoms_list_pt_avail[i] for i in idx_pt] + + atoms_list_pt = atoms_list_pt_filtered + atoms_list_pt_extra + for atoms in atoms_list_pt: # del atoms.info["mace_descriptors"] atoms.info["pretrained"] = True diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 4cdfb0c3..d62b9b8b 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -356,6 +356,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=int, default=1000, ) + parser.add_argument( + "--filtering_type_pt", + help="strategy for filtering of configurations for pretrained head", + choices=["none", "subset", "exact", "superset", "any_overlap"], + default="subset" + ) parser.add_argument( "--subselect_pt", help="Method to subselect the configurations of the pretraining set", diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index 1e190da2..0c94aa2d 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -153,9 +153,9 @@ def assemble_mp_data( "head_ft": "Default", "weight_pt": args.weight_pt_head, "weight_ft": 1.0, - "filtering_type": "combination", "output": f"mp_finetuning-{tag}.xyz", "descriptors": descriptors_mp, + "filtering_type": args.filtering_type_pt, "subselect": args.subselect_pt, "device": args.device, "default_dtype": args.default_dtype, @@ -179,4 +179,4 @@ def assemble_mp_data( ) return collections_mp except Exception as exc: - raise RuntimeError("Model download failed and no local model found") from exc + raise RuntimeError("Failed to assemble pretrained data") from exc