Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions source/source_io/module_parameter/read_input_item_ofdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,15 @@ Note: Even dimensions may cause slight errors in FFT. It should be ignorable in
item.default_value = "False";
item.unit = "";
item.availability = "Used only for KSDFT with plane wave basis";
item.check_value = [](const Input_Item& item, const Parameter& para) {
if (para.input.of_ml_gene_data
&& (para.input.esolver_type != "ksdft" || para.input.basis_type != "pw" || GlobalV::NPROC != 1))
{
ModuleBase::WARNING_QUIT(
"ReadInput",
"of_ml_gene_data is only available for KSDFT with PW basis on a single MPI rank (NPROC = 1)");
}
};
Comment thread
sunliang98 marked this conversation as resolved.
read_sync_bool(input.of_ml_gene_data);
this->add_item(item);
}
Expand Down
30 changes: 30 additions & 0 deletions source/source_io/test_serial/read_input_item_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,36 @@ TEST_F(InputTest, Item_test2)
it->second.reset_value(it->second, param);
EXPECT_EQ(param.input.of_read_kernel, false);
}
{ // of_ml_gene_data
auto it = find_label("of_ml_gene_data", readinput.input_lists);

param.input.of_ml_gene_data = true;
param.input.esolver_type = "ofdft";
param.input.basis_type = "pw";
GlobalV::NPROC = 1;
testing::internal::CaptureStdout();
EXPECT_EXIT(it->second.check_value(it->second, param), ::testing::ExitedWithCode(1), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("NOTICE"));

param.input.of_ml_gene_data = true;
param.input.esolver_type = "ksdft";
param.input.basis_type = "lcao";
GlobalV::NPROC = 1;
testing::internal::CaptureStdout();
EXPECT_EXIT(it->second.check_value(it->second, param), ::testing::ExitedWithCode(1), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("NOTICE"));

param.input.of_ml_gene_data = true;
param.input.esolver_type = "ksdft";
param.input.basis_type = "pw";
GlobalV::NPROC = 2;
testing::internal::CaptureStdout();
EXPECT_EXIT(it->second.check_value(it->second, param), ::testing::ExitedWithCode(1), "");
output = testing::internal::GetCapturedStdout();
EXPECT_THAT(output, testing::HasSubstr("NOTICE"));
}
{ // dft_plus_u
auto it = find_label("dft_plus_u", readinput.input_lists);
param.input.dft_plus_u = 1;
Expand Down
40 changes: 40 additions & 0 deletions tests/01_PW/103_PW_Gene_Descriptors/INPUT
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
INPUT_PARAMETERS
#Parameters (1.General)
suffix autotest
calculation scf
nbands 3
symmetry 1
pseudo_dir ../../PP_ORB
pseudo_rcut 16

#Parameters (2.Iteration)
ecutwfc 10
scf_nmax 100
dft_functional XC_GGA_X_PBE+XC_GGA_C_PBE


#Parameters (3.Basis)
basis_type pw

#Parameters (4.Smearing)
smearing_method gauss
smearing_sigma 0.0002

#Parameters (5.Mixing)
mixing_type broyden
device cpu

#Parameters (6.Gene Data)
of_full_pw 1
of_full_pw_dim 1

of_ml_gene_data 1

of_ml_nkernel 2
of_ml_kernel 1 1
of_ml_kernel_scaling 1.5 1
of_ml_chi_p 0.2
of_ml_chi_q 0.1
of_ml_chi_pnl 0.2 0.2
of_ml_chi_qnl 0.1 0.1
of_ml_chi_xi 0.8 1
4 changes: 4 additions & 0 deletions tests/01_PW/103_PW_Gene_Descriptors/KPT
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
K_POINTS
0
Gamma
20 20 20 0 0 0
1 change: 1 addition & 0 deletions tests/01_PW/103_PW_Gene_Descriptors/README
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Tests the correctness of ML functional descriptors generated by PW KSDFT; this case supports single-rank only.
18 changes: 18 additions & 0 deletions tests/01_PW/103_PW_Gene_Descriptors/STRU
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
ATOMIC_SPECIES
Al 13 al.gga.psp blps

LATTICE_CONSTANT
7.6513590200098225 // add lattice constant

LATTICE_VECTORS
0.000000000000 0.500000000000 0.500000000000
0.500000000000 0.000000000000 0.500000000000
0.500000000000 0.500000000000 0.000000000000

ATOMIC_POSITIONS
Direct

Al
0.0
1
0.000000000000 0.000000000000 0.000000000000 1 1 1
38 changes: 38 additions & 0 deletions tests/01_PW/103_PW_Gene_Descriptors/result.ref
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
etotref -57.04829815855825
etotperatomref -57.0482981586
ml_desc_mean_enhancement_npy 1.111419127928
ml_desc_mean_gamma_npy 0.995847307666
ml_desc_mean_gammanl_1_1_5_npy 0.077138936619
ml_desc_mean_gammanl_1_1_npy 0.083331919287
ml_desc_mean_nablaRhox_npy 0.003714583460
ml_desc_mean_nablaRhoy_npy 0.003714583460
ml_desc_mean_nablaRhoz_npy 0.003714583460
ml_desc_mean_p_npy 0.095710030124
ml_desc_mean_pauli_npy 0.875124481343
ml_desc_mean_pnl_1_1_5_npy 0.230824971750
ml_desc_mean_pnl_1_1_npy 0.224100011030
ml_desc_mean_q_npy 0.310335060112
ml_desc_mean_qnl_1_1_5_npy 0.598582089727
ml_desc_mean_qnl_1_1_npy 0.660815190079
ml_desc_mean_rho_npy 0.026789555290
ml_desc_mean_tanh_pnl_1_1_5_npy 0.044695787317
ml_desc_mean_tanh_pnl_1_1_npy 0.043383873620
ml_desc_mean_tanh_qnl_1_1_5_npy 0.053898994944
ml_desc_mean_tanh_qnl_1_1_npy 0.059951993401
ml_desc_mean_tanhp_npy 0.018740571716
ml_desc_mean_tanhp_nl_1_1_5_npy 0.044938608030
ml_desc_mean_tanhp_nl_1_1_npy 0.043643966443
ml_desc_mean_tanhq_npy 0.028985866312
ml_desc_mean_tanhq_nl_1_1_5_npy 0.054446092813
ml_desc_mean_tanhq_nl_1_1_npy 0.060077197039
ml_desc_mean_tanhxi_1_1_5_npy 0.065230231297
ml_desc_mean_tanhxi_1_1_npy 0.086762840891
ml_desc_mean_tanhxi_nl_1_1_5_npy 0.113754807936
ml_desc_mean_tanhxi_nl_1_1_npy 0.163063987051
ml_desc_mean_veff_npy 0.431187756019
ml_desc_mean_xi_1_1_5_npy 0.083625451851
ml_desc_mean_xi_1_1_npy 0.090202449809
pointgroupref O_h
spacegroupref O_h
nksibzref 256
totaltimeref 1.46
1 change: 1 addition & 0 deletions tests/01_PW/CASES_CPU.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ scf_out_chg_tau
100_PW_W90
101_PW_MD_1O
102_PW_MD_2O
103_PW_Gene_Descriptors
201_PW_UPF201_Ce_f
202_PW_ONCV_Libxc
204_PW_SY
Expand Down
18 changes: 15 additions & 3 deletions tests/integrate/Autotest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ nt=$OMP_NUM_THREADS # number of OpenMP threads, default is $OMP_NUM_THREADS
threshold=0.0000001
force_threshold=0.0001
stress_threshold=0.001
# descriptor mean threshold
descriptor_threshold=0.00001
# check accuracy
ca=8
# specify the test cases file
Expand Down Expand Up @@ -60,6 +62,7 @@ echo "Number of threads: $nt"
echo "Test accuracy totenergy: $threshold eV"
echo "Test accuracy force: $force_threshold"
echo "Test accuracy stress: $stress_threshold"
echo "Test accuracy descriptor mean: $descriptor_threshold"
echo "Check accuaracy: $ca"
echo "Test cases file: $cases_file"
echo "Test cases regex: $case"
Expand Down Expand Up @@ -89,6 +92,7 @@ check_out(){
force_thr=$3
stress_thr=$4
fatal_thr=$5
descriptor_thr=$6

#------------------------------------------------------
# outfile = result.out
Expand Down Expand Up @@ -144,7 +148,11 @@ check_out(){
fatal_case_list+=$dir'\n'
break
else
if [ $(check_deviation_pass $deviation $thr) = 0 ]; then
compare_thr=$thr
if [[ $key == ml_desc_mean_* ]]; then
compare_thr=$descriptor_thr
fi
if [ $(check_deviation_pass $deviation $compare_thr) = 0 ]; then
if [ $key == "totalforceref" ]; then
if [ $(check_deviation_pass $deviation $force_thr) = 0 ]; then
echo -e "[WARNING ] "\
Expand Down Expand Up @@ -208,7 +216,7 @@ get_threshold()
default_value=$3
if [ -e $threshold_f ]; then
threshold_value=$(awk -v tn="$threshold_name" '$1==tn {print $2}' "$threshold_f")
if [ -n "$threshold_value" ]; then
if [ -n "$threshold_value" ]; then
echo $threshold_value
else
echo $default_value
Expand Down Expand Up @@ -263,6 +271,9 @@ for dir in $testdir; do
$abacus > log.txt
elif [ "$case" = "282_NO_RPA" ]; then
mpirun -np 1 $abacus > log.txt
elif grep -qE '^[[:space:]]*of_ml_gene_data[[:space:]]+1([[:space:]]|$)' INPUT; then
# of_ml_gene_data supports single-rank only.
mpirun -np 1 $abacus > log.txt
else
mpirun -np $np $abacus > log.txt
fi
Expand All @@ -289,7 +300,8 @@ for dir in $testdir; do
my_force_threshold=$(get_threshold $threshold_file "force_threshold" $force_threshold)
my_stress_threshold=$(get_threshold $threshold_file "stress_threshold" $stress_threshold)
my_fatal_threshold=$(get_threshold $threshold_file "fatal_threshold" $fatal_threshold)
check_out result.out $my_threshold $my_force_threshold $my_stress_threshold $my_fatal_threshold
my_descriptor_threshold=$(get_threshold $threshold_file "descriptor_threshold" $descriptor_threshold)
check_out result.out $my_threshold $my_force_threshold $my_stress_threshold $my_fatal_threshold $my_descriptor_threshold
fi
else
bash -e ../../integrate/tools/catch_properties.sh result.ref
Expand Down
9 changes: 9 additions & 0 deletions tests/integrate/tools/catch_properties.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
COMPARE_SCRIPT="../../integrate/tools/CompareFile.py"
#COMPARE_SCRIPT="../../integrate/tools/compare_file.py"
SUM_CUBE_EXE="python3 ../../integrate/tools/sum_cube.py"
COLLECT_NPY_MEANS="../../integrate/tools/collect_npy_means.py"


sum_file(){
Expand Down Expand Up @@ -646,6 +647,14 @@ if [ "$need_process_cube" = true ]; then
fi
fi

#--------------------------------------------
# ML gene data descriptors (.npy)
#--------------------------------------------
descriptor_dir="OUT.autotest/MLKEDF_Descriptors"
if [ -d "$descriptor_dir" ]; then
python3 $COLLECT_NPY_MEANS "$descriptor_dir" >> "$1"
fi

#--------------------------------------------
# implicit solvation model
#--------------------------------------------
Expand Down
43 changes: 43 additions & 0 deletions tests/integrate/tools/collect_npy_means.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python3
import argparse
import os
import re
import sys

import numpy as np


def sanitize_name(name):
return re.sub(r"[^A-Za-z0-9_]", "_", name)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("descriptor_dir")
args = parser.parse_args()

descriptor_dir = args.descriptor_dir
if not os.path.isdir(descriptor_dir):
print(f"Descriptor dir not found: {descriptor_dir}", file=sys.stderr)
return 1

files = sorted(f for f in os.listdir(descriptor_dir) if f.endswith(".npy"))
if not files:
print(f"No .npy files in: {descriptor_dir}", file=sys.stderr)
return 1

for filename in files:
path = os.path.join(descriptor_dir, filename)
data = np.load(path, allow_pickle=False)
if data.size == 0:
mean_value = 0.0
else:
mean_value = float(np.mean(np.abs(data)))
key = f"ml_desc_mean_{sanitize_name(filename)}"
print(f"{key} {mean_value:.12f}")

return 0


if __name__ == "__main__":
raise SystemExit(main())
Loading