Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ include("abstracttypes.jl")
export AbstractGNNGraph

include("gnngraph.jl")
export GNNGraph,
node_features,
edge_features,
graph_features
export GNNGraph

include("gnnheterograph/gnnheterograph.jl")
export GNNHeteroGraph,
Expand Down
32 changes: 0 additions & 32 deletions GNNGraphs/src/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -535,38 +535,6 @@ function graph_indicator(g::GNNGraph; edges = false)
end
end



function node_features(g::GNNGraph)
if isempty(g.ndata)
return nothing
elseif length(g.ndata) > 1
@error "Multiple feature arrays, access directly through `g.ndata`"
else
return first(values(g.ndata))
end
end

function edge_features(g::GNNGraph)
if isempty(g.edata)
return nothing
elseif length(g.edata) > 1
@error "Multiple feature arrays, access directly through `g.edata`"
else
return first(values(g.edata))
end
end

function graph_features(g::GNNGraph)
if isempty(g.gdata)
return nothing
elseif length(g.gdata) > 1
@error "Multiple feature arrays, access directly through `g.gdata`"
else
return first(values(g.gdata))
end
end

"""
is_bidirected(g::GNNGraph)

Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end
s, t = edge_index(g123)
@test s == [edge_index(g1)[1]; 10 .+ edge_index(g2)[1]; 14 .+ edge_index(g3)[1]]
@test t == [edge_index(g1)[2]; 10 .+ edge_index(g2)[2]; 14 .+ edge_index(g3)[2]]
@test node_features(g123)[:, 11:14] ≈ node_features(g2)
@test g123.ndata.x[:, 11:14] ≈ g2.ndata.x

# scalar graph features
g1 = GNNGraph(g1, gdata = rand())
Expand Down Expand Up @@ -106,7 +106,7 @@ end
s, t = edge_index(g2b)
@test s == edge_index(g2)[1]
@test t == edge_index(g2)[2]
@test node_features(g2b)node_features(g2)
@test g2b.ndata.xg2.ndata.x

g2c = getgraph(g, 2)
@test g2c isa GNNGraph{typeof(g.graph)}
Expand Down
4 changes: 2 additions & 2 deletions GNNLux/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end

(l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st

(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, g.ndata.x, ps, st))

@doc raw"""
GlobalAttentionPool(fgate, ffeat=identity)
Expand Down Expand Up @@ -106,7 +106,7 @@ function (l::GlobalAttentionPool)(g, x, ps, st)
return GNNlib.global_attention_pool(m, g, x), st
end

(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st))
(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, g.ndata.x, ps, st))

"""
TopKPool(adj, k, in_channel)
Expand Down
6 changes: 3 additions & 3 deletions GraphNeuralNetworks/src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ abstract type GNNLayer end

# Forward pass with graph-only input.
# To be specialized by layers also needing edge features as input (e.g. NNConv).
(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g)))
(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, g.ndata.x))

"""
WithGraph(model, g::GNNGraph; traingraph=false)
Expand Down Expand Up @@ -144,12 +144,12 @@ _applylayer(l, g::GNNGraph, x) = l(x)
_applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x)

# input from graph
_applylayer(l, g::GNNGraph) = GNNGraph(g, ndata = l(node_features(g)))
_applylayer(l, g::GNNGraph) = GNNGraph(g, ndata = l(g.ndata.x))
_applylayer(l::GNNLayer, g::GNNGraph) = l(g)

# # Handle Flux.Parallel
function _applylayer(l::Parallel, g::GNNGraph)
GNNGraph(g, ndata = _applylayer(l, g, node_features(g)))
GNNGraph(g, ndata = _applylayer(l, g, g.ndata.x))
end

function _applylayer(l::Parallel, g::GNNGraph, x::AbstractArray)
Expand Down
14 changes: 7 additions & 7 deletions GraphNeuralNetworks/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
GATConv(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, dropout)
end

(l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
(l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, g.ndata.x, g.edata.e))

(l::GATConv)(g, x, e = nothing) = GNNlib.gat_conv(l, g, x, e)

Expand Down Expand Up @@ -461,7 +461,7 @@ function GATv2Conv(ch::Pair{NTuple{2, Int}, Int},
add_self_loops, dropout)
end

(l::GATv2Conv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
(l::GATv2Conv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, g.ndata.x, g.edata.e))

(l::GATv2Conv)(g, x, e=nothing) = GNNlib.gatv2_conv(l, g, x, e)

Expand Down Expand Up @@ -718,7 +718,7 @@ end

(l::NNConv)(g, x, e) = GNNlib.nn_conv(l, g, x, e)

(l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
(l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, g.ndata.x, g.edata.e))

function Base.show(io::IO, l::NNConv)
out, in = size(l.weight)
Expand Down Expand Up @@ -933,7 +933,7 @@ end
(l::CGConv)(g, x, e = nothing) = GNNlib.cg_conv(l, g, x, e)


(l::CGConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
(l::CGConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, g.ndata.x, g.edata.e))

function Base.show(io::IO, l::CGConv)
print(io, "CGConv($(l.ch)")
Expand Down Expand Up @@ -1054,7 +1054,7 @@ function MEGNetConv(ch::Pair{Int, Int}; aggr = mean)
end

function (l::MEGNetConv)(g::GNNGraph)
x, e = l(g, node_features(g), edge_features(g))
x, e = l(g, g.ndata.x, g.edata.e)
return GNNGraph(g, ndata = x, edata = e)
end

Expand Down Expand Up @@ -1137,7 +1137,7 @@ end

(l::GMMConv)(g::GNNGraph, x, e) = GNNlib.gmm_conv(l, g, x, e)

(l::GMMConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
(l::GMMConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, g.ndata.x, g.edata.e))

function Base.show(io::IO, l::GMMConv)
(nin, ein), out = l.ch
Expand Down Expand Up @@ -1538,7 +1538,7 @@ end
(l::TransformerConv)(g, x, e = nothing) = GNNlib.transformer_conv(l, g, x, e)

function (l::TransformerConv)(g::GNNGraph)
GNNGraph(g, ndata = l(g, node_features(g), edge_features(g)))
GNNGraph(g, ndata = l(g, g.ndata.x, g.edata.e))
end

function Base.show(io::IO, l::TransformerConv)
Expand Down
6 changes: 3 additions & 3 deletions GraphNeuralNetworks/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end

(l::GlobalPool)(g::GNNGraph, x::AbstractArray) = GNNlib.global_pool(l, g, x)

(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, g.ndata.x))

@doc raw"""
GlobalAttentionPool(fgate, ffeat=identity)
Expand Down Expand Up @@ -96,7 +96,7 @@ GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity)

(l::GlobalAttentionPool)(g, x) = GNNlib.global_attention_pool(l, g, x)

(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))
(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, g.ndata.x))

"""
TopKPool(adj, k, in_channel)
Expand Down Expand Up @@ -157,4 +157,4 @@ end

(l::Set2Set)(g, x) = GNNlib.set2set_pool(l, g, x)

(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g)))
(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, g.ndata.x))
2 changes: 1 addition & 1 deletion GraphNeuralNetworks/test/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ end
g = batch([rand_graph(10, 40, graph_type = GRAPH_T) for _ in 1:5])
g = GNNGraph(g, ndata = rand(Float32, n_in, g.num_nodes))
l = Set2Set(n_in, n_iters, n_layers)
y = l(g, node_features(g))
y = l(g, g.ndata.x)
@test size(y) == (2 * n_in, g.num_graphs)

## TODO the numerical gradient seems to be 3 times smaller than zygote one
Expand Down
Loading