def readout_features(self, graphTensor: tfgnn.GraphTensor, feature_name: str) -> tf.Tensor:
"""Extract features using dynamic root node type from context"""
# 1. Get root node type from context feature
root_node_type = graphTensor.context["root_node_type"] # Shape [batch_size, 1]
#root_node_type = tf.squeeze(root_node_type, axis=1) # Shape [batch_size]
batch_size = tf.shape(root_node_type)[0]
features = tf.TensorArray(tf.float32, size=batch_size)
for i in tf.range(batch_size):
current_type = root_node_type[i]
component = tfgnn.get_component(graph_tensor, i)
feature = tfgnn.gather_first_node(
component,
node_set_name=current_type,
feature_name=tfgnn.HIDDEN_STATE
)
features = features.write(i, feature)
return features.stack()
It is very inconvenient to get embeddings for first nodes mutiple node sets in one time.
It is very inconvenient to get embeddings for first nodes mutiple node sets in one time.