diff --git a/src/lammpsparser/potential.py b/src/lammpsparser/potential.py index b5634d8..3c8fe0f 100644 --- a/src/lammpsparser/potential.py +++ b/src/lammpsparser/potential.py @@ -33,6 +33,383 @@ """ +import re +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Tuple, Union + +import pandas as pd + + +@dataclass +class Potential: + """Unified potential representation.""" + + year: str + year_suffix: str + authors: str + elements: Set[str] + repo_type: str + ipr: Optional[int] + original: str + df_index: Optional[int] = None + + @property + def sort_key(self): + """Key for sorting - prefer LAMMPS, then higher ipr.""" + return (0 if self.repo_type == "LAMMPS" else 1, -(self.ipr if self.ipr else 0)) + + @property + def family_id(self) -> str: + """Return author_year[suffix] identifier.""" + year_full = self.year + self.year_suffix if self.year_suffix else self.year + return f"{self.authors}_{year_full}" + + +class PotentialDeduplicator: + """ + Deduplicate interatomic potentials from DataFrame. + + Rules: + 1. Potentials from same author+year+suffix are duplicates + 2. Within LAMMPS: prefer higher ipr + 3. Across repos: prefer LAMMPS over OpenKIM + 4. Only keep potentials containing ALL target_elements + """ + + def __init__( + self, + target_elements: Union[str, List[str], Set[str]] = "Ni", + verbose: bool = False, + ): + """ + Parameters + ---------- + target_elements : str, list of str, or set of str + Element(s) to filter for. If multiple elements provided, + potentials must contain ALL of them. + Examples: 'Ni', ['Ni', 'Al'], {'Ni', 'Al', 'Cu'} + verbose : bool + Print deduplication details + """ + # Normalize to set + if isinstance(target_elements, str): + self.target_elements = {target_elements} + elif isinstance(target_elements, (list, tuple)): + self.target_elements = set(target_elements) + elif isinstance(target_elements, set): + self.target_elements = target_elements + else: + raise ValueError( + f"target_elements must be str, list, or set, got {type(target_elements)}" + ) + + self.verbose = verbose + self.last_duplicates_map = {} + self.last_stats = {} + + @property + def target_elements_str(self) -> str: + """Human-readable string of target elements.""" + if len(self.target_elements) == 1: + return list(self.target_elements)[0] + else: + return "{" + ", ".join(sorted(self.target_elements)) + "}" + + @staticmethod + def normalize_author(author_str: str) -> str: + """Extract and normalize primary author surname.""" + parts = re.split(r"[-_]", author_str) + main = parts[0] + + # Handle camelCase in OpenKIM + if len(re.findall(r"[A-Z]", main)) > 1: + camel_parts = re.split(r"(?=[A-Z])", main) + camel_parts = [p for p in camel_parts if p] + main = camel_parts[0] if camel_parts else main + + return re.sub(r"[^a-z]", "", main.lower()) + + @staticmethod + def parse_potential_metadata(name: str) -> Optional[Dict]: + """Parse potential name for metadata (year, author, repo, ipr).""" + + # Try LAMMPS format + lammps_pattern = ( + r"(\d{4})--([^-]+(?:-[^-]+)*)--([^-]+(?:-[^-]+)*)--LAMMPS--ipr(\d+)" + ) + match = re.match(lammps_pattern, name) + if match: + year, authors, _, ipr = match.groups() + return { + "year": year, + "year_suffix": "", + "authors": PotentialDeduplicator.normalize_author(authors), + "repo_type": "LAMMPS", + "ipr": int(ipr), + } + + # Try OpenKIM format + year_match = re.search(r"_(\d{4})([^_]*)", name) + mo_match = re.search(r"__(MO_|SM_)", name) + + if year_match and mo_match: + year = year_match.group(1) + year_suffix = year_match.group(2) + + parts = name.split("_") + year_idx = None + for i, part in enumerate(parts): + if part.startswith(year): + year_idx = i + break + + if year_idx and year_idx > 0: + authors = parts[year_idx - 1] + else: + authors = "" + + return { + "year": year, + "year_suffix": year_suffix, + "authors": ( + PotentialDeduplicator.normalize_author(authors) if authors else "" + ), + "repo_type": "OpenKIM", + "ipr": None, + } + + return None + + def contains_target_elements(self, elements: Set[str]) -> bool: + """Check if elements set contains ALL target elements.""" + return self.target_elements.issubset(elements) + + def deduplicate(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Deduplicate potentials from DataFrame. + + Parameters + ---------- + df : DataFrame + Must have 'Name' and 'Species' columns + + Returns + ------- + deduplicated_df : DataFrame + Deduplicated potentials containing all target elements + """ + + if "Name" not in df.columns or "Species" not in df.columns: + raise ValueError("DataFrame must have 'Name' and 'Species' columns") + + # Parse all potentials + potentials = [] + unparsed_indices = [] + filtered_out_indices = [] + + for idx, row in df.iterrows(): + name = row["Name"] + species = row["Species"] + + # Convert species to set + if isinstance(species, list): + elements = set(species) + elif isinstance(species, set): + elements = species + else: + elements = set() + + # Check if ALL target elements are present + if not self.contains_target_elements(elements): + filtered_out_indices.append(idx) + continue + + # Parse metadata + metadata = self.parse_potential_metadata(name) + + if metadata: + pot = Potential( + year=metadata["year"], + year_suffix=metadata["year_suffix"], + authors=metadata["authors"], + elements=elements, + repo_type=metadata["repo_type"], + ipr=metadata["ipr"], + original=name, + df_index=idx, + ) + potentials.append(pot) + else: + unparsed_indices.append(idx) + + # Store stats + self.last_stats = { + "total": len(df), + "filtered_out": len(filtered_out_indices), + "unparsed": len(unparsed_indices), + "valid": len(potentials), + } + + if self.verbose: + print(f"Total potentials: {self.last_stats['total']}") + print(f"Target elements: {self.target_elements_str}") + print( + f"Filtered out (missing target elements): {self.last_stats['filtered_out']}" + ) + print(f"Unparsed: {self.last_stats['unparsed']}") + print(f"Valid for deduplication: {self.last_stats['valid']}") + + # Group by (year+suffix, author) + groups = defaultdict(list) + for pot in potentials: + year_full = pot.year + pot.year_suffix if pot.year_suffix else pot.year + key = (year_full, pot.authors) + groups[key].append(pot) + + # Keep only the best from each group + kept_indices = [] + self.last_duplicates_map = {} + + for (year_full, author), group in sorted(groups.items()): + if len(group) == 1: + kept_indices.append(group[0].df_index) + continue + + # Sort by preference: LAMMPS first, then highest ipr + group.sort(key=lambda p: p.sort_key) + + best = group[0] + rest = group[1:] + + kept_indices.append(best.df_index) + self.last_duplicates_map[best.original] = [p.original for p in rest] + + if self.verbose: + print(f"\nGroup: {year_full} - {author}") + print(f" Kept: {best.original}") + for dup in rest: + print(f" Removed: {dup.original}") + + # Add back unparsed items + kept_indices.extend(unparsed_indices) + + # Update stats + self.last_stats["kept"] = len(kept_indices) + self.last_stats["removed_duplicates"] = sum( + len(v) for v in self.last_duplicates_map.values() + ) + + if self.verbose: + print(f"\nFinal count: {self.last_stats['kept']}") + print(f"Duplicates removed: {self.last_stats['removed_duplicates']}") + + # Return deduplicated DataFrame + return df.loc[kept_indices].copy() + + def get_duplicates(self) -> Dict[str, List[str]]: + """Return the duplicates map from last deduplication.""" + return self.last_duplicates_map.copy() + + def get_stats(self) -> Dict[str, int]: + """Return statistics from last deduplication.""" + return self.last_stats.copy() + + def get_family_id(self, potential_name: str) -> Optional[str]: + """ + Get the family label (author_year[suffix]) for a potential. + + Returns normalized label like 'foiles_1986' or 'adams_1989Universal6'. + """ + metadata = self.parse_potential_metadata(potential_name) + if metadata: + year_full = ( + metadata["year"] + metadata["year_suffix"] + if metadata["year_suffix"] + else metadata["year"] + ) + return f"{metadata['authors']}_{year_full}" + return None + + def analyze_families(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Analyze potential families in the DataFrame. + + Returns a summary DataFrame with family counts and repo types. + Only includes potentials with all target elements. + """ + families = defaultdict(lambda: {"count": 0, "repos": set(), "names": []}) + + for _, row in df.iterrows(): + name = row["Name"] + species = row["Species"] + + # Check target elements + elements = set(species) if isinstance(species, list) else species + if not self.contains_target_elements(elements): + continue + + family_id = self.get_family_id(name) + if family_id: + metadata = self.parse_potential_metadata(name) + families[family_id]["count"] += 1 + families[family_id]["repos"].add(metadata["repo_type"]) + families[family_id]["names"].append(name) + + # Convert to DataFrame + summary = [] + for family_id, info in sorted(families.items()): + summary.append( + { + "family": family_id, + "count": info["count"], + "repos": ", ".join(sorted(info["repos"])), + "has_duplicates": info["count"] > 1, + } + ) + + return pd.DataFrame(summary) + + def filter_by_elements( + self, + df: pd.DataFrame, + target_elements: Optional[Union[str, List[str], Set[str]]] = None, + ) -> pd.DataFrame: + """ + Filter DataFrame to only potentials containing specified elements. + + Parameters + ---------- + df : DataFrame + Input DataFrame with 'Species' column + target_elements : str, list, set, or None + Elements to filter for. If None, uses self.target_elements + + Returns + ------- + filtered_df : DataFrame + Filtered to potentials with all target elements + """ + if target_elements is not None: + # Temporarily change target elements + if isinstance(target_elements, str): + target_set = {target_elements} + elif isinstance(target_elements, (list, tuple)): + target_set = set(target_elements) + else: + target_set = target_elements + else: + target_set = self.target_elements + + filtered_indices = [] + for idx, row in df.iterrows(): + species = row["Species"] + elements = set(species) if isinstance(species, list) else species + if target_set.issubset(elements): + filtered_indices.append(idx) + + return df.loc[filtered_indices].copy() + class PotentialAbstract: """ @@ -283,7 +660,12 @@ def view_potentials(structure: Atoms, resource_path: str) -> pandas.DataFrame: pandas.Dataframe: Dataframe including all potential parameters. """ list_of_elements = set(structure.get_chemical_symbols()) - return LammpsPotentialFile(resource_path=resource_path).find(list_of_elements) + raw_df = LammpsPotentialFile(resource_path=resource_path).find(list_of_elements) + + dedup = PotentialDeduplicator(target_elements=list_of_elements, verbose=True) + clean_df = dedup.deduplicate(raw_df) + + return clean_df def convert_path_to_abs_posix(path: str) -> str: