@@ -60,30 +60,33 @@ function backward_sampling!(code::EinCode, @nospecialize(xs::Tuple), @nospeciali
6060 # get probability
6161 prob_code = optimize_code (EinCode ([ixs... , iy], el), size_dict, GreedyMethod (; nrepeat= 1 ))
6262 el_prev = eliminated_variables (samples)
63+ @show el_prev=> subset (samples, el_prev)[:,1 ]
6364 xs = [eliminate_dimensions (x, ix, el_prev=> subset (samples, el_prev)[:,1 ]) for (ix, x) in zip (ixs, xs)]
6465 probabilities = einsum (prob_code, (xs... , env), size_dict)
66+ @show el
67+ @show normalize (real .(vec (probabilities)), 1 )
6568
6669 # sample from the probability tensor
6770 totalset = CartesianIndices ((map (x-> size_dict[x], el)... ,))
6871 eliminated_locs = idx4labels (samples. labels, el)
6972 config = StatsBase. sample (totalset, _Weights (vec (probabilities)))
73+ @show eliminated_locs, config. I .- 1
7074 samples. samples[eliminated_locs, 1 ] .= config. I .- 1
7175
7276 # eliminate the sampled variables
7377 set_eliminated! (samples, el)
74- for l in el
75- size_dict[l] = 1
76- end
78+ setindex! .(Ref (size_dict), 1 , el)
7779 sub = subset (samples, el)[:, 1 ]
80+ @show ixs, el=> sub
7881 xs = [eliminate_dimensions (x, ix, el=> sub) for (ix, x) in zip (ixs, xs)]
79- env = eliminate_dimensions (env, iy, el=> sub)
8082
8183 # update environment
82- return map (1 : length (ixs)) do i
84+ envs = map (1 : length (ixs)) do i
8385 rest = setdiff (1 : length (ixs), i)
8486 code = optimize_code (EinCode ([ixs[rest]. .. , iy], ixs[i]), size_dict, GreedyMethod (; nrepeat= 1 ))
8587 einsum (code, (xs[rest]. .. , env), size_dict)
8688 end
89+ @show envs
8790end
8891
8992function eliminate_dimensions (x:: AbstractArray{T, N} , ix:: AbstractVector{L} , el:: Pair{<:AbstractVector{L}} ) where {T, N, L}
@@ -95,6 +98,7 @@ function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el:
9598 1 : size (x, i)
9699 end
97100 end
101+ @show idx
98102 return asarray (x[idx... ], x)
99103end
100104
@@ -185,11 +189,22 @@ function generate_samples!(se::SlicedEinsum, cache::CacheTree{T}, env::AbstractA
185189 return generate_samples! (se. eins, cache, env, samples, size_dict)
186190end
187191function generate_samples! (code:: NestedEinsum , cache:: CacheTree{T} , env:: AbstractArray{T} , samples:: Samples , size_dict:: Dict ) where {T}
188- if ! OMEinsum. isleaf (code)
192+ @info " @"
193+ if ! (OMEinsum. isleaf (code))
194+ @info " non-leaf node"
195+ @show env
189196 xs = ntuple (i -> cache. children[i]. content, length (cache. children))
190197 envs = backward_sampling! (code. eins, xs, env, samples, size_dict)
191- for (arg, sib, env) in zip (code. args, cache. children, envs)
192- generate_samples! (arg, sib, env, samples, size_dict)
198+ @show envs
199+ fucks = map (1 : length (code. args)) do k
200+ @info k
201+ generate_samples! (code. args[k], cache. children[k], envs[k], samples, size_dict)
202+ return " fuck"
193203 end
204+ @info fucks
205+ return
206+ else
207+ @info " leaf node"
208+ return
194209 end
195210end
0 commit comments