Skip to content

Commit 6170325

Browse files
committed
save
1 parent 8c3c7e5 commit 6170325

File tree

3 files changed

+41
-19
lines changed

3 files changed

+41
-19
lines changed

src/Core.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ Get the cardinalities of variables in this tensor network.
204204
"""
205205
function get_cards(tn::TensorNetworkModel; fixedisone = false)::Vector
206206
vars = get_vars(tn)
207-
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : length(tn.tensors[k]) for k in eachindex(vars)]
207+
size_dict = OMEinsum.get_size_dict(getixsv(tn.code), tn.tensors)
208+
[fixedisone && haskey(tn.evidence, vars[k]) ? 1 : size_dict[vars[k]] for k in eachindex(vars)]
208209
end
209210

210211
chevidence(tn::TensorNetworkModel, evidence) = TensorNetworkModel(tn.vars, tn.code, tn.tensors, evidence)

src/sampling.jl

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,30 +60,33 @@ function backward_sampling!(code::EinCode, @nospecialize(xs::Tuple), @nospeciali
6060
# get probability
6161
prob_code = optimize_code(EinCode([ixs..., iy], el), size_dict, GreedyMethod(; nrepeat=1))
6262
el_prev = eliminated_variables(samples)
63+
@show el_prev=>subset(samples, el_prev)[:,1]
6364
xs = [eliminate_dimensions(x, ix, el_prev=>subset(samples, el_prev)[:,1]) for (ix, x) in zip(ixs, xs)]
6465
probabilities = einsum(prob_code, (xs..., env), size_dict)
66+
@show el
67+
@show normalize(real.(vec(probabilities)), 1)
6568

6669
# sample from the probability tensor
6770
totalset = CartesianIndices((map(x->size_dict[x], el)...,))
6871
eliminated_locs = idx4labels(samples.labels, el)
6972
config = StatsBase.sample(totalset, _Weights(vec(probabilities)))
73+
@show eliminated_locs, config.I .- 1
7074
samples.samples[eliminated_locs, 1] .= config.I .- 1
7175

7276
# eliminate the sampled variables
7377
set_eliminated!(samples, el)
74-
for l in el
75-
size_dict[l] = 1
76-
end
78+
setindex!.(Ref(size_dict), 1, el)
7779
sub = subset(samples, el)[:, 1]
80+
@show ixs, el=>sub
7881
xs = [eliminate_dimensions(x, ix, el=>sub) for (ix, x) in zip(ixs, xs)]
79-
env = eliminate_dimensions(env, iy, el=>sub)
8082

8183
# update environment
82-
return map(1:length(ixs)) do i
84+
envs = map(1:length(ixs)) do i
8385
rest = setdiff(1:length(ixs), i)
8486
code = optimize_code(EinCode([ixs[rest]..., iy], ixs[i]), size_dict, GreedyMethod(; nrepeat=1))
8587
einsum(code, (xs[rest]..., env), size_dict)
8688
end
89+
@show envs
8790
end
8891

8992
function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el::Pair{<:AbstractVector{L}}) where {T, N, L}
@@ -95,6 +98,7 @@ function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el:
9598
1:size(x, i)
9699
end
97100
end
101+
@show idx
98102
return asarray(x[idx...], x)
99103
end
100104

@@ -185,11 +189,22 @@ function generate_samples!(se::SlicedEinsum, cache::CacheTree{T}, env::AbstractA
185189
return generate_samples!(se.eins, cache, env, samples, size_dict)
186190
end
187191
function generate_samples!(code::NestedEinsum, cache::CacheTree{T}, env::AbstractArray{T}, samples::Samples, size_dict::Dict) where {T}
188-
if !OMEinsum.isleaf(code)
192+
@info "@"
193+
if !(OMEinsum.isleaf(code))
194+
@info "non-leaf node"
195+
@show env
189196
xs = ntuple(i -> cache.children[i].content, length(cache.children))
190197
envs = backward_sampling!(code.eins, xs, env, samples, size_dict)
191-
for (arg, sib, env) in zip(code.args, cache.children, envs)
192-
generate_samples!(arg, sib, env, samples, size_dict)
198+
@show envs
199+
fucks = map(1:length(code.args)) do k
200+
@info k
201+
generate_samples!(code.args[k], cache.children[k], envs[k], samples, size_dict)
202+
return "fuck"
193203
end
204+
@info fucks
205+
return
206+
else
207+
@info "leaf node"
208+
return
194209
end
195210
end

test/sampling.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using TensorInference, Test
2-
using StatsBase: kldivergence
1+
using TensorInference, Test, LinearAlgebra
2+
import StatsBase
33
using OMEinsum
44

55
@testset "sampling" begin
@@ -77,16 +77,22 @@ end
7777
optimize_code(DynamicEinCode(ixs, Int[]), OMEinsum.get_size_dict(ixs, tensors), GreedyMethod()),
7878
tensors,
7979
Dict{Int, Int}(),
80-
[[i] for i=5:10]
80+
Vector{Int}[]
8181
)
82-
num_samples = 1
83-
samples = sample(mps, num_samples; queryvars=[1, 2, 3, 4])
82+
num_samples = 1000
83+
samples = map(1:num_samples) do i
84+
sample(mps, 1; queryvars=[1, 2, 3, 4]).samples[:, 1]
85+
end
8486
indices = map(samples) do sample
8587
sum(i->sample[i] * 2^(i-1), 1:4) + 1
8688
end
87-
@show samples
88-
@show indices
89-
probs = vec(DynamicEinCode(ixs, collect(1:4))(tensors...))
90-
negative_loglikelyhood(samples, probs) = -sum(log.(probs[indices]))
91-
@test negative_loglikelyhood(samples, probs)
89+
distribution = map(1:16) do i
90+
count(j->j==i, indices) / num_samples
91+
end
92+
probs = normalize!(real.(vec(DynamicEinCode(ixs, collect(1:4))(tensors...))), 1)
93+
#indices = StatsBase.sample(1:16, StatsBase.Weights(probs), 1000)
94+
negative_loglikelyhood(probs, samples) = -sum(log.(probs[samples]))/length(samples)
95+
entropy(probs) = -sum(probs .* log.(probs))
96+
@show distribution, probs
97+
@test negative_loglikelyhood(probs, indices) entropy(probs) atol=1e-1
9298
end

0 commit comments

Comments
 (0)