Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 117 additions & 6 deletions src/abacusagent/modules/scf.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,131 @@
from pathlib import Path
from typing import Dict, Any
from typing import Dict, Any, Optional, Literal

from abacusagent.init_mcp import mcp
from abacusagent.modules.submodules.scf import abacus_calculation_scf as _abacus_calculation_scf


@mcp.tool()
def abacus_calculation_scf(
abacus_inputs_dir: Path,
# Convergence parameters
ecutwfc: Optional[float] = None,
scf_thr: Optional[float] = None,
scf_nmax: Optional[int] = None,
# Smearing parameters
smearing_method: Optional[Literal["gaussian", "fd", "fixed", "mp", "mv", "cold"]] = None,
smearing_sigma: Optional[float] = None,
# Mixing parameters
mixing_type: Optional[Literal["plain", "kerker", "pulay", "pulay-kerker", "broyden"]] = None,
mixing_beta: Optional[float] = None,
mixing_ndim: Optional[int] = None,
mixing_gg0: Optional[float] = None,
# K-point parameters
kspacing: Optional[float] = None,
gamma_only: Optional[bool] = None,
# Other parameters
symmetry: Optional[bool] = None,
out_chg: Optional[int] = None,
out_mul: Optional[bool] = None,
chg_extrap: Optional[Literal["none", "atomic", "first-order", "second-order"]] = None,
ks_solver: Optional[Literal["cg", "dav", "bpcg", "genelpa", "scalapack_gvx"]] = None,
# Audit control
save_audit_trail: bool = True,
print_audit_summary: bool = False,
) -> Dict[str, Any]:
"""
Run ABACUS SCF calculation.
Run ABACUS SCF calculation with explicit parameter control.

This function supports two modes:
1. Legacy mode: Only abacus_inputs_dir provided → uses INPUT file as-is
2. New mode: Additional parameters provided → applies parameter management with full audit trail

All parameters are optional - missing values will be filled with defaults or inferred.
Full audit trail tracks parameter provenance (user input → defaults → inference → final value).

Args:
abacusjob (str): Path to the directory containing the ABACUS input files.
abacus_inputs_dir: Path to directory containing ABACUS input files (INPUT, STRU, KPT, etc.)

Convergence parameters:
ecutwfc: Energy cutoff for wavefunctions (Ry). Range: >0. Typical: 50-150 Ry
scf_thr: SCF convergence threshold. Range: >0. Default: 1e-6
scf_nmax: Maximum SCF iterations. Range: >0. Default: 100

Smearing parameters:
smearing_method: Electronic occupation smearing method.
Options: gaussian (default), fd, fixed, mp, mv, cold
smearing_sigma: Smearing width (Ry). Range: >0. Default: 0.015 Ry (≈0.2 eV)

Mixing parameters:
mixing_type: Charge density mixing method.
Options: plain, kerker, pulay (default), pulay-kerker, broyden
mixing_beta: Mixing parameter. Range: (0,1]. Default: depends on mixing_type
mixing_ndim: Mixing history size (for pulay/broyden). Range: >0. Default: 8
mixing_gg0: Kerker screening parameter (for kerker-based). Range: ≥0. Default: 0.0

K-point parameters:
kspacing: K-point spacing for automatic mesh (2π/Bohr). Range: >0
gamma_only: Use only Gamma point. Default: False

Other parameters:
symmetry: Use crystal symmetry. Default: True
out_chg: Output charge density. Options: 0 (no), 1 (yes), -1 (auto). Default: 0
out_mul: Output Mulliken analysis. Default: False
chg_extrap: Charge extrapolation method. Options: none, atomic, first-order, second-order
ks_solver: Kohn-Sham solver. Options: cg, dav, bpcg, genelpa, scalapack_gvx

Audit control:
save_audit_trail: Save audit trail to JSON file. Default: True
print_audit_summary: Print audit summary to console. Default: False

Returns:
A dictionary containing the path to output file of ABACUS calculation, and a dictionary containing whether the SCF calculation
finished normally, the SCF is converged or not, the converged SCF energy and total time used.
Dictionary containing:
- scf_work_dir: Path to calculation directory
- normal_end: Whether calculation completed normally
- converge: Whether SCF converged
- energy: Final SCF energy (eV)
- total_time: Calculation time (s)
- audit_trail: Parameter provenance information (if save_audit_trail=True)

Examples:
# Legacy mode - use INPUT file as-is
>>> result = abacus_calculation_scf("/path/to/inputs")

# Custom convergence criteria
>>> result = abacus_calculation_scf(
... "/path/to/inputs",
... ecutwfc=120,
... scf_thr=1e-8,
... scf_nmax=200
... )

# Metal calculation with Kerker mixing
>>> result = abacus_calculation_scf(
... "/path/to/inputs",
... smearing_method="mp",
... smearing_sigma=0.02,
... mixing_type="pulay-kerker",
... mixing_gg0=1.5
... )
"""
return _abacus_calculation_scf(abacus_inputs_dir)
return _abacus_calculation_scf(
abacus_inputs_dir=abacus_inputs_dir,
ecutwfc=ecutwfc,
scf_thr=scf_thr,
scf_nmax=scf_nmax,
smearing_method=smearing_method,
smearing_sigma=smearing_sigma,
mixing_type=mixing_type,
mixing_beta=mixing_beta,
mixing_ndim=mixing_ndim,
mixing_gg0=mixing_gg0,
kspacing=kspacing,
gamma_only=gamma_only,
symmetry=symmetry,
out_chg=out_chg,
out_mul=out_mul,
chg_extrap=chg_extrap,
ks_solver=ks_solver,
save_audit_trail=save_audit_trail,
print_audit_summary=print_audit_summary,
)
11 changes: 11 additions & 0 deletions src/abacusagent/modules/submodules/band/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Band parameter management package."""
from .schema import BandParameters, SmearingMethod, MixingType
from .audit import BandAuditLogger
from .validator import BandParameterValidator, ValidationResult
from .defaults import BandDefaultsManager

__all__ = [
"BandParameters", "SmearingMethod", "MixingType",
"BandAuditLogger", "BandParameterValidator",
"ValidationResult", "BandDefaultsManager"
]
10 changes: 10 additions & 0 deletions src/abacusagent/modules/submodules/band/audit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Audit trail for band calculations."""
from typing import Optional
from ..common import BaseAuditLogger

