Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,5 @@ for epoch in 1:100
jldsave("model_epoch_$epoch.jld", model_state = Flux.state(cpu(model)), opt_state=cpu(opt_state))
end
```

The code written __will__ break the other examples.
79 changes: 79 additions & 0 deletions scripts/reverse.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using Pkg
Pkg.add(["GLMakie", "ProtPlot", "ProgressBars"])

using ChainStorm, Flowfusion, GLMakie, ProtPlot, ProgressBars

@eval Flowfusion begin
function reverse_gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; tracker::Function=Returns(nothing), midpoint = false, progress_bar=identity)
Xₜ = copy.(X₀)
push!(record, (1, X₀, nothing))
for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end]))
T = eltype(s₁)
s2 = s₂
s1 = s₁ == s₂ ? s2 - T(0.001) : s₁
t = midpoint ? (s1 + s2) / 2 : t = s1
X0hat, X1hat = model(t, Xₜ)
X0hat = resolveprediction(X0hat, Xₜ)
X1hat = resolveprediction(X1hat, Xₜ)
Xₜ = mask(step(P, Xₜ, X0hat, s1, s2), X₀)

push!(record, (1-s₂, Xₜ, X1hat)) #records all the steps
tracker(1-t, Xₜ, X1hat)
end
return Xₜ
end
function gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector; tracker::Function=Returns(nothing), midpoint = false, progress_bar=identity)
Xₜ = copy.(X₀)
for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end]))
t = midpoint ? (s₁ + s₂) / 2 : t = s₁
hat = resolveprediction(model(t, Xₜ), Xₜ)
Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀)
tracker(t, Xₜ, hat)
end
return Xₜ
end

function bind_gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; tracker::Function=Returns(nothing), midpoint = false, progress_bar = identity)
Xₜ = copy.(X₀)
for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end]))
t = midpoint ? (s₁ + s₂) / 2 : t = s₁

hat = resolveprediction(model(t, Xₜ), Xₜ)

#Changes xt
s₁, old_xt, _ = record[end]
pop!(record)
tensor(Xₜ[1])[:, :, 1:size(tensor(old_xt[1]), 3), :] .= tensor(old_xt[1])
tensor(Xₜ[2])[:, :, 1:size(tensor(old_xt[2]), 3), :] .= tensor(old_xt[2])
tensor(Xₜ[3]).indices[1:size(tensor(old_xt[3]).indices, 1), :] .= tensor(old_xt[3]).indices


Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀)

if length(record) == 1
tensor(Xₜ[1])[:, :, 1:size(tensor(old_xt[1]), 3), :] .= tensor(old_xt[1])
tensor(Xₜ[2])[:, :, 1:size(tensor(old_xt[2]), 3), :] .= tensor(old_xt[2])
tensor(Xₜ[3]).indices[1:size(tensor(old_xt[3]).indices, 1), :] .= tensor(old_xt[3]).indices
end
tracker(t, Xₜ, hat)
end
return Xₜ
end

export bind_gen, reverse_gen
end


model = load_model();

b = ChainStorm.pdb2batch("path/to/pdb")

g, recorded = flow_quickgen(ChainStorm.reverse_process(ChainStorm.P), b, ChainStorm.compound_state(b), model, is_reverse = true, smooth=0, progress_bar = ProgressBar) #<- Model inference call

b = dummy_batch([ChainStorm.lengths_from_chainids(b.chainids); [10]])
paths = ChainStorm.Tracker()
g = flow_quickgen(ChainStorm.P, b, ChainStorm.zero_state(b), model, tracker = paths, smooth=0, record = recorded, progress_bar = ProgressBar)
id = join(string.(ChainStorm.lengths_from_chainids(b.chainids)),"_")*"-"*join(rand('A':'Z', 4))
export_pdb("$(id)_bind.pdb", g, b.chainids, b.resinds) #<- Save PDB
samp = gen2prot(g, b.chainids, b.resinds)
animate_trajectory("$(id)_bind.mp4", samp, first_trajectory(paths), viewmode = :fit)
22 changes: 22 additions & 0 deletions src/ChainStorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,28 @@ function load_model(; checkpoint = "ChainStormV1.jld2")
return Flux.loadmodel!(ChainStormV1(), JLD2.load(file, "model_state"))
end

function pdb2batch(struc::ProteinChains.ProteinStructure)
struc.cluster = 1
return DLProteinFormats.batch_flatrecs([DLProteinFormats.flatten(struc),])
end

function lengths_from_chainids(chainids)
counts = Int[] # Initialize an empty array to store counts
current_count = 1

for i in 2:length(chainids)
if chainids[i] == chainids[i - 1]
current_count += 1
else
push!(counts, current_count)
current_count = 1
end
end

push!(counts, current_count)
return counts
end

chainids_from_lengths(lengths) = vcat([repeat([i],l) for (i,l) in enumerate(lengths)]...)
function gen2prot(samp, chainids, resnums; name = "Gen", )
d = Dict(zip(0:25,'A':'Z'))
Expand Down
121 changes: 121 additions & 0 deletions src/extras/reversal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
using Pkg
Pkg.activate(".")
using Revise
Pkg.develop(path="../")

using Pkg
Pkg.add(["GLMakie", "ProtPlot", "ProgressBars", "CUDA", "cuDNN", "Flux"])

using ChainStorm, Flowfusion, GLMakie, ProtPlot, ProgressBars, CUDA, Flux

@eval Flowfusion begin
function reverse_gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; tracker::Function=Returns(nothing), midpoint = false, progress_bar=identity, snap_time = 0)
Xₜ = copy.(X₀)
push!(record, (1, X₀, nothing))
for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end]))
T = eltype(s₁)
s2 = s₂
s1 = s₁ == s₂ ? s2 - T(0.001) : s₁
t = midpoint ? (s1 + s2) / 2 : t = s1
X0hat, X1hat = model(t, Xₜ)
X0hat = resolveprediction(X0hat, Xₜ)
X1hat = resolveprediction(X1hat, Xₜ)
Xₜ = mask(step(P, Xₜ, X0hat, s1, s2), X₀)
if t < snap_time
fakeX1hat = deepcopy(X1hat)
tensor(fakeX1hat[1]) .= tensor(X₀[1])
tensor(fakeX1hat[2]) .= tensor(X₀[2])
push!(record, (1-s₂, deepcopy(Xₜ), fakeX1hat)) #records all the steps
else
push!(record, (1-s₂, deepcopy(Xₜ), deepcopy(X1hat))) #records all the steps
end
tracker(1-t, Xₜ, X1hat)
end
return Xₜ
end
function gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector; tracker::Function=Returns(nothing), midpoint = false, progress_bar=identity)
Xₜ = copy.(X₀)
for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end]))
t = midpoint ? (s₁ + s₂) / 2 : t = s₁
hat = resolveprediction(model(t, Xₜ), Xₜ)
Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀)
tracker(t, Xₜ, hat)
end
return Xₜ
end
function bind_gen(P::Tuple{Vararg{UProcess}}, X₀::Tuple{Vararg{UState}}, model, steps::AbstractVector, record; tracker::Function=Returns(nothing), midpoint = false, progress_bar = identity)
Xₜ = copy.(X₀)
for (s₁, s₂) in progress_bar(zip(steps, steps[begin+1:end]))
t = midpoint ? (s₁ + s₂) / 2 : t = s₁
hat = resolveprediction(model(t, Xₜ), Xₜ)
#Changes xt
s₁, old_xt, _ = record[end]
pop!(record)
tensor(Xₜ[1])[:, :, 1:size(tensor(old_xt[1]), 3), :] .= tensor(old_xt[1])
tensor(Xₜ[2])[:, :, 1:size(tensor(old_xt[2]), 3), :] .= tensor(old_xt[2])
tensor(Xₜ[3]).indices[1:size(tensor(old_xt[3]).indices, 1), :] .= tensor(old_xt[3]).indices
Xₜ = mask(step(P, Xₜ, hat, s₁, s₂), X₀)
if length(record) == 1
tensor(Xₜ[1])[:, :, 1:size(tensor(old_xt[1]), 3), :] .= tensor(old_xt[1])
tensor(Xₜ[2])[:, :, 1:size(tensor(old_xt[2]), 3), :] .= tensor(old_xt[2])
tensor(Xₜ[3]).indices[1:size(tensor(old_xt[3]).indices, 1), :] .= tensor(old_xt[3]).indices
end
tracker(t, Xₜ, hat)
end
return Xₜ
end
export bind_gen, reverse_gen
end


model = load_model() |> gpu

struc = pdb"7RBY"1
target = ChainStorm.pdb2batch(struc[[1]])


@time rev_g, recorded = flow_quickgen(ChainStorm.reverse_process(ChainStorm.P), target, ChainStorm.compound_state(target), model, is_reverse = true, smooth=0, d = gpu, steps = 0f0:0.005f0:1f0, progress_bar = ProgressBar, snap_time = 0.9f0);

#proportions = [mean(unhot(recorded[i][2][3]).S.state .== 21) for i in 1:length(recorded)]
#pl = Plots.plot(proportions, xlabel = "t step", label = :none, ylabel = "P(21)")
#savefig(pl, "proportions.pdf")

b = dummy_batch([ChainStorm.lengths_from_chainids(target.chainids); [122, 114]])
b.resinds[1:length(target.resinds)] .= target.resinds
binder_inds = length(target.resinds)+1:length(b.resinds)
b.resinds[binder_inds] .= 1:length(binder_inds)
X0 = ChainStorm.zero_state(b)
#If you want to bias the starting location:
#Flowfusion.tensor(X0[1])[:,1,binder_inds,1] .*= 0.5f0
#Flowfusion.tensor(X0[1])[:,1,binder_inds,1] .+= [0.0f0, 0.1f0, 0.1f0]

paths = ChainStorm.Tracker()
@time fwd_g = flow_quickgen(ChainStorm.P, b, X0, model, tracker = paths, smooth=0, record = deepcopy(recorded), progress_bar = ProgressBar, d = gpu, steps = 0f0:0.005f0:1f0);
id = join(string.(ChainStorm.lengths_from_chainids(b.chainids)),"_")*"-"*join(rand('A':'Z', 4))
export_pdb("samples/$(id)_bind.pdb", fwd_g, b.chainids, b.resinds) #<- Save PDB

samp = gen2prot(fwd_g, b.chainids, b.resinds)
animate_trajectory("samples/$(id)_bind.mp4", samp, first_trajectory(paths), viewmode = :fit)

for _ in 1:10
snap = rand([0f0, 0.9f0])
@time rev_g, recorded = flow_quickgen(ChainStorm.reverse_process(ChainStorm.P), target, ChainStorm.compound_state(target), model, is_reverse = true, smooth=0, d = gpu, steps = 0f0:0.005f0:1f0, progress_bar = ProgressBar, snap_time = snap); #, steps = 0f0:0.025f0:1f0); #<- Model inference call
lens = [200+rand(1:25),200+rand(1:25)]
if rand() < 0.33
lens = [100+rand(1:25),100+rand(1:25)]
end
if rand() < 0.33
lens = [rand(30:150)]
end
b = dummy_batch([ChainStorm.lengths_from_chainids(target.chainids); lens])
b.resinds[1:length(target.resinds)] .= target.resinds
binder_inds = length(target.resinds)+1:length(b.resinds)
b.resinds[binder_inds] .= 1:length(binder_inds)
X0 = ChainStorm.zero_state(b)
paths = ChainStorm.Tracker()
@time fwd_g = flow_quickgen(ChainStorm.P, b, X0, model, tracker = paths, smooth=0, record = deepcopy(recorded), progress_bar = ProgressBar, d = gpu, steps = 0f0:0.005f0:1f0);
id = join(string.(ChainStorm.lengths_from_chainids(b.chainids)),"_")*"-"*join(rand('A':'Z', 4))
export_pdb("samples/$(id)_$(snap)_bind.pdb", fwd_g, b.chainids, b.resinds) #<- Save PDB
samp = gen2prot(fwd_g, b.chainids, b.resinds)
animate_trajectory("samples/$(id)_$(snap)_bind.mp4", samp, first_trajectory(paths), viewmode = :fit)
end
104 changes: 98 additions & 6 deletions src/flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,32 @@ const rotM = Flowfusion.Rotations(3)
schedule_f(t) = 1-(1-t)^2
const P = (FProcess(BrownianMotion(0.2f0), schedule_f), FProcess(ManifoldProcess(0.2f0), schedule_f), NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21))

#Bringing Alexander's version in - this should be replaced by the "full" solution:
function rev_NoisyInterpolatingDiscreteFlow(noise; K = 1, dummy_token::T = nothing) where T
if (K > 1 && isnothing(dummy_token))
@warn "NoisyInterpolatingDiscreteFlow: If K>1 things might break if your X0 is not the `dummy_token` (which should also be passed to NoisyInterpolatingDiscreteFlow)."
end
return NoisyInterpolatingDiscreteFlow{T}(
t -> oftype(t,1-(1-cos((π/2)*(1-t)))^K), #K1
t -> oftype(t,(noise * sin(π*t))), #K2
t -> oftype(t,(K * (π/2) * cos((π/2) * t) * (1 - sin((π/2) * t))^(K - 1))), #dK1
t -> oftype(t,(noise*π*cos(π*t))), #dK2
dummy_token
)
end

function reverse_process(P)
continuous_schedule = t -> 1 - P[1].F(1 - t)
manifold_schedule = t -> 1 - P[2].F(1 - t)
κ₁ = t -> 1 - P[3].κ₁(1 - t)
dκ₁ = t -> P[3].dκ₁(1 - t)
κ₂ = t -> 1 - P[3].κ₂(1 - t)
dκ₂ = t -> P[3].dκ₁(1 - t)
#This is a hack fix:
#(Flowfusion.FProcess(P[1].P, continuous_schedule), Flowfusion.FProcess(P[2].P, manifold_schedule), Flowfusion.NoisyInterpolatingDiscreteFlow(κ₁, dκ₁, κ₂, dκ₂, P[3].mask_token))
(Flowfusion.FProcess(P[1].P, continuous_schedule), Flowfusion.FProcess(P[2].P, manifold_schedule), rev_NoisyInterpolatingDiscreteFlow(0.2f0, K = 2, dummy_token = 21))
end

function compound_state(b)
L,B = size(b.aas)
cmask = b.aas .< 100
Expand Down Expand Up @@ -45,7 +71,6 @@ function flowX1predictor(X0, b, model; d = identity, smooth = 0)
prev_trans = values(translation(f))
T = eltype(prev_trans)
function m(t, Xt)
print(".")
f, aalogits = model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = f)
values(translation(f)) .= prev_trans .* T(smooth) .+ values(translation(f)) .* T(1-smooth)
prev_trans = values(translation(f))
Expand All @@ -54,17 +79,84 @@ function flowX1predictor(X0, b, model; d = identity, smooth = 0)
return m
end

function flowX0predictor(X0, b, model, P; d = identity, smooth = 0) # Forces P to be a FProcess and doesn't work for some reason for P Deterministic
batch_dim = size(tensor(X0[1]), 4)
ff, _ = model(d(ones(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds)) # ones makes it start at time = 1
if P[1].P isa Deterministic
v = 0
else
v = P[1].P.v
end
function m(rt, Xt)
ff, aalogits = model(d(1-rt .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames=ff)
aalogits = deepcopy(cpu(aalogits))
X1Hat = deepcopy(cpu(ff))
t = 1f0 .- P[1].F.(rt .+ zeros(Float32, 1, batch_dim))
t[t .>= 0.999] .= 0.999
values(translation(X1Hat)) .= (tensor(Xt[1]) .- values(translation(X1Hat)) .* t) ./ (1 .- t + v .* t)
M = Xt[2].S.M
p = eachslice(tensor(Xt[2]), dims=(3, 4))
tangent = -t ./ (1 .- t) .* log.((M,), p, eachslice(values(linear(X1Hat)), dims=(3, 4)))
X0Hat = exp.((M,), p, tangent)
values(linear(X1Hat)) .= stack(X0Hat)
T = eltype(aalogits)
aalogits .= T(-Inf)
aalogits[21,:,:] .= 0
return (cpu(values(translation(X1Hat))), ManifoldState(rotM, eachslice(cpu(values(linear(X1Hat))), dims=(3,4))), cpu(softmax(aalogits))), (cpu(values(translation(ff))), ManifoldState(rotM, eachslice(cpu(values(linear(ff))), dims=(3,4))), cpu(softmax(aalogits)))
end
return m
end

function bind_flowX1predictor(X0, b, model, recorded; d = identity, smooth = 0, meanshift = true)
recdim = size(tensor(recorded[end][3][1]), 3)
batch_dim = size(tensor(X0[1]), 4)
f, _ = cpu(model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds)))
values(translation(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][1]) # Might be more sensible to do a weighted average of X̂₁ and (1-t)*(Xₜ₊Δₜ - Xₜ)/Δt
values(linear(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][2])
f, _ = cpu(model(d(zeros(Float32, 1, batch_dim)), d(X0), d(b.chainids), d(b.resinds), sc_frames=d(f)))
recmean = Flux.mean(values(translation(f))[:, :, 1:recdim, :], dims = 3)
forcemean = Flux.mean(tensor(recorded[end][3][1]), dims = 3)
values(translation(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][1])
if meanshift
values(translation(f))[:, :, 1:recdim, :] .+= forcemean .- recmean
end
values(linear(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][2])
function m(t, Xt)
f, aalogits = cpu(model(d(t .+ zeros(Float32, 1, batch_dim)), d(Xt), d(b.chainids), d(b.resinds), sc_frames = d(f)))
recmean = Flux.mean(values(translation(f))[:, :, 1:recdim, :], dims = 3)
forcemean = Flux.mean(tensor(recorded[end][3][1]), dims = 3)
values(translation(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][1])
if meanshift #This is to shift the binder over by the amount the target would have shifted, in the other direction:
values(translation(f))[:, :, recdim+1:end, :] .+= forcemean .- recmean
end
values(linear(f))[:, :, 1:recdim, :] .= tensor(recorded[end][3][2])
return cpu(values(translation(f))), ManifoldState(rotM, eachslice(cpu(values(linear(f))), dims=(3,4))), cpu(softmax(aalogits))
end
return m
end

H(a; d = 2/3) = a<=d ? (a^2)/2 : d*(a - d/2)
S(a) = H(a)/H(1)

function flow_quickgen(b, model; steps = :default, d = identity, tracker = Returns(nothing), smooth = 0.6)
function flow_quickgen(P, b, X0, model; is_reverse = false, steps = :default, d = identity, tracker = Returns(nothing), smooth = 0.6, record = [], progress_bar=identity, snap_time = 0)
stps = vcat(zeros(5),S.([0.0:0.00255:0.9975;]),[0.999, 0.9998, 1.0])
if steps isa Number

if steps isa Number
stps = 0f0:1f0/steps:1f0
elseif steps isa AbstractVector
stps = steps
end
X0 = zero_state(b)
X1pred = flowX1predictor(X0, b, model, d = d, smooth = smooth)
return gen(P, X0, X1pred, Float32.(stps), tracker = tracker)
b.locs .= tensor(X0[1])
#b.aas .= convert(Matrix{Int64}, tensor(X0[3]).indices)
b.aas .= unhot(X0[3]).S.state
if !is_reverse && length(record) == 0
X1pred = flowX1predictor(X0, b, model, d = d, smooth = smooth)
return gen(P, X0, X1pred, Float32.(stps), tracker = tracker, progress_bar = progress_bar)
elseif is_reverse
X0pred = flowX0predictor(X0, b, model, P, d = d, smooth = smooth)
return reverse_gen(P, X0, X0pred, Float32.(1 .- reverse(stps)), record, tracker = tracker, progress_bar = progress_bar, snap_time = snap_time), record
else
X1pred = bind_flowX1predictor(X0, b, model, record, d = d, smooth = smooth)
return bind_gen(P, X0, X1pred, Float32.(stps), record, tracker = tracker, progress_bar = progress_bar)
end
end