Skip to content

Commit c160f95

Browse files
committed
allow to supply position and make bare_position fixed size array
1 parent ebeca70 commit c160f95

4 files changed

Lines changed: 16 additions & 3 deletions

File tree

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Carlo = "780c37f4-4e5a-43de-9e79-65c261e525a4"
99
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
1010
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1111
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
12+
FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2"
1213
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
1314
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1415
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -20,6 +21,7 @@ Carlo = "0.3.1"
2021
ChunkSplitters = "3.1.2"
2122
DataFrames = "1.8.1"
2223
Distributions = "0.25.123"
24+
FixedSizeArrays = "1.3.0"
2325
HDF5 = "0.17.2"
2426
Printf = "1.11.0"
2527
Random = "1.11.0"

src/WaveMC.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using ThreadsX
2121
using ChunkSplitters
2222
using CairoMakie
2323
using DataFrames
24+
using FixedSizeArrays
2425

2526
include("wavefunction/wavefunction.jl")
2627
include("acceptance_rate/acceptance_rate_adapter.jl")

src/WavefunctionMC.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,17 @@ function WavefunctionMC(params::AbstractDict)
5353
sigma_dist = get(params, :sigma_distribution, 0.3)
5454
distribution = get(params, :distribution, S <: Complex ? ComplexNormal(0, sigma_dist) : Normal(0, sigma_dist))
5555
dynamic_pos = get(params, :dynamic_positions, (N, 1:N))
56-
position = [i in dynamic_pos ? coordinate_proj(100 * rand(distribution)) : coordinate_proj(0 * rand(distribution)) for i in 1:N]
56+
57+
position = get(params, :position, missing)
58+
if ismissing(position)
59+
position = [i in dynamic_pos ? 100 * rand(distribution) : 0 * rand(distribution) for i in 1:N]
60+
end
61+
if missing in position || nothing in position
62+
inds = findall(x -> x === missing || x === nothing, position)
63+
position[inds] = [i in dynamic_pos ? 100 * rand(distribution) : 0 * rand(distribution) for i in inds]
64+
end
65+
position = FixedSizeArray(coordinate_proj.(position))
66+
@show typeof(position)
5767

5868
observables = get(params, :observables, NoObservables())
5969

src/state.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
mutable struct State{S, B <: Buffer, F, C <: Integer}
2-
bare_position::Vector{S}
1+
mutable struct State{S, B <: Buffer, F, C <: Integer, A<:AbstractVector{S}}
2+
bare_position::A
33
position::B
44
logdensity::F
55
num_accepts::C

0 commit comments

Comments
 (0)