Skip to content

Commit 58bc55b

Browse files
committed
Working JE for entropic spring, code must be cleaned
1 parent 7ab4b4c commit 58bc55b

File tree

9 files changed

+140
-78
lines changed

9 files changed

+140
-78
lines changed

cpp/common/OpenCL/OCL_MM.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,14 @@ class OCL_MM: public OCLsystem { public:
7878
int ibuff_atoms=-1,ibuff_aforces=-1,ibuff_neighs=-1,ibuff_neighCell=-1;
7979
int ibuff_avel=-1,ibuff_cvf=-1, ibuff_neighForce=-1, ibuff_bkNeighs=-1, ibuff_bkNeighs_new=-1;
8080
int ibuff_REQs=-1, ibuff_MMpars=-1, ibuff_BLs=-1,ibuff_BKs=-1,ibuff_Ksp=-1, ibuff_Kpp=-1; // MMFFf4 params
81-
int ibuff_lvecs=-1, ibuff_ilvecs=-1,ibuff_MDpars=-1,ibuff_TDrive=-1, ibuff_pbcshifts=-1;
81+
int ibuff_lvecs=-1, ibuff_ilvecs=-1,ibuff_MDpars=-1,ibuff_TDrive=-1, ibuff_pbcshifts=-1, ibuff_jeParams=-1;
8282
int ibuff_constr=-1;
8383
int ibuff_constrK=-1;
8484
int ibuff_bboxes=-1;
8585
int ibuff_sysneighs=-1;
8686
int ibuff_sysbonds=-1;
8787
int ibuff_averageForces=-1;
8888
int ibuff_work=-1;
89-
int4 jeParams{0,0,0,0};
9089

9190
int ibuff_samp_ps=-1;
9291
int ibuff_samp_fs=-1;
@@ -233,6 +232,8 @@ class OCL_MM: public OCLsystem { public:
233232

234233
// Buffer for thermodynamic integration - stores accumulated force differences
235234
ibuff_averageForces = newBuffer( "averageForces", nSystems, sizeof(float4), 0, CL_MEM_READ_WRITE );
235+
ibuff_work = newBuffer( "work", nSystems, sizeof(float), 0, CL_MEM_READ_WRITE );
236+
ibuff_jeParams = newBuffer( "jeParams", nSystems, sizeof(int4), 0, CL_MEM_READ_WRITE );
236237
ibuff_BKs = newBuffer( "BKs", nSystems*nnode, sizeof(float4), 0, CL_MEM_READ_ONLY );
237238
ibuff_Ksp = newBuffer( "Ksp", nSystems*nnode, sizeof(float4), 0, CL_MEM_READ_ONLY );
238239
ibuff_Kpp = newBuffer( "Kpp", nSystems*nnode, sizeof(float4), 0, CL_MEM_READ_ONLY );
@@ -606,7 +607,7 @@ class OCL_MM: public OCLsystem { public:
606607
err |= useArgBuff( ibuff_sysbonds ); // 14
607608
err |= useArgBuff( ibuff_averageForces);// 15
608609
err |= useArgBuff( ibuff_work ); // 16
609-
err |= _useArg( jeParams ); // 17
610+
err |= useArgBuff( ibuff_jeParams ); // 17
610611
OCL_checkError(err, "setup_updateAtomsMMFFf4");
611612
return task;
612613
// const int4 n, // 1 // (natoms,nnode) dimensions of the system

cpp/common/molecular/MMFFBuilder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4186,6 +4186,7 @@ void toMMFFsp3_loc( MMFFsp3_loc& ff, bool bRealloc=true, bool bEPairs=true, bool
41864186
const Atom& A = atoms[ia];
41874187
ff.apos [ia] = A.pos;
41884188
ff.atypes[ia] = A.type;
4189+
ff.REQs [ia] = A.REQ;
41894190
AtomType& atyp = params->atypes[A.type];
41904191

41914192
if(A.iconf>=0){

cpp/common/molecular/MolWorld_sp3_multi.h

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class MolWorld_sp3_multi : public MolWorld_sp3, public MultiSolverInterface { pu
122122
Quat4f* MDpars =0; // Molecular dynamics params
123123
Quat4f* TDrive =0; // temperature and drived dynamics
124124
Quat4f* averageForces =0; // accumulated force differences for thermodynamic integration
125+
Quat4i* jeParams = 0; // parameters for Jarzynski Equality [ nSystems ]
125126

126127
Quat4f* constr =0;
127128
Quat4f* constrK =0;
@@ -389,8 +390,8 @@ void TI_step(double lambda, double dE, double sigma, double dLambda, int nMDstep
389390
for(int ia=0; ia<ffls[isys].natoms; ia++){
390391
if(ffls[isys].atypes[ia]==params.getAtomType("Si")){
391392
if(si_count < nCVs){
392-
Quat4f acon = Quat4f{initial_positions[si_count].x, initial_positions[si_count].y, initial_positions[si_count].z, 1e2f};
393-
Quat4f aconK = Quat4f{final_positions[si_count].x, final_positions[si_count].y, final_positions[si_count].z, (float)nLambda};
393+
Quat4f acon = Quat4f{initial_positions[si_count].x, initial_positions[si_count].y, initial_positions[si_count].z, 5.0f};
394+
Quat4f aconK = Quat4f{final_positions[si_count].x, final_positions[si_count].y, final_positions[si_count].z, 0.0f};
394395
constr [isys*ocl.nAtoms + ia] = acon;
395396
constrK[isys*ocl.nAtoms + ia] = aconK;
396397
}
@@ -402,6 +403,7 @@ void TI_step(double lambda, double dE, double sigma, double dLambda, int nMDstep
402403
upload( ocl.ibuff_constrK, constrK );
403404

404405
double beta = 1.0 / (const_kB * go.T_target);
406+
double dLambda = 1.0 / (double)(nLambda - 1);
405407
nPerVFs = nPerVFs_;
406408
int nBatches = nMDsteps / (nLambda * nPerVFs);
407409
if( nBatches < 1 ) nBatches = 1;
@@ -416,20 +418,15 @@ void TI_step(double lambda, double dE, double sigma, double dLambda, int nMDstep
416418
printf(" Equilibrating %d steps...\n", nEQsteps);
417419
bSaveTrajectory = false;
418420
nPerVFs = nEQsteps;
419-
for(int isys=0; isys<nSystems; isys++){
420-
TDrive[isys].z = -1;
421-
}
422-
ocl.upload( ocl.ibuff_TDrive, TDrive );
421+
for(int isys=0; isys<nSystems; isys++){ jeParams[isys].x = -1; }
422+
ocl.upload( ocl.ibuff_jeParams, jeParams );
423423
run_ocl_opt( nEQsteps, Fconv );
424424

425425
// 2. Pulling
426426
bSaveTrajectory = true;
427-
nPerVFs = nPerVFs_; // High resolution pulling
428-
std::vector<bool> bExploring_old(nSystems);
429-
for(int isys=0; isys<nSystems; isys++){
430-
TDrive[isys].z = 0;
431-
}
432-
ocl.upload( ocl.ibuff_TDrive, TDrive );
427+
nPerVFs = nPerVFs_;
428+
for(int isys=0; isys<nSystems; isys++){ jeParams[isys].x = 0; jeParams[isys].y = nLambda; jeParams[isys].z = nPerVFs; jeParams[isys].w = 0; }
429+
ocl.upload( ocl.ibuff_jeParams, jeParams );
433430
// Clear work buffer
434431
float* zero_work = new float[nSystems * nLambda];
435432
for(int i=0; i<nSystems * nLambda; i++) zero_work[i] = 0;
@@ -438,16 +435,24 @@ void TI_step(double lambda, double dE, double sigma, double dLambda, int nMDstep
438435

439436
printf(" Pulling for %d * %d = %d steps...\n", nLambda, nPerVFs, nLambda*nPerVFs);
440437
run_ocl_opt( nLambda*nPerVFs, Fconv );
441-
for(int isys=0; isys<nSystems; isys++) gopts[isys].bExploring = bExploring_old[isys];
442438

443439
// 3. Download work and accumulate
444440
ocl.download( ocl.ibuff_work, gpu_work );
445441
ocl.finishRaw();
442+
// DEBUG: Print first few work values for system 0 in first batch
443+
if(batch==0){
444+
printf("DEBUG gpu_work[sys=0]: ");
445+
for(int i=0; i<5; i++) printf("[%d]=%g ", i, gpu_work[i]);
446+
printf(" dLambda=%g\n", dLambda);
447+
printf("DEBUG gpu_work[sys=0] dW: ");
448+
for(int i=0; i<5; i++) printf("[%d]=%g ", i, gpu_work[i]*dLambda);
449+
printf("\n");
450+
}
446451

447452
for(int isys=0; isys<nSystems; isys++){
448453
double W_traj = 0;
449454
for(int i=0; i<nLambda; i++){
450-
float dW = gpu_work[ isys * nLambda + i ];
455+
float dW = gpu_work[ isys * nLambda + i ]*dLambda;
451456
W_traj += (double)dW;
452457
//if(isys==0 && i==0) printf("batch %d, isys %d, i %d, W %f\n", batch, isys, i, W);
453458
sum_exp_W[i] += exp(-beta * W_traj);
@@ -526,6 +531,7 @@ void realloc( int nSystems_ ){
526531

527532
// Initialize averageForces buffer for thermodynamic integration
528533
_realloc0( averageForces, nSystems, Quat4fZero );
534+
_realloc0( jeParams, nSystems, Quat4iMinusOnes );
529535

530536
_realloc( pbcshifts, ocl.npbc*nSystems );
531537

@@ -1180,7 +1186,8 @@ double evalVFs( double Fconv=1e-6 ){
11801186
// TDrive[isys].z = 0;
11811187
// }
11821188
// printf("evalVFs() TDrive[isys].z = %f\n", TDrive[isys].z);
1183-
TDrive[isys].z += 1;
1189+
jeParams[isys].x += 1;
1190+
jeParams[isys].w = 0;
11841191
// printf("evalVFs() TDrive[isys].z = %f\n", TDrive[isys].z);
11851192
TDrive[isys].w = randf(-1.0,1.0);
11861193
}else{
@@ -1193,6 +1200,7 @@ double evalVFs( double Fconv=1e-6 ){
11931200
//printf( "MDpars{%g,%g,%g,%g}\n", MDpars[0].x,MDpars[0].y,MDpars[0].z,MDpars[0].w );
11941201
err |= ocl.upload( ocl.ibuff_MDpars, MDpars );
11951202
err |= ocl.upload( ocl.ibuff_TDrive, TDrive );
1203+
if(jeParams)err |= ocl.upload( ocl.ibuff_jeParams, jeParams );
11961204
err |= ocl.upload( ocl.ibuff_cvf , cvfs );
11971205
// //printf("MolWorld_sp3_multi::evalVFs() bGroupUpdate=%i \n", bGroupUpdate );
11981206
// if(bGroupUpdate){

cpp/common_resources/cl/relax_multi.cl

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@ __kernel void updateAtomsMMFFf4(
915915
__global float4* sysbonds, // 14 // // contains parameters of bonds (constrains) with neighbor systems {Lmin,Lmax,Kpres,Ktens}
916916
__global float4* averageForces, // 15 // contains average forces on atoms for Thermodynamic Integration
917917
__global float* work, // 16 // contains work recorded at each step for Jarzynski Equality
918-
const int4 jeParams
918+
__global int4* jeParams // 17 // parameters for Jarzynski Equality per system
919919
){
920920
const int natoms=n.x; // number of atoms
921921
const int nnode =n.y; // number of node atoms
@@ -1004,45 +1004,59 @@ __kernel void updateAtomsMMFFf4(
10041004
float4 cons = constr[ iaa ]; // constraints (x,y,z,K)
10051005

10061006

1007-
// if(iS==0 && iG==0)printf("GPU: iS=%i iG=%i cons.w=%g TDrive.z=%g \n", iS, iG, cons.w, TDrive.z );
1008-
if( work && (cons.w > 0.f) && (TDrive.z >= 0.f) ){
1009-
float4 consEnd = constrK[ iaa ];
1010-
int nLambda = (int)consEnd.w;
1011-
float lambda = TDrive.z/(float)(nLambda-1);
1012-
float k = cons.w; // Stiffness stored in .w
1013-
1014-
// Interpolate position
1015-
float3 p0 = cons.xyz;
1016-
float3 p1 = consEnd.xyz;
1017-
float3 target = p0 + (p1 - p0) * lambda;
1018-
1019-
// Compute Force (Harmonic)
1020-
// Force on atom = k * (target - pe)
1021-
float3 fc = (target - pe.xyz) * (float3){k,k,k};
1022-
fe.xyz += fc;
1023-
1024-
// Accumulate Work
1025-
// Work done ON system = integral of (dH/dLambda) dLambda
1026-
// H_spring = 0.5 * k * (x - x0(lambda))^2
1027-
// dH/dLambda = k * (x - x0) * (-dx0/dLambda)
1028-
// = k * (x - x0) * -(p1 - p0)
1029-
// = k * (x0 - x) * (p1 - p0) = fc * (p1 - p0)
1030-
// So we accumulate dot(fc, dir).
1031-
1032-
float3 dir = p1 - p0;
1033-
float work_term = dot(fc, dir);
1034-
work_term *= 1.0f/(float)(nLambda-1);
1035-
1036-
// Record work at this step if buffer provided
1037-
{
1038-
volatile __global float* addr = &work[ nLambda * iS + (int)(TDrive.z) ];
1039-
float old_val, new_val;
1040-
do {
1041-
old_val = *addr;
1042-
new_val = old_val + work_term;
1043-
} while (atomic_cmpxchg((volatile __global int*)addr, as_int(old_val), as_int(new_val)) != as_int(old_val));
1007+
// if(iS==0 && iG==0)printf("GPU: iS=%i iG=%i jeParams(%i,%i,%i,%i) \n", iS, iG, jeParams[iS].x, jeParams[iS].y, jeParams[iS].z, jeParams[iS].w );
1008+
if( (cons.w > 0.f) && (jeParams[iS].x >= -1) ){
1009+
// Jarzynski equality / Thermodynamic integration
1010+
// We use standard "constr" for initial position and "constrK" for final position
1011+
// But we use "cons.w" as stiffness
1012+
1013+
float4 consEnd = constrK[ iaa ];
1014+
int nLambda = jeParams[iS].y;
1015+
float k = cons.w;
1016+
1017+
float lambda;
1018+
if(jeParams[iS].x < 0){
1019+
lambda = 0.0f;
1020+
}else{
1021+
lambda = (float)jeParams[iS].x/(float)(nLambda-1);
1022+
}
1023+
1024+
float3 p0 = cons.xyz;
1025+
float3 p1 = consEnd.xyz;
1026+
1027+
float3 target = p0 + (p1 - p0) * lambda;
1028+
1029+
// Compute Force (Harmonic)
1030+
// Force on atom = k * (target - pe)
1031+
float3 fc = (target - pe.xyz) * (float3){k,k,k};
1032+
fe.xyz += fc;
1033+
1034+
// Accumulate Work
1035+
// Work done ON system = integral of (dH/dLambda) dLambda
1036+
// H_spring = 0.5 * k * (x - x0(lambda))^2
1037+
// dH/dLambda = k * (x - x0) * (-dx0/dLambda)
1038+
// = k * (x - x0) * -(p1 - p0)
1039+
// = k * (x0 - x) * (p1 - p0) = fc * (p1 - p0)
1040+
// So we accumulate dot(fc, dir).
1041+
1042+
float3 dir = p1 - p0;
1043+
float work_term = dot(fc, dir);
1044+
if( (jeParams[iS].w >= jeParams[iS].z - 1) && (jeParams[iS].x >= 0) ){
1045+
// Record work at this step if buffer provided
1046+
{
1047+
volatile __global float* addr = &work[ nLambda * iS + jeParams[iS].x ];
1048+
float old_val, new_val;
1049+
do {
1050+
old_val = *addr;
1051+
new_val = old_val + work_term;
1052+
} while (atomic_cmpxchg((volatile __global int*)addr, as_int(old_val), as_int(new_val)) != as_int(old_val));
1053+
1054+
}
1055+
}
1056+
else if(iG==0){
1057+
jeParams[iS].w += 1;
10441058
}
1045-
cons.w = 0.0f; // Disable standard logic
1059+
cons.w = 0.0f; // Disable standard logic
10461060
}
10471061

10481062
if( cons.w>0.f && (cons.w<1e3f) ){ // if stiffness is positive, we have constraint

cpp/common_resources/xyz/DA.xyz

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
46
2+
*****
3+
C -4.65120 3.04710 -3.84360
4+
C -3.28410 3.28530 -3.99840
5+
N -2.40600 2.76310 -3.10520
6+
C -2.84070 2.01120 -2.06220
7+
H -1.38350 2.94170 -3.22170 +0.5
8+
C -4.20780 1.77020 -1.90350
9+
N -5.10890 2.29170 -2.79860
10+
H -2.13080 1.59950 -1.35570
11+
H -4.51570 1.16740 -1.06160
12+
H -5.34170 3.47030 -4.56440
13+
H -2.92740 3.88370 -4.82730
14+
C -6.57120 2.08830 -2.70790
15+
C -7.03990 1.23320 -1.51740
16+
H -6.91820 1.59760 -3.64470
17+
H -7.06190 3.08440 -2.63310
18+
C -8.56180 1.08550 -1.51490
19+
H -6.72840 1.71970 -0.56730
20+
H -6.58440 0.22110 -1.58710
21+
Si -9.13340 -0.00000 0.00000
22+
H -8.90030 0.59210 -2.45070
23+
H -9.04310 2.08360 -1.43570
24+
H -8.48770 -1.40530 -0.10220
25+
H -10.67720 -0.13500 -0.01720
26+
H -8.68590 0.68140 1.31820
27+
C 2.61800 2.00310 -0.26920
28+
C 4.54350 3.38240 -0.13910
29+
C 1.78880 3.12720 -0.28380
30+
C 3.74330 4.52820 -0.15140
31+
C 2.35740 4.39800 -0.22430
32+
H 4.20280 5.50830 -0.10450
33+
H 0.71370 3.00530 -0.34100
34+
N 3.98010 2.12990 -0.19780
35+
H 5.61230 3.52730 -0.08190
36+
H 2.15940 1.02180 -0.31600
37+
C 4.76620 0.87660 -0.18910
38+
C 6.29220 1.05990 -0.11040
39+
H 4.53420 0.30840 -1.11750
40+
H 4.44580 0.26500 0.68390
41+
H 6.64260 1.64330 -0.98950
42+
C 7.00740 -0.28990 -0.10880
43+
H 6.55470 1.59910 0.82570
44+
H 6.77500 -0.84990 -1.03970
45+
H 6.68820 -0.89380 0.76710
46+
Si 9.13340 -0.00000 -0.00000
47+
O 1.6481 5.3886 -0.2359
48+
E 0.6481 5.3886 -0.2359 -0.5
Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,2 @@
1-
# Constraint positions for DA (dimer) thermodynamic integration
2-
#
3-
# Format: atom_index initial_x initial_y initial_z final_x final_y final_z
4-
#
5-
# CV: Distance between the two Si atoms (indices 18 and 43 in 0-indexed)
6-
# Si atoms from DA.mol2:
7-
# Atom 19 (index 18): Si at (-12.0498, -6.8896, 1.7148)
8-
# Atom 44 (index 43): Si at (3.1486, -10.9881, 10.9822)
9-
#
10-
# Initial distance ~18Å, pulling from 1Å to 20Å on x-axis
11-
12-
18 0.5 0.0 0.0 10.0 0.0 0.0
13-
43 -0.5 0.0 0.0 -10.0 0.0 0.0
1+
18 5.5 0.0 0.0 30.0 0.0 0.0
2+
43 -5.5 0.0 0.0 -30.0 0.0 0.0

examples/tFreeEnergy_multi/run_DA.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ echo "Step 2: Running Thermodynamic Integration for DA..."
2323
echo "----------------------------------------"
2424
python3 run_ES.py \
2525
--nSys 100 \
26-
--xyz_name "../../cpp/common_resources/DA.mol2" \
26+
--xyz_name "../../cpp/common_resources/xyz/DA.xyz" \
2727
--system_name "DA" \
2828
--nLambda 100 \
2929
--nMDsteps 2000000 \

examples/tFreeEnergy_multi/run_ES.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ while [[ "$#" -gt 0 ]]; do
1818
shift
1919
done
2020

21-
N=20
21+
N=30
2222

2323
# Ensure we are in the script directory
2424
cd "$(dirname "$0")"
@@ -46,12 +46,12 @@ python3 run_ES.py \
4646
--nSys 100 \
4747
--xyz_name "../tMMFF/data/entropic_spring_$N.xyz" \
4848
--system_name "entropic_spring_$N" \
49-
--nLambda 50000 \
50-
--nMDsteps 1000000 \
51-
--nEQsteps 5000 \
49+
--nLambda 100000 \
50+
--nMDsteps 100000 \
51+
--nEQsteps 50000 \
5252
--Fconv 1e-6 \
5353
--constraints "constraints_ES.txt" \
54-
--nPerVFs 10
54+
--nPerVFs 1
5555

5656
if [ $? -ne 0 ]; then
5757
echo "ERROR: Calculation failed!"

examples/tMolGUIapp_multi/run.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ touch minima.dat
211211
#./$name -m 2000 -x common_resources/xyz/xylitol_WO_gridFF -iParalel 3 -T 300 0.2 -gopt 1000,100000 0.25,1.0 -verb 0 -perframe 100 -grid_nPBC 2,2,0 # -nogridff
212212

213213

214-
./$name -m 50 -x common_resources/DA.mol2 -iParalel 3 -T 300 0.2 -gopt 1000,100000 0.25,1.0
215-
# ./$name -m 50 -x common_resources/entropic_spring_30.xyz -iParalel 3 -T 300 0.2 -gopt 1000,100000 0.25,1.0
214+
# ./$name -m 50 -x common_resources/DA.mol2 -iParalel 3 -T 300 0.2 -gopt 1000,100000 0.25,1.0
215+
# ./$name -m 50 -x common_resources/xyz/DA.xyz -iParalel 3 -T 300 0.2 -gopt 1000,100000 0.25,1.0
216+
./$name -m 50 -x common_resources/entropic_spring_30.xyz -iParalel 3 -T 300 0.2 -gopt 1000,100000 0.25,1.0
216217
# ./$name -m 2 -x common_resources/xyz/nHexadecan.xyz -iParalel 3 -T 300 0.2 -gopt 1000,100000 0.25,1.0

0 commit comments

Comments
 (0)