diff --git a/source/source_io/module_parameter/read_input_item_ofdft.cpp b/source/source_io/module_parameter/read_input_item_ofdft.cpp index 31e3d2aeeb..46647a7bed 100644 --- a/source/source_io/module_parameter/read_input_item_ofdft.cpp +++ b/source/source_io/module_parameter/read_input_item_ofdft.cpp @@ -382,6 +382,15 @@ Note: Even dimensions may cause slight errors in FFT. It should be ignorable in item.default_value = "False"; item.unit = ""; item.availability = "Used only for KSDFT with plane wave basis"; + item.check_value = [](const Input_Item& item, const Parameter& para) { + if (para.input.of_ml_gene_data + && (para.input.esolver_type != "ksdft" || para.input.basis_type != "pw" || GlobalV::NPROC != 1)) + { + ModuleBase::WARNING_QUIT( + "ReadInput", + "of_ml_gene_data is only available for KSDFT with PW basis on a single MPI rank (NPROC = 1)"); + } + }; read_sync_bool(input.of_ml_gene_data); this->add_item(item); } diff --git a/source/source_io/test_serial/read_input_item_test.cpp b/source/source_io/test_serial/read_input_item_test.cpp index 41e8ba55c8..eeb5246666 100644 --- a/source/source_io/test_serial/read_input_item_test.cpp +++ b/source/source_io/test_serial/read_input_item_test.cpp @@ -1528,6 +1528,36 @@ TEST_F(InputTest, Item_test2) it->second.reset_value(it->second, param); EXPECT_EQ(param.input.of_read_kernel, false); } + { // of_ml_gene_data + auto it = find_label("of_ml_gene_data", readinput.input_lists); + + param.input.of_ml_gene_data = true; + param.input.esolver_type = "ofdft"; + param.input.basis_type = "pw"; + GlobalV::NPROC = 1; + testing::internal::CaptureStdout(); + EXPECT_EXIT(it->second.check_value(it->second, param), ::testing::ExitedWithCode(1), ""); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output, testing::HasSubstr("NOTICE")); + + param.input.of_ml_gene_data = true; + param.input.esolver_type = "ksdft"; + param.input.basis_type = "lcao"; + GlobalV::NPROC = 1; + testing::internal::CaptureStdout(); + EXPECT_EXIT(it->second.check_value(it->second, param), ::testing::ExitedWithCode(1), ""); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output, testing::HasSubstr("NOTICE")); + + param.input.of_ml_gene_data = true; + param.input.esolver_type = "ksdft"; + param.input.basis_type = "pw"; + GlobalV::NPROC = 2; + testing::internal::CaptureStdout(); + EXPECT_EXIT(it->second.check_value(it->second, param), ::testing::ExitedWithCode(1), ""); + output = testing::internal::GetCapturedStdout(); + EXPECT_THAT(output, testing::HasSubstr("NOTICE")); + } { // dft_plus_u auto it = find_label("dft_plus_u", readinput.input_lists); param.input.dft_plus_u = 1; diff --git a/tests/01_PW/103_PW_Gene_Descriptors/INPUT b/tests/01_PW/103_PW_Gene_Descriptors/INPUT new file mode 100644 index 0000000000..784dd3143f --- /dev/null +++ b/tests/01_PW/103_PW_Gene_Descriptors/INPUT @@ -0,0 +1,40 @@ +INPUT_PARAMETERS +#Parameters (1.General) +suffix autotest +calculation scf +nbands 3 +symmetry 1 +pseudo_dir ../../PP_ORB +pseudo_rcut 16 + +#Parameters (2.Iteration) +ecutwfc 10 +scf_nmax 100 +dft_functional XC_GGA_X_PBE+XC_GGA_C_PBE + + +#Parameters (3.Basis) +basis_type pw + +#Parameters (4.Smearing) +smearing_method gauss +smearing_sigma 0.0002 + +#Parameters (5.Mixing) +mixing_type broyden +device cpu + +#Parameters (6.Gene Data) +of_full_pw 1 +of_full_pw_dim 1 + +of_ml_gene_data 1 + +of_ml_nkernel 2 +of_ml_kernel 1 1 +of_ml_kernel_scaling 1.5 1 +of_ml_chi_p 0.2 +of_ml_chi_q 0.1 +of_ml_chi_pnl 0.2 0.2 +of_ml_chi_qnl 0.1 0.1 +of_ml_chi_xi 0.8 1 \ No newline at end of file diff --git a/tests/01_PW/103_PW_Gene_Descriptors/KPT b/tests/01_PW/103_PW_Gene_Descriptors/KPT new file mode 100644 index 0000000000..a5e2161899 --- /dev/null +++ b/tests/01_PW/103_PW_Gene_Descriptors/KPT @@ -0,0 +1,4 @@ +K_POINTS +0 +Gamma +20 20 20 0 0 0 diff --git a/tests/01_PW/103_PW_Gene_Descriptors/README b/tests/01_PW/103_PW_Gene_Descriptors/README new file mode 100644 index 0000000000..afb8c92d10 --- /dev/null +++ b/tests/01_PW/103_PW_Gene_Descriptors/README @@ -0,0 +1 @@ +Tests the correctness of ML functional descriptors generated by PW KSDFT; this case supports single-rank only. diff --git a/tests/01_PW/103_PW_Gene_Descriptors/STRU b/tests/01_PW/103_PW_Gene_Descriptors/STRU new file mode 100644 index 0000000000..18e837d78e --- /dev/null +++ b/tests/01_PW/103_PW_Gene_Descriptors/STRU @@ -0,0 +1,18 @@ +ATOMIC_SPECIES +Al 13 al.gga.psp blps + +LATTICE_CONSTANT +7.6513590200098225 // add lattice constant + +LATTICE_VECTORS +0.000000000000 0.500000000000 0.500000000000 +0.500000000000 0.000000000000 0.500000000000 +0.500000000000 0.500000000000 0.000000000000 + +ATOMIC_POSITIONS +Direct + +Al +0.0 +1 + 0.000000000000 0.000000000000 0.000000000000 1 1 1 diff --git a/tests/01_PW/103_PW_Gene_Descriptors/result.ref b/tests/01_PW/103_PW_Gene_Descriptors/result.ref new file mode 100644 index 0000000000..1989ab7796 --- /dev/null +++ b/tests/01_PW/103_PW_Gene_Descriptors/result.ref @@ -0,0 +1,38 @@ +etotref -57.04829815855825 +etotperatomref -57.0482981586 +ml_desc_mean_enhancement_npy 1.111419127928 +ml_desc_mean_gamma_npy 0.995847307666 +ml_desc_mean_gammanl_1_1_5_npy 0.077138936619 +ml_desc_mean_gammanl_1_1_npy 0.083331919287 +ml_desc_mean_nablaRhox_npy 0.003714583460 +ml_desc_mean_nablaRhoy_npy 0.003714583460 +ml_desc_mean_nablaRhoz_npy 0.003714583460 +ml_desc_mean_p_npy 0.095710030124 +ml_desc_mean_pauli_npy 0.875124481343 +ml_desc_mean_pnl_1_1_5_npy 0.230824971750 +ml_desc_mean_pnl_1_1_npy 0.224100011030 +ml_desc_mean_q_npy 0.310335060112 +ml_desc_mean_qnl_1_1_5_npy 0.598582089727 +ml_desc_mean_qnl_1_1_npy 0.660815190079 +ml_desc_mean_rho_npy 0.026789555290 +ml_desc_mean_tanh_pnl_1_1_5_npy 0.044695787317 +ml_desc_mean_tanh_pnl_1_1_npy 0.043383873620 +ml_desc_mean_tanh_qnl_1_1_5_npy 0.053898994944 +ml_desc_mean_tanh_qnl_1_1_npy 0.059951993401 +ml_desc_mean_tanhp_npy 0.018740571716 +ml_desc_mean_tanhp_nl_1_1_5_npy 0.044938608030 +ml_desc_mean_tanhp_nl_1_1_npy 0.043643966443 +ml_desc_mean_tanhq_npy 0.028985866312 +ml_desc_mean_tanhq_nl_1_1_5_npy 0.054446092813 +ml_desc_mean_tanhq_nl_1_1_npy 0.060077197039 +ml_desc_mean_tanhxi_1_1_5_npy 0.065230231297 +ml_desc_mean_tanhxi_1_1_npy 0.086762840891 +ml_desc_mean_tanhxi_nl_1_1_5_npy 0.113754807936 +ml_desc_mean_tanhxi_nl_1_1_npy 0.163063987051 +ml_desc_mean_veff_npy 0.431187756019 +ml_desc_mean_xi_1_1_5_npy 0.083625451851 +ml_desc_mean_xi_1_1_npy 0.090202449809 +pointgroupref O_h +spacegroupref O_h +nksibzref 256 +totaltimeref 1.46 diff --git a/tests/01_PW/CASES_CPU.txt b/tests/01_PW/CASES_CPU.txt index 5c0ef88066..0ffdf7c077 100644 --- a/tests/01_PW/CASES_CPU.txt +++ b/tests/01_PW/CASES_CPU.txt @@ -103,6 +103,7 @@ scf_out_chg_tau 100_PW_W90 101_PW_MD_1O 102_PW_MD_2O +103_PW_Gene_Descriptors 201_PW_UPF201_Ce_f 202_PW_ONCV_Libxc 204_PW_SY diff --git a/tests/integrate/Autotest.sh b/tests/integrate/Autotest.sh index 9e5250b511..6dd4b5b382 100755 --- a/tests/integrate/Autotest.sh +++ b/tests/integrate/Autotest.sh @@ -9,6 +9,8 @@ nt=$OMP_NUM_THREADS # number of OpenMP threads, default is $OMP_NUM_THREADS threshold=0.0000001 force_threshold=0.0001 stress_threshold=0.001 +# descriptor mean threshold +descriptor_threshold=0.00001 # check accuracy ca=8 # specify the test cases file @@ -60,6 +62,7 @@ echo "Number of threads: $nt" echo "Test accuracy totenergy: $threshold eV" echo "Test accuracy force: $force_threshold" echo "Test accuracy stress: $stress_threshold" +echo "Test accuracy descriptor mean: $descriptor_threshold" echo "Check accuaracy: $ca" echo "Test cases file: $cases_file" echo "Test cases regex: $case" @@ -89,6 +92,7 @@ check_out(){ force_thr=$3 stress_thr=$4 fatal_thr=$5 + descriptor_thr=$6 #------------------------------------------------------ # outfile = result.out @@ -144,7 +148,11 @@ check_out(){ fatal_case_list+=$dir'\n' break else - if [ $(check_deviation_pass $deviation $thr) = 0 ]; then + compare_thr=$thr + if [[ $key == ml_desc_mean_* ]]; then + compare_thr=$descriptor_thr + fi + if [ $(check_deviation_pass $deviation $compare_thr) = 0 ]; then if [ $key == "totalforceref" ]; then if [ $(check_deviation_pass $deviation $force_thr) = 0 ]; then echo -e "[WARNING ] "\ @@ -208,7 +216,7 @@ get_threshold() default_value=$3 if [ -e $threshold_f ]; then threshold_value=$(awk -v tn="$threshold_name" '$1==tn {print $2}' "$threshold_f") - if [ -n "$threshold_value" ]; then + if [ -n "$threshold_value" ]; then echo $threshold_value else echo $default_value @@ -263,6 +271,9 @@ for dir in $testdir; do $abacus > log.txt elif [ "$case" = "282_NO_RPA" ]; then mpirun -np 1 $abacus > log.txt + elif grep -qE '^[[:space:]]*of_ml_gene_data[[:space:]]+1([[:space:]]|$)' INPUT; then + # of_ml_gene_data supports single-rank only. + mpirun -np 1 $abacus > log.txt else mpirun -np $np $abacus > log.txt fi @@ -289,7 +300,8 @@ for dir in $testdir; do my_force_threshold=$(get_threshold $threshold_file "force_threshold" $force_threshold) my_stress_threshold=$(get_threshold $threshold_file "stress_threshold" $stress_threshold) my_fatal_threshold=$(get_threshold $threshold_file "fatal_threshold" $fatal_threshold) - check_out result.out $my_threshold $my_force_threshold $my_stress_threshold $my_fatal_threshold + my_descriptor_threshold=$(get_threshold $threshold_file "descriptor_threshold" $descriptor_threshold) + check_out result.out $my_threshold $my_force_threshold $my_stress_threshold $my_fatal_threshold $my_descriptor_threshold fi else bash -e ../../integrate/tools/catch_properties.sh result.ref diff --git a/tests/integrate/tools/catch_properties.sh b/tests/integrate/tools/catch_properties.sh index c6070a8fc4..21d3bb70cc 100755 --- a/tests/integrate/tools/catch_properties.sh +++ b/tests/integrate/tools/catch_properties.sh @@ -5,6 +5,7 @@ COMPARE_SCRIPT="../../integrate/tools/CompareFile.py" #COMPARE_SCRIPT="../../integrate/tools/compare_file.py" SUM_CUBE_EXE="python3 ../../integrate/tools/sum_cube.py" +COLLECT_NPY_MEANS="../../integrate/tools/collect_npy_means.py" sum_file(){ @@ -646,6 +647,14 @@ if [ "$need_process_cube" = true ]; then fi fi +#-------------------------------------------- +# ML gene data descriptors (.npy) +#-------------------------------------------- +descriptor_dir="OUT.autotest/MLKEDF_Descriptors" +if [ -d "$descriptor_dir" ]; then + python3 $COLLECT_NPY_MEANS "$descriptor_dir" >> "$1" +fi + #-------------------------------------------- # implicit solvation model #-------------------------------------------- diff --git a/tests/integrate/tools/collect_npy_means.py b/tests/integrate/tools/collect_npy_means.py new file mode 100644 index 0000000000..4d138be611 --- /dev/null +++ b/tests/integrate/tools/collect_npy_means.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +import argparse +import os +import re +import sys + +import numpy as np + + +def sanitize_name(name): + return re.sub(r"[^A-Za-z0-9_]", "_", name) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("descriptor_dir") + args = parser.parse_args() + + descriptor_dir = args.descriptor_dir + if not os.path.isdir(descriptor_dir): + print(f"Descriptor dir not found: {descriptor_dir}", file=sys.stderr) + return 1 + + files = sorted(f for f in os.listdir(descriptor_dir) if f.endswith(".npy")) + if not files: + print(f"No .npy files in: {descriptor_dir}", file=sys.stderr) + return 1 + + for filename in files: + path = os.path.join(descriptor_dir, filename) + data = np.load(path, allow_pickle=False) + if data.size == 0: + mean_value = 0.0 + else: + mean_value = float(np.mean(np.abs(data))) + key = f"ml_desc_mean_{sanitize_name(filename)}" + print(f"{key} {mean_value:.12f}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())