From e628595d83ad2432d42be054d02d89f761453dbe Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sat, 17 Jan 2026 22:03:21 -0800 Subject: [PATCH 1/2] fix: Gate empty result in GATConv --- GNNlib/src/layers/conv.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index bd9bd18b3..abb6e0c03 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -139,11 +139,12 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = α = dropout(α, l.dropout) β = α .* m.Wxj x = aggregate_neighbors(g, +, β) + width = size(x, 1) if !l.concat x = mean(x, dims = 2) end - x = reshape(x, :, size(x, 3)) # return a matrix + x = reshape(x, width, size(x, 3)) # return a matrix x = l.σ.(x .+ l.bias) return x From 9b794a9dc67fc0e2c2d2ca87c483e8d16ec5df96 Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Sun, 18 Jan 2026 00:44:22 -0800 Subject: [PATCH 2/2] fix: Handle multiple heads --- GNNlib/src/layers/conv.jl | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index abb6e0c03..a234a6416 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -124,13 +124,10 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = _, chout = l.channel heads = l.heads - Wxi = Wxj = l.dense_x(xj) - Wxi = Wxj = reshape(Wxj, chout, heads, :) - - if xi !== xj - Wxi = l.dense_x(xi) - Wxi = reshape(Wxi, chout, heads, :) - end + Wxj = l.dense_x(xj) + Wxj = reshape(Wxj, chout, heads, :) + Wxi = l.dense_x(xi) + Wxi = reshape(Wxi, chout, heads, :) # a hand-written message passing message = Fix1(gat_message, l) @@ -139,10 +136,12 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = α = dropout(α, l.dropout) β = α .* m.Wxj x = aggregate_neighbors(g, +, β) - width = size(x, 1) if !l.concat x = mean(x, dims = 2) + width = size(x, 1) + else + width = size(x, 1) * size(x, 2) end x = reshape(x, width, size(x, 3)) # return a matrix x = l.σ.(x .+ l.bias) @@ -195,8 +194,11 @@ function gatv2_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} if !l.concat x = mean(x, dims = 2) + width = size(x, 1) + else + width = size(x, 1) * size(x, 2) end - x = reshape(x, :, size(x, 3)) + x = reshape(x, width, size(x, 3)) x = l.σ.(x .+ l.bias) return x end