From e628595d83ad2432d42be054d02d89f761453dbe Mon Sep 17 00:00:00 2001
From: Leni Aniva
Date: Sat, 17 Jan 2026 22:03:21 -0800
Subject: [PATCH 1/2] fix: Gate empty result in GATConv
---
GNNlib/src/layers/conv.jl | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl
index bd9bd18b3..abb6e0c03 100644
--- a/GNNlib/src/layers/conv.jl
+++ b/GNNlib/src/layers/conv.jl
@@ -139,11 +139,12 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} =
α = dropout(α, l.dropout)
β = α .* m.Wxj
x = aggregate_neighbors(g, +, β)
+ width = size(x, 1)
if !l.concat
x = mean(x, dims = 2)
end
- x = reshape(x, :, size(x, 3)) # return a matrix
+ x = reshape(x, width, size(x, 3)) # return a matrix
x = l.σ.(x .+ l.bias)
return x
From 9b794a9dc67fc0e2c2d2ca87c483e8d16ec5df96 Mon Sep 17 00:00:00 2001
From: Leni Aniva
Date: Sun, 18 Jan 2026 00:44:22 -0800
Subject: [PATCH 2/2] fix: Handle multiple heads
---
GNNlib/src/layers/conv.jl | 20 +++++++++++---------
1 file changed, 11 insertions(+), 9 deletions(-)
diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl
index abb6e0c03..a234a6416 100644
--- a/GNNlib/src/layers/conv.jl
+++ b/GNNlib/src/layers/conv.jl
@@ -124,13 +124,10 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} =
_, chout = l.channel
heads = l.heads
- Wxi = Wxj = l.dense_x(xj)
- Wxi = Wxj = reshape(Wxj, chout, heads, :)
-
- if xi !== xj
- Wxi = l.dense_x(xi)
- Wxi = reshape(Wxi, chout, heads, :)
- end
+ Wxj = l.dense_x(xj)
+ Wxj = reshape(Wxj, chout, heads, :)
+ Wxi = l.dense_x(xi)
+ Wxi = reshape(Wxi, chout, heads, :)
# a hand-written message passing
message = Fix1(gat_message, l)
@@ -139,10 +136,12 @@ function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} =
α = dropout(α, l.dropout)
β = α .* m.Wxj
x = aggregate_neighbors(g, +, β)
- width = size(x, 1)
if !l.concat
x = mean(x, dims = 2)
+ width = size(x, 1)
+ else
+ width = size(x, 1) * size(x, 2)
end
x = reshape(x, width, size(x, 3)) # return a matrix
x = l.σ.(x .+ l.bias)
@@ -195,8 +194,11 @@ function gatv2_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix}
if !l.concat
x = mean(x, dims = 2)
+ width = size(x, 1)
+ else
+ width = size(x, 1) * size(x, 2)
end
- x = reshape(x, :, size(x, 3))
+ x = reshape(x, width, size(x, 3))
x = l.σ.(x .+ l.bias)
return x
end