Skip to content

Commit 35901a4

Browse files
committed
Fix up initial velocities, deployment script and add rattling option for optimisation
1 parent 74a565f commit 35901a4

3 files changed

Lines changed: 112 additions & 17 deletions

File tree

.github/workflows/website.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@ jobs:
6464
working-directory: website
6565
run: npm run build
6666

67-
- name: Convert PET-MAD model
68-
run: |
69-
uv run scripts/convert_pet_mad.py --output website/dist/pet-mad.gguf
70-
7167
- name: Setup Pages
7268
uses: actions/configure-pages@v4
7369

website/src/components/MolecularDynamics.tsx

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,7 @@ export default function MolecularDynamics() {
297297
const [mode, setMode] = useState<'md' | 'optimize'>('md')
298298
const [maxOptSteps, setMaxOptSteps] = useState(100)
299299
const [forceThreshold, setForceThreshold] = useState(0.05)
300+
const [rattleAmount, setRattleAmount] = useState(0.1) // Angstroms
300301

301302
// Initialize NGL Stage
302303
useEffect(() => {
@@ -408,6 +409,11 @@ export default function MolecularDynamics() {
408409
setState(s => ({ ...s, isRunning: false }))
409410
break
410411

412+
case 'rattled':
413+
// Update visualization with rattled positions
414+
updateVisualization(msg.positions)
415+
break
416+
411417
case 'error':
412418
setState(s => ({ ...s, error: msg.message, isRunning: false }))
413419
break
@@ -657,13 +663,21 @@ export default function MolecularDynamics() {
657663
setEnergyHistory([])
658664
setState(s => ({ ...s, step: 0, optimizationConverged: false }))
659665
lastStepTimeRef.current = 0
660-
workerRef.current?.postMessage({ type: 'start', stepsPerFrame: 1, mode })
666+
workerRef.current?.postMessage({
667+
type: 'start',
668+
stepsPerFrame: 1,
669+
mode,
670+
})
661671
}
662672

663673
const stopMD = () => {
664674
workerRef.current?.postMessage({ type: 'stop' })
665675
}
666676

677+
const rattleStructure = () => {
678+
workerRef.current?.postMessage({ type: 'rattle', amount: rattleAmount })
679+
}
680+
667681
return (
668682
<div className="md-simulation">
669683
{/* Left panel - Structure and parameters */}
@@ -789,8 +803,32 @@ export default function MolecularDynamics() {
789803
/>
790804
</div>
791805
</div>
806+
<div className="params-row">
807+
<div className="control-group">
808+
<label>Rattle (Å)</label>
809+
<input
810+
type="number"
811+
value={rattleAmount}
812+
onChange={e => setRattleAmount(Number(e.target.value))}
813+
min={0}
814+
max={1.0}
815+
step={0.05}
816+
className="number-input"
817+
/>
818+
</div>
819+
<div className="control-group">
820+
<label>&nbsp;</label>
821+
<button
822+
onClick={rattleStructure}
823+
className="control-button"
824+
disabled={!state.isModelLoaded || atomicNumbersRef.current.length === 0 || state.isRunning}
825+
>
826+
Rattle
827+
</button>
828+
</div>
829+
</div>
792830
<p className="nc-note">
793-
FIRE geometry optimization using gradient forces.
831+
FIRE optimization. Rattle perturbs atom positions.
794832
</p>
795833
</>
796834
)}
@@ -867,15 +905,16 @@ export default function MolecularDynamics() {
867905
stroke="rgba(59, 130, 246, 0.9)"
868906
strokeWidth="2.5"
869907
/>
870-
{/* Show dot at current position */}
871-
{dots.length > 0 && (
908+
{/* Show small dots at each data point */}
909+
{dots.map((dot, i) => (
872910
<circle
873-
cx={dots[dots.length - 1].x}
874-
cy={dots[dots.length - 1].y}
875-
r="8"
911+
key={i}
912+
cx={dot.x}
913+
cy={dot.y}
914+
r={i === dots.length - 1 ? 6 : 3}
876915
fill="rgba(59, 130, 246, 1)"
877916
/>
878-
)}
917+
))}
879918
<text x="10" y="18" className="energy-label">
880919
{mode === 'md' ? 'Total E' : 'E'} (eV)
881920
</text>

website/src/workers/mdWorker.ts

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,15 +389,42 @@ const FIRE_F_DEC = 0.5
389389
const FIRE_N_MIN = 5
390390
const FIRE_DT_MAX = 1.0 // fs
391391

392-
// Reset FIRE optimizer state
392+
// Reset FIRE optimizer state and initialize velocities along force direction
393393
function resetFIRE(): void {
394394
state.fireAlpha = FIRE_ALPHA_START
395395
state.fireNpos = 0
396396
state.fireDt = 0.1 // Start with small timestep
397397
state.optStep = 0
398-
// Zero velocities
399-
if (state.velocities) {
400-
state.velocities.fill(0)
398+
399+
// Initialize velocities along force direction for faster startup
400+
if (state.module && state.model && state.positions && state.velocities && state.masses) {
401+
// Get initial forces
402+
const system = state.module.AtomicSystem.create(
403+
state.positions,
404+
state.atomicNumbers!,
405+
state.cell,
406+
state.isPeriodic
407+
)
408+
const result = state.model.predictWithOptions(system, false)
409+
const forces = new Float64Array(result.forces)
410+
411+
// Calculate force magnitude
412+
let fNorm = 0
413+
for (let i = 0; i < forces.length; i++) {
414+
fNorm += forces[i] * forces[i]
415+
}
416+
fNorm = Math.sqrt(fNorm)
417+
418+
// Set initial velocity along force direction with small magnitude
419+
// v = dt * F / |F| gives unit velocity in force direction scaled by timestep
420+
if (fNorm > 1e-10) {
421+
const vScale = state.fireDt * 0.1 // Small initial velocity
422+
for (let i = 0; i < state.velocities.length; i++) {
423+
state.velocities[i] = vScale * forces[i] / fNorm
424+
}
425+
} else {
426+
state.velocities.fill(0)
427+
}
401428
}
402429
}
403430

@@ -607,7 +634,17 @@ function runMDStep(): void {
607634
}
608635
}
609636

610-
function handleStart(data: { stepsPerFrame?: number, mode?: 'md' | 'optimize' }): void {
637+
// Apply random perturbations to positions
638+
function rattlePositions(amount: number): void {
639+
if (!state.positions || amount <= 0) return
640+
641+
for (let i = 0; i < state.positions.length; i++) {
642+
// Uniform random in [-amount, +amount]
643+
state.positions[i] += (Math.random() * 2 - 1) * amount
644+
}
645+
}
646+
647+
function handleStart(data: { stepsPerFrame?: number, mode?: 'md' | 'optimize', rattleAmount?: number }): void {
611648
if (state.isRunning) return
612649

613650
// Update mode if provided
@@ -621,6 +658,11 @@ function handleStart(data: { stepsPerFrame?: number, mode?: 'md' | 'optimize' })
621658
// Reset FIRE state for new optimization
622659
resetFIRE()
623660

661+
// Apply rattle if requested
662+
if (data.rattleAmount && data.rattleAmount > 0) {
663+
rattlePositions(data.rattleAmount)
664+
}
665+
624666
// Run optimization steps at ~30 fps
625667
mdInterval = setInterval(() => {
626668
const done = runFIREStep()
@@ -656,6 +698,21 @@ function handleStep(): void {
656698
runMDStep()
657699
}
658700

701+
function handleRattle(data: { amount: number }): void {
702+
if (!state.positions) {
703+
self.postMessage({ type: 'error', message: 'No system loaded' })
704+
return
705+
}
706+
707+
rattlePositions(data.amount)
708+
709+
// Send back the new positions so visualization can update
710+
self.postMessage({
711+
type: 'rattled',
712+
positions: Array.from(state.positions),
713+
})
714+
}
715+
659716
// Message router
660717
self.onmessage = async (e: MessageEvent) => {
661718
const { type, ...data } = e.data
@@ -685,6 +742,9 @@ self.onmessage = async (e: MessageEvent) => {
685742
case 'step':
686743
handleStep()
687744
break
745+
case 'rattle':
746+
handleRattle(data)
747+
break
688748
default:
689749
self.postMessage({ type: 'error', message: `Unknown message type: ${type}` })
690750
}

0 commit comments

Comments
 (0)