Skip to content

Commit 9caf8bc

Browse files
committed
Increment of lambda is now on GPU (it is faster, but I wanted it differently)
1 parent 10a624a commit 9caf8bc

3 files changed

Lines changed: 85 additions & 79 deletions

File tree

cpp/common/molecular/MolWorld_sp3_multi.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,12 @@ void TI_step(double lambda, double dE, double sigma, double dLambda, int nMDstep
393393
if(si_count < nCVs){
394394
Quat4f acon = Quat4f{initial_positions[si_count].x, initial_positions[si_count].y, initial_positions[si_count].z, JEforceconst};
395395
Quat4f aconK = Quat4f{final_positions[si_count].x, final_positions[si_count].y, final_positions[si_count].z, 0.0f};
396+
if(si_count == 0){ // for incrementing of lambda step index in the kernel
397+
aconK.w = 1.0f;
398+
}
396399
constr [isys*ocl.nAtoms + ia] = acon;
397400
constrK[isys*ocl.nAtoms + ia] = aconK;
401+
398402
}
399403
si_count++;
400404
}
@@ -406,7 +410,7 @@ void TI_step(double lambda, double dE, double sigma, double dLambda, int nMDstep
406410
double beta = 1.0 / (const_kB * go.T_target);
407411
double dLambda = 1.0 / (double)(nLambda - 1);
408412
nPerVFs = nPerVFs_;
409-
int nBatches = nMDsteps / (nLambda * nPerVFs * nSystems);
413+
int nBatches = nMDsteps / (nLambda * nSystems);
410414
if( nBatches < 1 ) nBatches = 1;
411415

412416
printf(" Running %d batches of %d pulling steps each...\n", nBatches, nLambda);
@@ -423,15 +427,16 @@ void TI_step(double lambda, double dE, double sigma, double dLambda, int nMDstep
423427
run_ocl_opt( nEQsteps, Fconv );
424428

425429
// 2. Pulling
430+
// nPerVFs = nLambda;
426431
nPerVFs = nPerVFs_;
427-
for(int isys=0; isys<nSystems; isys++){ jeParams[isys].x = 0; jeParams[isys].y = nLambda; jeParams[isys].z = nPerVFs; jeParams[isys].w = 0; }
432+
for(int isys=0; isys<nSystems; isys++){ jeParams[isys].x = 0; jeParams[isys].y = nLambda; }
428433
ocl.upload( ocl.ibuff_jeParams, jeParams );
429434

430435
for(int i=0; i<nSystems * nLambda; i++) gpu_work[i] = 0;
431436
ocl.upload( ocl.ibuff_work, gpu_work );
432437

433-
printf(" Pulling for %d * %d = %d steps...\n", nLambda, nPerVFs, nLambda*nPerVFs);
434-
run_ocl_opt( nLambda*nPerVFs, Fconv );
438+
printf(" Pulling for %d steps...\n", nLambda);
439+
run_ocl_opt( nLambda, Fconv );
435440

436441
// 3. Download work and accumulate
437442
ocl.download( ocl.ibuff_work, gpu_work );
@@ -441,6 +446,9 @@ void TI_step(double lambda, double dE, double sigma, double dLambda, int nMDstep
441446
double W_traj = 0;
442447
for(int i=0; i<nLambda; i++){
443448
float dW = gpu_work[ isys * nLambda + i ]*dLambda;
449+
// if(isys == 0 && i%1000 == 0){
450+
// printf(" System %d, lambda %f, dW %f, cumulative W %f\n", isys, (float)i/(float)(nLambda-1), dW, W_traj + dW);
451+
// }
444452
W_traj += (double)dW;
445453
sum_exp_W[i] += exp(-beta * W_traj);
446454
}
@@ -1173,10 +1181,6 @@ double evalVFs( double Fconv=1e-6 ){
11731181
// TDrive[isys].z = 0;
11741182
// }
11751183
// printf("evalVFs() TDrive[isys].z = %f\n", TDrive[isys].z);
1176-
if( jeParams && jeParams[isys].x >= 0 ){
1177-
jeParams[isys].x += 1;
1178-
jeParams[isys].w = 0;
1179-
}
11801184
// printf("evalVFs() TDrive[isys].z = %f\n", TDrive[isys].z);
11811185
TDrive[isys].w = randf(-1.0,1.0);
11821186
}else{
@@ -1189,7 +1193,6 @@ double evalVFs( double Fconv=1e-6 ){
11891193
//printf( "MDpars{%g,%g,%g,%g}\n", MDpars[0].x,MDpars[0].y,MDpars[0].z,MDpars[0].w );
11901194
err |= ocl.upload( ocl.ibuff_MDpars, MDpars );
11911195
err |= ocl.upload( ocl.ibuff_TDrive, TDrive );
1192-
if( jeParams)err |= ocl.upload( ocl.ibuff_jeParams, jeParams );
11931196
err |= ocl.upload( ocl.ibuff_cvf , cvfs );
11941197
// //printf("MolWorld_sp3_multi::evalVFs() bGroupUpdate=%i \n", bGroupUpdate );
11951198
// if(bGroupUpdate){

cpp/common_resources/cl/relax_multi.cl

Lines changed: 68 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,8 @@ __kernel void getMMFFf4(
342342
// --- Evaluate bond-length stretching energy and forces
343343
if(iG<ing){
344344
// Bond stretching with proper MMFF parameters from bL[i] and bK[i]
345-
E+= evalBond( h.xyz, l-bL[i], bK[i], &f1 ); fbs[i]-=f1; fa+=f1; // harmonic bond stretching, fa is force on center atom, fbs[i] is recoil force on i-th neighbor,
346-
//E+= evalBond( h.xyz, l-1.198f, 40.f, &f1 ); fbs[i]-=f1; fa+=f1; // harmonic bond stretching, fa is force on center atom, fbs[i] is recoil force on i-th neighbor,
345+
//E+= evalBond( h.xyz, l-bL[i], bK[i], &f1 ); fbs[i]-=f1; fa+=f1; // harmonic bond stretching, fa is force on center atom, fbs[i] is recoil force on i-th neighbor,
346+
E+= evalBond( h.xyz, l-1.198f, 40.f, &f1 ); fbs[i]-=f1; fa+=f1; // harmonic bond stretching, fa is force on center atom, fbs[i] is recoil force on i-th neighbor,
347347

348348
// pi-pi alignment interaction
349349
float kpp = Kppi[i];
@@ -385,26 +385,26 @@ __kernel void getMMFFf4(
385385
const int jnga = jng+i0a;
386386
const float4 hj = hs[j];
387387

388-
E += evalAngleCosHalf( hi, hj, par.xy, par.z, &f1, &f2 ); // evaluate angular force and energy using cos(angle/2) formulation
389-
fa -= f1+f2;
390-
391-
//if(bSubtractVdW)
392-
{ // Remove non-bonded interactions from atoms that are bonded to common neighbor
393-
float4 REQi=REQKs[inga]; // non-bonding parameters of i-th neighbor
394-
float4 REQj=REQKs[jnga]; // non-bonding parameters of j-th neighbor
395-
// combine non-bonding parameters of i-th and j-th neighbors using mixing rules
396-
float4 REQij;
397-
REQij.x = REQi.x + REQj.x;
398-
REQij.yz = REQi.yz * REQj.yz;
388+
// E += evalAngleCosHalf( hi, hj, par.xy, par.z, &f1, &f2 ); // evaluate angular force and energy using cos(angle/2) formulation
389+
// fa -= f1+f2;
390+
391+
// //if(bSubtractVdW)
392+
// { // Remove non-bonded interactions from atoms that are bonded to common neighbor
393+
// float4 REQi=REQKs[inga]; // non-bonding parameters of i-th neighbor
394+
// float4 REQj=REQKs[jnga]; // non-bonding parameters of j-th neighbor
395+
// // combine non-bonding parameters of i-th and j-th neighbors using mixing rules
396+
// float4 REQij;
397+
// REQij.x = REQi.x + REQj.x;
398+
// REQij.yz = REQi.yz * REQj.yz;
399399

400-
float3 dp = (hj.xyz/hj.w) - (hi.xyz/hi.w); // recover vector between i-th and j-th neighbors using stored vectos and inverse bond lengths, this should be faster than dp=apos[jngv].xyz-apos[ingv].xyz; from global memory
401-
float4 fij = getLJQH( dp, REQij, 1.0f ); // compute non-bonded interaction between i-th and j-th neighbors using Lennard-Jones and Coulomb interactions and Hydrogen bond correction
402-
f1 -= fij.xyz;
403-
f2 += fij.xyz;
404-
}
405-
406-
fbs[i]+= f1;
407-
fbs[j]+= f2;
400+
// float3 dp = (hj.xyz/hj.w) - (hi.xyz/hi.w); // recover vector between i-th and j-th neighbors using stored vectos and inverse bond lengths, this should be faster than dp=apos[jngv].xyz-apos[ingv].xyz; from global memory
401+
// float4 fij = getLJQH( dp, REQij, 1.0f ); // compute non-bonded interaction between i-th and j-th neighbors using Lennard-Jones and Coulomb interactions and Hydrogen bond correction
402+
// f1 -= fij.xyz;
403+
// f2 += fij.xyz;
404+
// }
405+
406+
// fbs[i]+= f1;
407+
// fbs[j]+= f2;
408408
}
409409
}
410410

@@ -1004,54 +1004,57 @@ __kernel void updateAtomsMMFFf4(
10041004
// ------- constrains
10051005
float4 cons = constr[ iaa ]; // constraints (x,y,z,K)
10061006

1007-
if( (cons.w > 0.f) && jeParams && (jeParams[iS].x > -1) ){
1008-
// Jarzynski equality
1007+
if( (cons.w > 0.f) && jeParams ){
1008+
// Jarzynski equality or Setup Equilibration
10091009
// We use standard "constr" for initial position and "constrK" for final position
10101010
// But we use "cons.w" as stiffness
1011-
// jeParams are (iLambda, nLambda, nPerVF, iStep) for Jarzynski Equality
10121011

10131012
float4 consEnd = constrK[ iaa ];
1014-
int nLambda = jeParams[iS].y;
1015-
float k = cons.w;
1016-
1017-
float lambda = (float)jeParams[iS].x/(float)(nLambda-1);
1018-
1019-
float3 p0 = cons.xyz;
1020-
float3 p1 = consEnd.xyz;
1021-
1022-
float3 target = p0 + (p1 - p0) * lambda;
1023-
1024-
// Compute Force (Harmonic)
1025-
// Force on atom = k * (target - pe)
1026-
float3 fc = (target - pe.xyz) * (float3){k,k,k};
1027-
fe.xyz += fc;
1028-
1029-
// Accumulate Work
1030-
// Work done ON system = integral of (dH/dLambda) dLambda
1031-
// H_spring = 0.5 * k * (x - x0(lambda))^2
1032-
// dH/dLambda = k * (x - x0) * (-dx0/dLambda)
1033-
// = k * (x - x0) * -(p1 - p0)
1034-
// = k * (x0 - x) * (p1 - p0) = fc * (p1 - p0)
1035-
// So we accumulate dot(fc, dir).
1036-
1037-
float3 dir = p1 - p0;
1038-
float work_term = dot(fc, dir);
1039-
if( (jeParams[iS].w >= jeParams[iS].z - 1) && (jeParams[iS].x >= 0) ){
1040-
// Record work at this step if buffer provided
1041-
{
1042-
volatile __global float* addr = &work[ nLambda * iS + jeParams[iS].x ];
1043-
float old_val, new_val;
1044-
do {
1045-
old_val = *addr;
1046-
new_val = old_val + work_term;
1047-
} while (atomic_cmpxchg((volatile __global int*)addr, as_int(old_val), as_int(new_val)) != as_int(old_val));
1048-
1049-
1050-
}
1051-
}
1052-
else if(iG==0){
1053-
jeParams[iS].w += 1;
1054-
}
1013+
int step = jeParams[iS].x;
1014+
if (step > -1) {
1015+
int nLambda = jeParams[iS].y;
1016+
float k = cons.w;
1017+
1018+
// RESTRICTION: Ensure step does not overrun nLambda and cause OOB memory access
1019+
if (step < nLambda) {
1020+
// Calculate lambda continuously
1021+
float lambda = min(1.0f, (float)step / (float)((nLambda - 1)));
1022+
1023+
float3 p0 = cons.xyz;
1024+
float3 p1 = consEnd.xyz;
1025+
1026+
float3 target = p0 + (p1 - p0) * lambda;
1027+
1028+
// Compute Force (Harmonic)
1029+
float3 fc = (target - pe.xyz) * (float3){k,k,k};
1030+
fe.xyz += fc;
1031+
1032+
// Accumulate Work
1033+
float3 dir = p1 - p0;
1034+
float work_term = dot(fc, dir);
1035+
1036+
// Record work at this safely bounded step
1037+
{
1038+
volatile __global float* addr = &work[ nLambda * iS + step ];
1039+
float old_val, new_val;
1040+
do {
1041+
old_val = *addr;
1042+
new_val = old_val + work_term;
1043+
} while (atomic_cmpxchg((volatile __global int*)addr, as_int(old_val), as_int(new_val)) != as_int(old_val));
1044+
}
1045+
} // End of Restriction
1046+
1047+
if(consEnd.w>0.0f && step < nLambda){
1048+
jeParams[iS].x += 1; // Increment step index for next step exactly once per system
1049+
}
1050+
} else if (step == -1) {
1051+
// Initial Equilibrium Step
1052+
float k = cons.w;
1053+
1054+
float3 target = cons.xyz;
1055+
float3 fc = (target - pe.xyz) * (float3){k,k,k};
1056+
fe.xyz += fc;
1057+
}
10551058
cons.w = 0.0f; // Disable standard logic
10561059
}
10571060

examples/tFreeEnergy_multi/run_ES.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ echo "Step 2: Running Free Energy Calculation (Mode: $MODE)..."
4545
echo "----------------------------------------"
4646
python3 run_ES.py \
4747
--mode $MODE \
48-
--nSys 10 \
48+
--nSys 20 \
4949
--xyz_name "../tMMFF/data/entropic_spring_$N.xyz" \
5050
--system_name "entropic_spring_$N" \
5151
--nLambda 100000 \
52-
--nMDsteps 10000000 \
52+
--nMDsteps 2000000 \
5353
--nEQsteps 50000 \
5454
--Fconv 1e-6 \
5555
--constraints "constraints_ES.txt" \
@@ -77,8 +77,8 @@ echo " Completed successfully!"
7777
echo "=========================================="
7878
echo ""
7979
echo "Output files:"
80-
echo " - entropic_spring_${N}_TI.dat (raw data)"
81-
echo " - entropic_spring_${N}_TI_interactive.html (interactive plot)"
80+
echo " - entropic_spring_${N}_free_energy.dat (raw data)"
81+
echo " - entropic_spring_${N}_free_energy_interactive.html (interactive plot)"
8282
echo ""
83-
echo "To view the interactive plot, open entropic_spring_${N}_TI_interactive.html in a web browser"
83+
echo "To view the interactive plot, open entropic_spring_${N}_free_energy_interactive.html in a web browser"
8484
echo ""

0 commit comments

Comments
 (0)