class BandAuditLogger(BaseAuditLogger):
"""Audit logger for band calculations."""
def __init__(self, calculation_id: Optional[str] = None):
super().__init__(calculation_type="band", calculation_id=calculation_id)

__all__ = ["BandAuditLogger"]
41 changes: 41 additions & 0 deletions src/abacusagent/modules/submodules/band/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Default values for band parameters."""
from typing import Dict, Any
from copy import deepcopy
from ..common import BaseDefaultsManager
from .schema import BandParameters
from .audit import BandAuditLogger

class BandDefaultsManager(BaseDefaultsManager):
"""Defaults manager for band parameters."""

def __init__(self, audit_logger: BandAuditLogger):
super().__init__(audit_logger)

def apply_defaults_and_inferences(self, params: BandParameters, context: Dict[str, Any]) -> BandParameters:
params = deepcopy(params)
params = self._apply_convergence_defaults(params)
params = self._apply_smearing_defaults(params, context)
params = self._apply_mixing_defaults(params, context)
params = self._apply_kpoint_defaults(params, context)
params = self._apply_output_defaults(params, context)
params = self._apply_band_defaults(params)
params = self._infer_mixing_beta(params)
params = self._infer_ks_solver(params, context)
return params

def _apply_band_defaults(self, params: BandParameters) -> BandParameters:
if params.mode is None:
params.mode = "auto"
self.audit.log_default("mode", "auto", "Auto-detect band calculation mode")
if params.energy_min is None:
params.energy_min = -10.0
self.audit.log_default("energy_min", -10.0, "Standard lower energy bound")
if params.energy_max is None:
params.energy_max = 10.0
self.audit.log_default("energy_max", 10.0, "Standard upper energy bound")
if params.insert_point_nums is None:
params.insert_point_nums = 30
self.audit.log_default("insert_point_nums", 30, "Standard k-point density")
return params

__all__ = ["BandDefaultsManager"]
53 changes: 53 additions & 0 deletions src/abacusagent/modules/submodules/band/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Parameter schemas and type definitions for band calculations.

This module defines the schema-first approach for band parameters:
- Inherits common parameters from the shared framework
- Adds band-specific parameters
- Explicit type hints using Literal types
"""

from typing import Literal, Optional, List, Dict, Union
from dataclasses import dataclass
from ..common import CommonPostSCFParameters


