diff --git a/docs/src/models/deeponet.md b/docs/src/models/deeponet.md index 3f5a186..3c650f2 100644 --- a/docs/src/models/deeponet.md +++ b/docs/src/models/deeponet.md @@ -49,7 +49,7 @@ Random.seed!(rng, 1234) xdev = reactant_device() -eval_points = 1 +eval_points = 17 batch_size = 64 dim_y = 1 m = 32 @@ -58,7 +58,10 @@ xrange = range(0, 2π; length=m) .|> Float32 α = 0.5f0 .+ 0.5f0 .* rand(Float32, batch_size) u_data = zeros(Float32, m, batch_size) -y_data = rand(rng, Float32, 1, eval_points) .* Float32(2π) +y_data = rand(rng, Float32, dim_y, eval_points) .* Float32(2π) +# for plotting, we want to evaluate points in order +rightorder = sortperm(vec(y_data)) + v_data = zeros(Float32, eval_points, batch_size) for i in 1:batch_size @@ -67,8 +70,8 @@ for i in 1:batch_size end deeponet = DeepONet( - Chain(Dense(m => 8, σ), Dense(8 => 8, σ), Dense(8 => 8, σ)), - Chain(Dense(1 => 4, σ), Dense(4 => 8, σ)) + Chain(Dense(m => 64, tanh), Dense(64 => 64, tanh), Dense(64 => 64, tanh)), + Chain(Dense(1 => 16, tanh), Dense(16 => 64, tanh)) ) ps, st = Lux.setup(rng, deeponet) |> xdev; @@ -90,7 +93,7 @@ function train!(model, ps, st, data; epochs=10) return losses end -losses = train!(deeponet, ps, st, data; epochs=1000) +losses = train!(deeponet, ps, st, data; epochs=20000) draw( AoG.data((; losses, iteration=1:length(losses))) * @@ -99,4 +102,23 @@ draw( axis=(; yscale=log10), figure=(; title="Using DeepONet to learn the anti-derivative operator") ) + +# plot the prediction for a new function +# that's not part of the training set +αₜ = 0.75 +input_data = sin.(αₜ .* xrange) |> xdev +output_data, st = @jit Lux.apply(deeponet, (input_data, y_data), ps, st) +output_x = vec(cdev(y_data))[rightorder] +pred_y = vec(cdev(output_data))[rightorder] +true_y = -inv(αₜ) .* cos.(αₜ .* y_data[1, rightorder]) +p = lines(Array(xrange), Array(input_data); label="u") +lines!(a, Array(output_x), Array(pred_y); label="Predicted") +lines!(a, Array(output_x), Array(true_y); label="Expected") +axislegend(a) +# Compute the absolute error and plot that, too +absolute_error = abs.(Array(pred_y) .- Array(true_y)) +a2, p2 = lines(f[2, 1], Array(output_x), absolute_error; axis=(; ylabel="Error")) +rowsize!(f.layout, 2, Aspect(1, 1 / 8)) +linkxaxes!(a, a2) +f ```