From e64662d74c3f013048fd93111860cfbe940fb489 Mon Sep 17 00:00:00 2001 From: calebjubal <2301730055@krmu.edu.in> Date: Fri, 6 Feb 2026 08:40:08 +0000 Subject: [PATCH] refactor: remove multiple feature access methods and update usages to use `g.ndata.x` --- GNNGraphs/src/GNNGraphs.jl | 5 +--- GNNGraphs/src/query.jl | 32 ------------------------- GNNGraphs/test/transform.jl | 4 ++-- GNNLux/src/layers/pool.jl | 4 ++-- GraphNeuralNetworks/src/layers/basic.jl | 6 ++--- GraphNeuralNetworks/src/layers/conv.jl | 14 +++++------ GraphNeuralNetworks/src/layers/pool.jl | 6 ++--- GraphNeuralNetworks/test/layers/pool.jl | 2 +- 8 files changed, 19 insertions(+), 54 deletions(-) diff --git a/GNNGraphs/src/GNNGraphs.jl b/GNNGraphs/src/GNNGraphs.jl index bb35dcfcd..a02fd3023 100644 --- a/GNNGraphs/src/GNNGraphs.jl +++ b/GNNGraphs/src/GNNGraphs.jl @@ -24,10 +24,7 @@ include("abstracttypes.jl") export AbstractGNNGraph include("gnngraph.jl") -export GNNGraph, - node_features, - edge_features, - graph_features +export GNNGraph include("gnnheterograph/gnnheterograph.jl") export GNNHeteroGraph, diff --git a/GNNGraphs/src/query.jl b/GNNGraphs/src/query.jl index 2c2f4332a..eb33e4227 100644 --- a/GNNGraphs/src/query.jl +++ b/GNNGraphs/src/query.jl @@ -535,38 +535,6 @@ function graph_indicator(g::GNNGraph; edges = false) end end - - -function node_features(g::GNNGraph) - if isempty(g.ndata) - return nothing - elseif length(g.ndata) > 1 - @error "Multiple feature arrays, access directly through `g.ndata`" - else - return first(values(g.ndata)) - end -end - -function edge_features(g::GNNGraph) - if isempty(g.edata) - return nothing - elseif length(g.edata) > 1 - @error "Multiple feature arrays, access directly through `g.edata`" - else - return first(values(g.edata)) - end -end - -function graph_features(g::GNNGraph) - if isempty(g.gdata) - return nothing - elseif length(g.gdata) > 1 - @error "Multiple feature arrays, access directly through `g.gdata`" - else - return first(values(g.gdata)) - end -end - """ is_bidirected(g::GNNGraph) diff --git a/GNNGraphs/test/transform.jl b/GNNGraphs/test/transform.jl index fb7e95bf4..cd974fcbc 100644 --- a/GNNGraphs/test/transform.jl +++ b/GNNGraphs/test/transform.jl @@ -43,7 +43,7 @@ end s, t = edge_index(g123) @test s == [edge_index(g1)[1]; 10 .+ edge_index(g2)[1]; 14 .+ edge_index(g3)[1]] @test t == [edge_index(g1)[2]; 10 .+ edge_index(g2)[2]; 14 .+ edge_index(g3)[2]] - @test node_features(g123)[:, 11:14] ≈ node_features(g2) + @test g123.ndata.x[:, 11:14] ≈ g2.ndata.x # scalar graph features g1 = GNNGraph(g1, gdata = rand()) @@ -106,7 +106,7 @@ end s, t = edge_index(g2b) @test s == edge_index(g2)[1] @test t == edge_index(g2)[2] - @test node_features(g2b) ≈ node_features(g2) + @test g2b.ndata.x ≈ g2.ndata.x g2c = getgraph(g, 2) @test g2c isa GNNGraph{typeof(g.graph)} diff --git a/GNNLux/src/layers/pool.jl b/GNNLux/src/layers/pool.jl index 4d4b7273e..d9a9349a3 100644 --- a/GNNLux/src/layers/pool.jl +++ b/GNNLux/src/layers/pool.jl @@ -39,7 +39,7 @@ end (l::GlobalPool)(g::GNNGraph, x::AbstractArray, ps, st) = GNNlib.global_pool(l, g, x), st -(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) +(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, g.ndata.x, ps, st)) @doc raw""" GlobalAttentionPool(fgate, ffeat=identity) @@ -106,7 +106,7 @@ function (l::GlobalAttentionPool)(g, x, ps, st) return GNNlib.global_attention_pool(m, g, x), st end -(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g), ps, st)) +(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, g.ndata.x, ps, st)) """ TopKPool(adj, k, in_channel) diff --git a/GraphNeuralNetworks/src/layers/basic.jl b/GraphNeuralNetworks/src/layers/basic.jl index 45a1e0b05..1c87bc62f 100644 --- a/GraphNeuralNetworks/src/layers/basic.jl +++ b/GraphNeuralNetworks/src/layers/basic.jl @@ -9,7 +9,7 @@ abstract type GNNLayer end # Forward pass with graph-only input. # To be specialized by layers also needing edge features as input (e.g. NNConv). -(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) +(l::GNNLayer)(g::GNNGraph) = GNNGraph(g, ndata = l(g, g.ndata.x)) """ WithGraph(model, g::GNNGraph; traingraph=false) @@ -144,12 +144,12 @@ _applylayer(l, g::GNNGraph, x) = l(x) _applylayer(l::GNNLayer, g::GNNGraph, x) = l(g, x) # input from graph -_applylayer(l, g::GNNGraph) = GNNGraph(g, ndata = l(node_features(g))) +_applylayer(l, g::GNNGraph) = GNNGraph(g, ndata = l(g.ndata.x)) _applylayer(l::GNNLayer, g::GNNGraph) = l(g) # # Handle Flux.Parallel function _applylayer(l::Parallel, g::GNNGraph) - GNNGraph(g, ndata = _applylayer(l, g, node_features(g))) + GNNGraph(g, ndata = _applylayer(l, g, g.ndata.x)) end function _applylayer(l::Parallel, g::GNNGraph, x::AbstractArray) diff --git a/GraphNeuralNetworks/src/layers/conv.jl b/GraphNeuralNetworks/src/layers/conv.jl index e3cf30fea..3b98a4f8d 100644 --- a/GraphNeuralNetworks/src/layers/conv.jl +++ b/GraphNeuralNetworks/src/layers/conv.jl @@ -341,7 +341,7 @@ function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity; GATConv(dense_x, dense_e, b, a, σ, negative_slope, ch, heads, concat, add_self_loops, dropout) end -(l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) +(l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, g.ndata.x, g.edata.e)) (l::GATConv)(g, x, e = nothing) = GNNlib.gat_conv(l, g, x, e) @@ -461,7 +461,7 @@ function GATv2Conv(ch::Pair{NTuple{2, Int}, Int}, add_self_loops, dropout) end -(l::GATv2Conv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) +(l::GATv2Conv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, g.ndata.x, g.edata.e)) (l::GATv2Conv)(g, x, e=nothing) = GNNlib.gatv2_conv(l, g, x, e) @@ -718,7 +718,7 @@ end (l::NNConv)(g, x, e) = GNNlib.nn_conv(l, g, x, e) -(l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) +(l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, g.ndata.x, g.edata.e)) function Base.show(io::IO, l::NNConv) out, in = size(l.weight) @@ -933,7 +933,7 @@ end (l::CGConv)(g, x, e = nothing) = GNNlib.cg_conv(l, g, x, e) -(l::CGConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) +(l::CGConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, g.ndata.x, g.edata.e)) function Base.show(io::IO, l::CGConv) print(io, "CGConv($(l.ch)") @@ -1054,7 +1054,7 @@ function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) end function (l::MEGNetConv)(g::GNNGraph) - x, e = l(g, node_features(g), edge_features(g)) + x, e = l(g, g.ndata.x, g.edata.e) return GNNGraph(g, ndata = x, edata = e) end @@ -1137,7 +1137,7 @@ end (l::GMMConv)(g::GNNGraph, x, e) = GNNlib.gmm_conv(l, g, x, e) -(l::GMMConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) +(l::GMMConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, g.ndata.x, g.edata.e)) function Base.show(io::IO, l::GMMConv) (nin, ein), out = l.ch @@ -1538,7 +1538,7 @@ end (l::TransformerConv)(g, x, e = nothing) = GNNlib.transformer_conv(l, g, x, e) function (l::TransformerConv)(g::GNNGraph) - GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) + GNNGraph(g, ndata = l(g, g.ndata.x, g.edata.e)) end function Base.show(io::IO, l::TransformerConv) diff --git a/GraphNeuralNetworks/src/layers/pool.jl b/GraphNeuralNetworks/src/layers/pool.jl index 0efe3282d..15f29d33e 100644 --- a/GraphNeuralNetworks/src/layers/pool.jl +++ b/GraphNeuralNetworks/src/layers/pool.jl @@ -38,7 +38,7 @@ end (l::GlobalPool)(g::GNNGraph, x::AbstractArray) = GNNlib.global_pool(l, g, x) -(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) +(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, g.ndata.x)) @doc raw""" GlobalAttentionPool(fgate, ffeat=identity) @@ -96,7 +96,7 @@ GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity) (l::GlobalAttentionPool)(g, x) = GNNlib.global_attention_pool(l, g, x) -(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) +(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata = l(g, g.ndata.x)) """ TopKPool(adj, k, in_channel) @@ -157,4 +157,4 @@ end (l::Set2Set)(g, x) = GNNlib.set2set_pool(l, g, x) -(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, node_features(g))) +(l::Set2Set)(g::GNNGraph) = GNNGraph(g, gdata = l(g, g.ndata.x)) diff --git a/GraphNeuralNetworks/test/layers/pool.jl b/GraphNeuralNetworks/test/layers/pool.jl index fa1475b20..cbb2240f4 100644 --- a/GraphNeuralNetworks/test/layers/pool.jl +++ b/GraphNeuralNetworks/test/layers/pool.jl @@ -80,7 +80,7 @@ end g = batch([rand_graph(10, 40, graph_type = GRAPH_T) for _ in 1:5]) g = GNNGraph(g, ndata = rand(Float32, n_in, g.num_nodes)) l = Set2Set(n_in, n_iters, n_layers) - y = l(g, node_features(g)) + y = l(g, g.ndata.x) @test size(y) == (2 * n_in, g.num_graphs) ## TODO the numerical gradient seems to be 3 times smaller than zygote one