@dataclass
class BandParameters(CommonPostSCFParameters):
"""
Schema for band calculation parameters.

Inherits common parameters from CommonPostSCFParameters:
- Convergence, Smearing, Mixing, K-points, Output

Adds band-specific parameters:
- mode: Calculation mode (nscf, pyatb, auto)
- kpath: High symmetry k-point path
- high_symm_points: Coordinates of high symmetry points
- energy_min: Lower energy bound for plot
- energy_max: Upper energy bound for plot
- insert_point_nums: Points between high symmetry points
"""

mode: Optional[Literal["nscf", "pyatb", "auto"]] = None
"""Band calculation mode (default: auto)"""

kpath: Optional[Union[List[str], List[List[str]]]] = None
"""High symmetry k-point path"""

high_symm_points: Optional[Dict[str, List[float]]] = None
"""Coordinates of high symmetry points"""

energy_min: Optional[float] = None
"""Lower energy bound (eV, default: -10)"""

energy_max: Optional[float] = None
"""Upper energy bound (eV, default: 10)"""

insert_point_nums: Optional[int] = None
"""Points between high symmetry points (default: 30)"""


from ..common import SmearingMethod, MixingType

__all__ = ["BandParameters", "SmearingMethod", "MixingType"]
31 changes: 31 additions & 0 deletions src/abacusagent/modules/submodules/band/validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Validation logic for band parameters."""
from typing import Dict, Tuple, List, Any
from ..common import BaseParameterValidator, ValidationResult
from .schema import BandParameters

class BandParameterValidator(BaseParameterValidator):
"""Validator for band parameters."""

def validate_all(self, params: BandParameters, context: Dict[str, Any]) -> Tuple[bool, List[ValidationResult]]:
self.validation_results = []
self.warnings = []
self.errors = []

self._validate_convergence_params(params)
self._validate_smearing_params(params)
self._validate_mixing_params(params)
self._validate_kpoint_params(params)
self._validate_output_params(params, context)
self._validate_band_specific(params)

return len(self.errors) == 0, self.validation_results

def _validate_band_specific(self, params: BandParameters):
if params.insert_point_nums is not None and params.insert_point_nums <= 0:
self._add_error("insert_point_nums", f"insert_point_nums must be > 0, got {params.insert_point_nums}")

if params.energy_min is not None and params.energy_max is not None:
if params.energy_min >= params.energy_max:
self._add_error("energy_min", f"energy_min must be < energy_max")

__all__ = ["BandParameterValidator", "ValidationResult"]
79 changes: 79 additions & 0 deletions src/abacusagent/modules/submodules/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Common parameter management framework.

This package provides the foundational components for parameter management
across all calculation modules. It includes:

- Base schemas for parameters and audit trails
- Common parameter definitions (enums and groups)
- Base validator with common validation methods
- Base audit logger for provenance tracking
- Base defaults manager with inference rules

Module-specific implementations inherit from these base classes and extend
them with module-specific logic.
"""

# Base schemas
from .base_schema import (
BaseParameters,
ParameterProvenance,
ValidationResult,
AuditTrail,
)

# Shared parameters
from .shared_parameters import (
# Enums
SmearingMethod,
MixingType,
BasisType,
# Parameter groups
ConvergenceParameters,
SmearingParameters,
MixingParameters,
KPointParameters,
ForceStressParameters,
OutputParameters,
)

# Composable parameter groups
from .parameter_groups import (
CommonSCFParameters,
CommonRelaxationParameters,
CommonPostSCFParameters,
)

# Base classes
from .base_validator import BaseParameterValidator
from .base_audit import BaseAuditLogger
from .base_defaults import BaseDefaultsManager, INFERENCE_RULES

__all__ = [
# Base schemas
"BaseParameters",
"ParameterProvenance",
"ValidationResult",
"AuditTrail",
# Enums
"SmearingMethod",
"MixingType",
"BasisType",
# Parameter groups
"ConvergenceParameters",
"SmearingParameters",
"MixingParameters",
"KPointParameters",
"ForceStressParameters",
"OutputParameters",
# Composable groups
"CommonSCFParameters",
"CommonRelaxationParameters",
"CommonPostSCFParameters",
# Base classes
"BaseParameterValidator",
"BaseAuditLogger",
"BaseDefaultsManager",
# Inference rules
"INFERENCE_RULES",
]
Loading
Loading