diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index bd9bd18b3..a234a6416 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -124,13 +124,10 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = _, chout = l.channel heads = l.heads - Wxi = Wxj = l.dense_x(xj) - Wxi = Wxj = reshape(Wxj, chout, heads, :) - - if xi !== xj - Wxi = l.dense_x(xi) - Wxi = reshape(Wxi, chout, heads, :) - end + Wxj = l.dense_x(xj) + Wxj = reshape(Wxj, chout, heads, :) + Wxi = l.dense_x(xi) + Wxi = reshape(Wxi, chout, heads, :) # a hand-written message passing message = Fix1(gat_message, l) @@ -142,8 +139,11 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = if !l.concat x = mean(x, dims = 2) + width = size(x, 1) + else + width = size(x, 1) * size(x, 2) end - x = reshape(x, :, size(x, 3)) # return a matrix + x = reshape(x, width, size(x, 3)) # return a matrix x = l.σ.(x .+ l.bias) return x @@ -194,8 +194,11 @@ function gatv2_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} if !l.concat x = mean(x, dims = 2) + width = size(x, 1) + else + width = size(x, 1) * size(x, 2) end - x = reshape(x, :, size(x, 3)) + x = reshape(x, width, size(x, 3)) x = l.σ.(x .+ l.bias) return x end