JGNN
Graph Neural Networks (GNNs) are getting more and more popular as a machine learning paradigm, for example to make predictions based on relational information, or to perform inference on small datasets. JGNN is a library that provides cross-platform implementations of this paradigm without the need for dedicated hardware or firmware; create highly portable models that fit and are trained in a few megabytes of memory. While reading this guidebook, keep in mind that this is not a library for running computationally intensive stuff; it has no GPU support and we do not plan to add any (unless such support becomes integrated in the Java virtual machine). So, while source code is highly optimized and complex architectures are supported, running them fastly on graphs with many nodes may require compromises in the number of learned parameters or running time.
This guidebook is organized into four sections that focus on practical use. After this brief introduction and instructions for how to set things up, section 2 gives a taste of what using the library looks like. Then, section 3 describes the library's builder patter for constructing GNN models. Model construction includes symbolic expression parsing for machine learning operations, which drastically simplifies coding. Parsed expressions follow the Neuralang scripting language for model definitions. Finally, section 4 describes interfaces for training on automatically generated or customized data and testing. It also takes a deep dive into obtaining raw model predictions, and using them in custom training and evaluation schemes.
In addition to the above-described material, JGNN's full programmatic interface is provided as Javadoc, and domain-specific examples reside in the project's GitHub repository.
1. Setup
The simplest way to set up JGNN is to download it as a JAR package from
releases
and add it your Java project's build path. Those working with Maven
or Gradle can instead add JGNN's latest nightly release as a dependency from its JitPack
distribution. Follow the link below for full instructions.
For example, the fields in the snippet below may be added in a Maven pom.xml file to work with the latest nightly release.
<repositories>
<repository>
<id>jitpack.io</id>
<url>https://jitpack.io</url>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>com.github.MKLab-ITI</groupId>
<artifactId>JGNN</artifactId>
<version>SNAPSHOT</version>
</dependency>
</dependencies>
2. Quickstart
Here we demonstrate usage of JGNN for node classification. This is an inductive learning task that predicts node labels given a graph's structure, node features, and a some already known labels. Classifying graphs is also supported, although it is a harder task to explain and set up. GNN architectures for node classification are typically written as message passing mechanisms; they diffuse node representations across edges, where node neighbors pick up, aggregate (e.g., average), and transform incoming representations to update theirs. Alternatives that boast higher expressive power also exist and are supported, but simple architectures may be just as good or better than complex alternatives in solving practical problems [Krasanakis et al., 2024]. Simple architectures also enjoy reduced resource consumption.
Our demonstration starts by loading the Cora
dataset from those shipped
with the library for out-of-the-box testing. The first time an instance of this dataset is created,
it downloads its raw data from a web resource and stores them in a local downloads/
folder. The data are then loaded into a sparse graph adjacency matrix, a dense node feature matrix,
and a dense node label matrix. Sparse and dense representations are interchangeable in terms of operations,
with the main difference being that sparse matrices are much more efficient when they contain lots of zeros.
In the loaded matrices, each row contains the corresponding node's
neighbors, features, or one-hot encoding of labels. We apply the renormalization trick and
symmetric normalization on the dataset's adjacency matrix using in-place operations for minimal memory footprint;
the first of the two makes GNN computations numerically stable by adding self-loops
to all nodes, while renormalization is required by spectral-based GNNs, such as
the model we implement next.
Dataset dataset = new Cora();
dataset.graph().setMainDiagonal(1).setToSymmetricNormalization();
We now incrementally create a trainable model using symbolic expressions that resemble math
notation. The expressions are part of a scripting language, called Neuralang,
that is covered in section 3.3. However, for faster onboarding, stick to
the FastBuilder
class for creating models; this ommits some of
the language's features in favor of providing programmatic shortcuts for boilerplate code. Its constructor
accepts two arguments A
and h0
, respectivel holding
the graph's adjacency matrix and node features. These are internally set as constant symbols that
parse expressions can use. Other constants and input variables can be set too,
but more on this later. After instantiation, use some
model builder methods to declare a model's dataflow. Some of these methods parse the aforementioned expressions.
config
- Configures hyperparameter values. These can be used in all subsequent function and layer declarations.function
- Declares a Neuralang function, in this case with inputsA
andh
.layer
- Declares a layer that can use builtin and Neuralang functions. In this, the symbols{l}
and{l+1}
specifically are replaced by a layer counter.classify
- Adds a softmax layer tailored to classification. This also silently declares an inputnodes
that represents a list of node indices where the outputs should be computed.autosize
- Automatically sizes matrix and vector dimensions that were originally defnoted with a questionmark?
. This method requires some input example, and here we provide a list of node identifiers, which we also make dataless (have only the correct dimensions without allocating memory). This method also checks for integrity errors in the declared architecture, such as computational paths that do not lead to an output.
matrix
and vector
Neuralang functions. These inline declarations of learnable parameter for
given dimensions and regularization. Access the builder's created model via modelBuilder.getModel()
.
long numSamples = dataset.samples().getSlice().size();
long numClasses = dataset.labels().getCols();
ModelBuilder modelBuilder = new FastBuilder(dataset.graph(), dataset.features())
.config("reg", 0.005)
.config("classes", numClasses)
.config("hidden", 64)
.function("gcnlayer", "(A,h){Adrop = dropout(A, 0.5); return Adrop@(h@matrix(?, hidden, reg))+vector(?);}")
.layer("h{l+1}=relu(gcnlayer(A, h{l}))")
.config("hidden", "classes") // reassigns the output gcnlayer's "hidden" to be equal to the number of "classes"
.layer("h{l+1}=gcnlayer(A, h{l})")
.classify()
.autosize(new EmptyTensor(numSamples));
Training epochs for the created model can be implemented
manually, by passing inputs, obtaining outputs, computing losses, and triggering backpropagation
on an optimizer. As these steps may be complicated, JGNN automates common
training patterns by extending a base ModelTraining
class with training strategies
tailored to different data formats and predictive tasks. Find these subclasses in the
adhoc.train
Javadoc. Instances of model trainers
accept a method chain notation to set their parameters. Parameters usually include training and validation data
(these should be made first and depend on the model training class) and aspects of the training strategy like the number of epochs, patience
for early stopping, the employed optimizer, and loss functions. An example is presented below.
Of data needed for training, the graph adjacency matrix and node features are already declared as constants by the
FastBuilder
constructor, as node classification takes place on the same graph
with fully known node features. Thus, input features are a column of node identifiers, which
classify
method above uses to gather
the predictions on respective nodes. Architecture outputs are softmax approximation of the one-hot
encodings of respective node labels. The simplest way to handle missing labels for test data without modifying
the example is to leave their one-hot encodings as zeroes only.
Additionally, this particular training strategy accepts training and validation data slices, where slices are lists
of integer entries pointing to rows of inputs and outputs - find more later.
To finish describing the training strategy, the example selects
Adam
optimization with learning rate 0.01, and training
over many epochs with early stopping. A verbose
loss prints every 10 epochs the progress of cross entropy and accuracy on validation data, where the
first of these two is used for the early stopping criterion.
To run a full training process, pass a strategy to a model.
In a cold start scenario, apply a parameter initializer first before training is conducted.
A warm start that resumes training from some previously trained outcomes would skip this step.
Selecting an initializer is not part of the training strategy
to signify its model-dependent nature; dense layers should maintain the expected
input variances in the output before the first epoch, and therefore the initializer depends
on the type of activation functions.
Slice nodes = dataset.samples().getSlice().shuffle(); // a permutation of node identifiers
Matrix inputFeatures = Tensor.fromRange(nodes.size()).asColumn(); // each node has its identifier as an input (equivalent to: nodes.samplesAsFeatures())
ModelTraining trainer = new SampleClassification()
// training data
.setFeatures(inputFeatures)
.setLabels(dataset.labels())
.setTrainingSamples(nodes.range(0, 0.6))
.setValidationSamples(nodes.range(0.6, 0.8))
// training strategy
.setOptimizer(new Adam(0.01))
.setEpochs(3000)
.setPatience(100)
.setLoss(new CategoricalCrossEntropy())
.setValidationLoss(new VerboseLoss(new CategoricalCrossEntropy(), new Accuracy()).setInterval(10)); // print every 10 epochs
Model model = modelBuilder.getModel()
.init(new XavierNormal())
.train(trainer);
Trained models and their generating builders can be saved and loaded. The next snippet demonstrates how raw predictions can be made too. During this process, some matrix manipulation operations are employed to obtain transparent access to parts of input and output data of the dataset.
modelBuilder.save(Paths.get("gcn_cora.jgnn")); // needs a Path as an input
Model loadedModel = ModelBuilder.load(Paths.get("gcn_cora.jgnn")).getModel(); // loading creates a new modelbuilder from which to get the model
Matrix output = loadedModel.predict(Tensor.fromRange(0, nodes.size()).asColumn()).get(0).cast(Matrix.class);
double acc = 0;
for(Long node : nodes.range(0.8, 1)) {
Matrix nodeLabels = dataset.labels().accessRow(node).asRow();
Tensor nodeOutput = output.accessRow(node).asRow();
acc += nodeOutput.argmax()==nodeLabels.argmax()?1:0;
}
System.out.println("Acc\t "+acc/nodes.range(0.8, 1).size());
3. GNN Builders
We already touched on the subject of model builders in the previous section, where one of them was used to create a model. There exist different kinds of builders that offer different conveniences.
- GNNBuilder - Parses strings of simple Neuralang expressions.
- FastBuilder - Extends the
GNNBuilder
class with methods that inject boilerplate code for the inputs, outputs, and layers of node classification tasks. Prefer this builder of your want to keep track of the whole model definition in one place within Java code. - Neuralang - Extends the
GNNBuilder
class so that it can parse all aspects of the Neuralang language, such as functional declarations of machine learning modules, where parts of function signatures manage configuration hyperparameters. Use this builder to maintain model definitions in one place (e.g., packed in one string variable, or in one file) and avoid weaving symbolic expressions in Java code.
3.1. ModelBuilder
This is the base model builder class; it offers a wide breadth of functionalities that other builders extend.
Before looking at how to use it, though, we need to see what JGNN models look like under the hood.
Models are collections of NNOperation
instances, each representing a numerical computation with
specified inputs and outputs of
JGNN's Tensor
type. Tensors will be covered later; for now, it suffices to think of them as
numerical vectors, which are sometimes endowed with matrix dimensions. This guidebook does not list operation classes, as they are rarely used directly and can be found the Javadoc, namely
nn.inputs,
nn.activations,
and
nn.pooling.
Create models in pure Java like below. The example computes the expression
y=log(2*x+1)
without any trainable parameters.
After defining models, run them with the method Tensor Model.predict(Tensor...)
.
This takes as input one or more comma-separated tensors that match the model's
inputs (in the same order) and computes a list of output tensors. If inputs are dynamically created,
an overloaded version of the same method supports an array list of input tensors
Tensor Model.predict(ArrayList<Tensor>)
.
The snippet below includes a prediction for an input that consists of one tensor of one element.
Variable x = new Variable();
Constant c1 = new Constant(Tensor.fromDouble(1)); // holds the constant "1"
Constant c2 = new Constant(Tensor.fromDouble(2)); // holds the constant "2"
NNOperation mult = new Multiply()
.addInput(x)
.addInput(c2);
NNOperation add = new Add()
.addInput(mult)
.addInput(c1);
NNOperation y = new Log()
.addInput(add);
Model model = new Model()
.addInput(x)
.addOutput(y);
System.out.println(model.predict(Tensor.fromDouble(2)));
Judging by the fact that several lines of code are needed to declare even simple expressions,
pure Java code for creating full models tends to be cumbersome to read and maintain - hence the need for
builders that construct the models from concise symbolic expressions. Let us recreate the above example
with the ModelBuilder
class.
After instantiating the builder, use a method chain to declare an input variable
with the .var(String)
method, parse an expression with the
.operation(String)
method, and finally declare which symbol holds
outputs with the .out(String)
method.
The first and last of these methods can be called multiple times
to declare several inputs and outputs. Inputs need to be only one symbol, but a whole expression
for evaluation can be declared in outputs.
ModelBuilder modelBuilder = new ModelBuilder()
.var("x")
.operation("y = log(2*x+1)")
.out("y");
System.out.println(model.predict(Tensor.fromDouble(2)));
The operation parses string expressions that are typically structured
as assignments to symbols; the right-hand side of assignments accepts several operators and functions that
are listed in the next table. Models allow multiple operations too, which are parsed through either multiple
method calls or by being separated with a semicolon ;
within larger string expressions.
All methods need to use previously declared symbols. For example, parsing .out("symbol")
throws an exception if no operation previously assigned to the symbol or declared it as an input. For logic
safety, symbols cannot be overwritten or set to updated values outside of Neuralang functions.
Finally, the base model builder
class supports a roundabout declaration of Neuralang functions with expressions like this snippet taken from the Quickstart
section:
.function("gcnlayer", "(A,h){return A@(h@matrix(?, hidden, reg))+vector(?);}")
.
In this, the first method argument is the declared function symbol's name, and the second should necessarily have the arguments enclosed in
a parenthesis and the function's body enclosed in brackets. Learn more about Neuralang functions in
section 3.3.
Model definitions have so far been too simple to be employed in practice;
we need trainable parameters, which are created inline with the matrix
and vector
functions. There is also an equivalent Java
method ModelBuilder.param(String, Tensor)
that assigns an initialized Tensor
to a variable name, but its usage is discouraged to keep model definitions simple.
Additionally, there may be constants and configuration hyperparameters. Of these, constants reflect
untrainable tensors and set with ModelBuilder.const(String, Tensor)
,
whereas configuration hyperparameters are numerical values used by the parser and
set with ModelBuilder.config(String, double)
, or
ModelBuilder.config(String, String)
if the second argument value
should be copied from another configuration.
Both numbers in the last snippet's symbolic definition are internally parsed into constants.
On the other hand, hyperparameters can be used as arguments to dimension sizes and regularization.
Retrieve previously set hyperparameters though double ModelBuilder.getConfig(String)
or double ModelBuilder.getConfigOrDefault(String, double)
to replace the error with a default value if the configuration is not found. The usefulness of retrieving
configurations will become apparent later on.
Next is a table of available operations that you can use in expressions. Standard priority rules for priority and parentheses apply. Prefer using configuration hyperparameters to set matrix and vector creation, as these transfer their names to respective dimensions for error checking - more on this in section 3.4.
Symbol | Type | Description |
---|---|---|
x = expr |
Operator | Assign to variable x the outcome of executing expression expr . This expression does not evaluate to anything. |
x + y |
Operator | Element-by-element addition. |
x * y |
Operator | Element-by-element multiplication. |
x - y |
Operator | Element-by-element subtraction. |
x @ y |
Operator | Matrix multiplication. |
x | y |
Operator | Row-wise concatenation of x and y . |
x [y] |
Operator | Gathers the rows of x with indexes y . Indexes are still tensors, whose elements are cast to integers during this operation. |
transpose(A) |
Function | Transposes matrix A . |
log(x) |
Function | Applies a logarithm on each element of tensor x . |
exp(x) |
Function | Exponentiates each element of tensor x . |
nexp(x) |
Function | Exponentiates each non-zero element of tensor x . Typically used for neighbor attention (see below). |
relu(x) |
Function | Apply relu on each tensor element. |
tanh(x) |
Function | Apply a tanh activation on each tensor element. |
sigmoid(x) |
Function | Apply a sigmoid activation on each tensor element. |
dropout(x, rate) |
Function | Apply training dropout on tensor x with constant dropout rate hyperparameter rate . |
drop(x, rate) |
Function | Shorthand notation dropout . |
lrelu(x, slope) |
Function | Leaky relu on tensor x with constant negative slope hyperparameter slope . |
prelu(x) |
Function | Leaky relu on tensor x with learnable negative slope. |
softmax(x, dim) |
Function | Apply a softmax reduction on x , where dim is either dim:'row' (default) or dim:'col' . |
sum(x, dim) |
Function | Apply a sum reduction on x , where dim is either dim:'row' (default) or dim:'col' . |
mean(x, dim) |
Function | Apply a mean reduction on x , where dim is either dim:'row' (default) or dim:'col' . |
L1(x, dim) |
Function | Apply an L1 normalization on x across dimension dim , where dim is either dim:'row' (default) or dim:'col' . |
L2(x, dim) |
Function | Apply an L2 normalization on x across dimension dim , where dim is either dim:'row' (default) or dim:'col' . |
max(x, dim) |
Function | Apply a max reduction on x , where dim is either dim:'row' (default) or dim:'col' . |
min(x, dim) |
Function | Apply a min reduction on x , where dim is either dim:'row' (default) or dim:'col' . |
matrix(rows, cols) |
Function | Generate a matrix parameter with respective hyperparameter dimensions. |
matrix(rows, cols, reg) |
Function | Generate a matrix parameter with respective hyperparameter dimensions, and L2 regularization hyperparameter reg . |
mat(rows, cols) |
Function | Shorthand notation matrix . |
mat(rows, cols, reg) |
Function | Shorthand notation matrix . |
vector(len) |
Function | Generate a vector with size hyperparameter len . |
vector(len, reg) |
Function | Generate a vector with size hyperparameter len , and L2 regularization hyperparameter reg . |
vec(len) |
Function | Shorthand notation vector . |
vec(len, reg) |
Function | Shorthand notation vector . |
3.2. FastBuilder
The FastBuilder
class for building GNN architectures extends the generic
ModelBuilder
with common graph neural network operations. The main difference
is that it has two constuctor arguments, namely a square matrix A
that
is typically a normalization of the (sparse) adjacency matrix,
and a feature matrix h0
.
This builder further supports the notation symbol{l}
,
where the layer counter replaces the symbol part {l}
with 0 for the first layer,
1 for the second, and so on. Prefer the notation h{l}
to refer to the node representation
matrix of the current layer; for the first layer, this is parsed as h0
, which is the constant
set by the constructor.
FastBuilder
instances also offer a FastBuilder.layer(String)
chain method to compute neural layer outputs. This is a a variation of operation parsing, where the
the symbol part {l+1}
is substituted with the next layer's counter,
the expression is parsed, and the layer counter is incremented by one. Example usage is shown below, where
symbolic expressions read similarly to what you would find in a paper.
FastBuilder modelBuilder = new FastBuilder(adjacency, features) // sets A, h0
.layer("h{l+1}=relu(A@(h{l}@matrix(features, hidden, reg))+vector(hidden))") // parses h1 = relu(A@(h0 @ ...
.layer("h{l+1}=A@(h{l}@matrix(hidden, classes, reg))+vector(classes)"); // parses h2 = A@(h1@ ...
Before continuing, let us give some context for the above implementation.
The base operation of message passing GNNs, which are often used for node classification,
is to propagate node representations to neighbors via graph edges. Then, neighbors aggregate
the received representation, where aggregation typically consists of a weighted average per
the normalized adjacency matrix's edge weights. For symmetric normalization, this the
weighted sum is compatible with spectral graph signal processing. The operation to perform
one propagation can be written as .layer("h{l+1}=A @ h{l}")
.
The propagation's outcome is typically transformed further by passing through a dense
layer.
Several have been proposed as improvements of this scheme. However, they tend to incur marginal accuracy improvements at the cost of more compute. Stay away from complex architectures when learning from large graphs, as JGNN is designed to be lightweight but does not (and is not planned to) leverage GPUs. The library still supports the improvements listed below, since they could be used when running time is not a pressing issue (e.g., for transfer or stream learning that applies updates for a few epochs), or to analyse smaller graphs:
- Edge dropout - Applying dropout on the adjacency
matrix on each layer with
.layer("h{l+1}=dropout(A,0.5) @ h{l}")
. Usage of this operation seems innocuous, but it disables a bunch of caching optimizations that occur under-the-hood. - Heterogeneity - Some rcent approaches explicitly account for high-pass frequency diffusion by accounting for the graph Laplacian too. Insert this into the
architecture as a normal constant like so:
.constant("L", adjacency.negative().cast(Matrix.class).setMainDiagonal(1))
- Edge attention - Performs the dot product of edge nodes to create new edge weights
per the mathematical formula
A.(hTh)
, whereA
is a sparse adjacency matrix, the dot.
represents the Hadamard product (element-by-element multiplication), andh
is a dense matrix whose rows hold respective node representations. JGNN efficiently implements this operation with the Neuralang functionatt(A, h)
. For example, weighted adjacency matrices for each layer of gated attention networks are implemented as.operation("A{l} = L1(nexp(att(A, h{l})))")
. - General message passing - JGNN also supports the the fully generized
message passing scheme between node neighbors of more complex relational analysis
[Velickovic, 2022].
In this generalization, each edge is responsible for appropriately
transforming and propagating representations to node neighbors;
create message matrices whose rows correspond to edges and
columns to edge features by gathering the features of the edge
source and destination nodes. Programmatically,obtain edge source indexes
src=from(A)
and destination indexesdst=to(A)
, whereA
is the adjacency matrix. Then use the horizontal concatenation operation|
to concatenate node features. One may also concatenate edge features. Given a constructed message, define any kind of ad-hoc mechanism or neural processing of messages with traditional matrix operations (take care to define correct matrix sizes for dense transformations, e.g., twice the number of columns ash{l}
in the previous snippet). For any kind ofLayeredBuilder
, don't forget thatmessage{l}
within operations is needed to obtain a message from the representationsh{l}
that is not accidentally shared with future layers. Receiver mechanisms need to perform some kind of reduction on messages. JGNN implements summation reduction, given that this has the same theoretical expressive power as maximum-based reduction but is easier to backpropagate through. Perform this like below. The sum is weighted per the values of the adjacency matrixA
. Thus, perform adjacency matrix normalization only if you want such weighting to occur.modelBuilder .operation("src=from(A)") .operation("dst=to(A)") .operation("message{l}=h{l}[src] | h{l}[dst]") // has two times the number of h{l}'s features .operation("transformed_message{l}=...") // fill in the transformation .operation("received{l}=reduce(transformed_message{l}, A)");
So far, we discussed the propagation mechanisms of
GNNs, which consider the features of all nodes. However,
in node classification settings, training data labels
are typically available only for certain nodes, despite
all node features being required to make any prediction.
We thus need a mechanism to retrieve the predictions of the top
layer for those nodes, for example before applying a softmax.
This is achieved in the snippet below, which uses the gather
operations through brackets. Alternatively, chain the
FastBuilder.classify()
method, which injects this exact code.
modelBuilder
.var("nodes")
.layer("h{l} = softmax(h{l})")
.operation("output = h{l}[nodes]")
.out("output");
So far we tackled only equivariant GNNs, whose outputs follow any node permutations applied on their inputs. In simple terms, if the order of node idenfitifiers is modified (both in the graph adjacency matrix and in node feature matrices), the order of rows will be similarly modified for outputs. Most operations described so far are equivariant (those that are not explicitly say so), so that their synthesis is also equivariant. However, there are cases where created GNNs should be invariant, which means that they should create predictions that remain the same despite any input permutations. Invariance is the property to impose when classifying graphs, where one prediction should be made for the whole graph.
Imposing invariance is simple enough; take an equivariant
architecture and then apply an invariant operation on top.
You may want to perform further transformations (e.g., some
dense layers) afterwards, but the general idea remains
the same. JGNN offers two types of invariant operations, also
known as pooling:
reductions and sort-based pooling. Of these, reductions
are straightforward to implement
by taking a dimensionality reduction mechanism (min
,
max
, sum
, mean
)
applying it column-wise on the output feature matrix.
Recall that each row has the features of a different node,
so the result of reduction yields an one-dimensional vector that,
for each feature dimension, aggregates feature values across all nodes.
Reduction-based pooling performs a symmetric operation and therefore fail to distinguish between the structural positioning of nodes to be pooled. One computationally light alternative, which JGNN implements, is sorting nodes based on learned features before concatenating their features into one vector for each graph. This process is further simplified by keeping the top reduced number of nodes to concatenate their features, where the order is determined by an arbitrarily selected feature (in our implementation: the last one, with the previous feature being used to break ties, and so on). The idea is that the selected feature determines important nodes whose information can be adopted by others. To implement this scheme, JGNN provides independent operations to sort nodes, gather node latent representations, and reshape matrices into row or column tensors with learnable transformations to class outputs. These components are demonstrated in the following code snippet:
long reduced = 5; // input graphs need to have at least that many nodes
long hidden = 8; // many latent dims reduce speed without GPU parallelization
ModelBuilder builder = new LayeredBuilder()
.var("A")
.config("features", 1)
.config("classes", 2)
.config("reduced", reduced)
.config("hidden", hidden)
.layer("h{l+1}=relu(A@(h{l}@matrix(features, hidden))+vector(hidden))")
.layer("h{l+1}=relu(A@(h{l}@matrix(hidden, hidden))+vector(hidden))")
.concat(2) // concatenates the outputs of the last 2 layers
.config("hiddenReduced", hidden*2*reduced) // 2* due to concatenation
.operation("z{l}=sort(h{l}, reduced)") // z{l} are node indexes
.layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)")
.layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)")
.layer("h{l+1}=softmax(h{l}, dim: 'row')")
.out("h{l}");
3.3. Neuralang
Neuralang scripts consist of functions that declare machine learning
components. Use a Rust highlighter to cover all keywords.
Functions correspond to machine learning modules and call each other.
At their end lies a return
statement, which expresses their
outcome. All arguments are passed by value, i.e., any assignments are
performed on fresh variable instances.
Before explaining how to use the Neuralang
model builder,
we present and analyse code that supports a fully functional architecture.
First, look at the classify
function, which for completeness is presented below.
This takes two tensor inputs: nodes
that correspond to identifiers
insicating which nodes should be classified (the output has a number of rows equal to the
number of identifiers), and a node feature matrix h
.
It then computes and returns a softmax for the features of the specified nodes.
Aside from main inputs, the function's
signature also has several configuration values, whose defaults
are indicated by a colon :
(only configurations have defaults and conversely).
The same notation is used to
set/overwrite configurations when calling functions, as we do for softmax
to apply it row-wise. Think of configurations as keyword
arguments of typical programming languages, with the difference that
they control hyperparameters, like dimension sizes or regularization.
Write exact values for configurations, as for now there no
arithmetics take place for them. For example, a configuration
patience:2*50
creates an error.
fn classify(nodes, h, epochs: !3000, patience: !100, lr: !0.01) {
return softmax(h[nodes], dim: "row");
}
Exclamation marks !
before numbers broadcast values
to all subsequent function calls that have configurations with the same
name. The broadcasted defaults overwrite already existing defaults of configurations with the same
name anywhere in the code. All defaults are replaced by values explicitly set when calling functions.
For example, take advantage of this prioritization to force output layer dimensions match your data. Importantly,
broadcasted values are stored within JGNN's Neuralang
model
builder too; this is useful for Java integration, for example to retrieve learning training hyperparameters
from the model. To sum up, configuration values have the following priority, from strongest to weakest:
1. Arguments set during the function's call.
2. Broacasted configurations (the last broadcasted value, including configurations set by Java).
3. Function signature defaults.
Next, let us look at some functions creating the main body of an architecture.
First, gcnlayer
accepts
two parameters: an adjacency matrix A
and input feature matrix h
.
The configuration hidden: 64
in the functions's signature
specifies the deafult number of hidden units,
whereas reg: 0.005
is the L2 regularization applied
during machine learning. The questionmark ?
in matrix definitions lets the autosize feature of JGNN determine
dimension sizes based on a test run - if possible.
Finally, the function returns the activated output of a
GCN layer. Similarly, look at the gcn
function. This declares
the GCN architecture and has as configuration the number of output classes.
The function basically consists of two gcnlayer
layers,
where the second's hidden units are set to the value of output classes.
The number of classes is unknown as of writting the model, and thus is externally declared
with the extern
keyword to signify that this value should always by provided
by Java's side of the implementation.
fn gcnlayer(A, h, hidden: 64, reg: 0.005) {
return A@h@matrix(?, hidden, reg) + vector(hidden);
}
fn gcn(A, h, classes: extern) {
h = gcnlayer(A, h);
h = dropout(relu(h), 0.5);
return gcnlayer(A, h, hidden: classes);
}
We now move to parsing our declarations with the Neuralang
model builder and using them to create an architecture. To this end, save your code
to a file and get is as a path Path architecture = Paths.get("filename.nn");
,
or avoid external files by inlining the definition within Java code through
a multiline string per String architecture = """ ... """;
.
Below, this string is parsed within a functional programming chain, where
each method call returns the modelBuilder instance to continue calling more methods.
For the model builder, the following snippet sets remaining hyperparameters
and overwrites the default value
for "hidden"
. It also specifies
that certain variables are constants, namely the adjacency matrix A
and node
representation h
, as well as that node identifiers is a variable that serves
as the architecture's inpu. There could be multiple inputs, so this distinction of what
is a constant and what is a variable depends mostly on which quantities change
during training and is managed by onlt the Java-side of the code.
In the case of node classification, both the adjacency matrix and
node features remain constant, as we work in one graph. Finally, the definition
sets an Neuralang expression as the architecture's output
by calling the .out(String)
method,
and applies the .autosize(Tensor...)
method to infer hyperparameter
values denoted with ?
from an example input.
For faster completion of the model, provide a dataless list of node identifiers as input,
like below.
long numSamples = dataset.samples().getSlice().size();
long numClasses = dataset.labels().getCols();
ModelBuilder modelBuilder = new Neuralang()
.parse(architecture)
.constant("A", dataset.graph())
.constant("h", dataset.features())
.var("nodes")
.config("classes", numClasses)
.config("hidden", numClasses+2) // custom number of hidden dimensions
.out("classify(nodes, gcn(A,h))") // expression to parse into a value
.autosize(new EmptyTensor(numSamples));
System.out.println("Preferred learning rate: "+modelBuilder.getConfig("lr"));
3.4. Debugging
JGNN offers high-level tools for debugging
architectures. Here we cover what diagnostics to run, and how to make
sense of error messages to fix erroneous
architectures. We already mentioned that model builder
symbols should be assigned to before
subsequent use. For example, consider a FastBuilder
that
tries to parse the expression .layer("h{l+1}=relu(hl@matrix(features, 32, reg)+vector(32))")
,
where hl
is a typographical error of
h{l}
. In this case, an exception is thrown:
Exception in thread "main" java.lang.RuntimeException: Symbol hl not defined.
Internally, models are effectively directed acyclic graphs (DAGs)
that model builders create. DAGs should not be confused with the graphs
that GNNs architectures analyse; they are just an organization of data flow
between NNComponent
s. During parsing, builders
may create temporary variables, which start with
the _tmp
prefix and are followed by
a number. Temporary variables often link
components to others that use them.
The easiest way to understand execution DAGs is
to look at them. The library provides two tools
for this purpose: a .print()
method
that prints built functional flows in the system
console, and a .getExecutionGraphDot()
method that returns a string holding the execution graph in
.dot format for visualization with
tools like GraphViz.
Another error-checking procedure consists of
an assertion that all model operations eventually affect
at least one output. Computational branches that lead nowhere mess up the
DAG traversal during backpropagation and should be checked with the
method .assertBackwardValidity()
.
The latter throws an exception if an invalid model is found.
Performing this assertion early on in
model building will likely throw exceptions that
are not logical errors, given that independend
outputs may be combined later. Backward validity errors
look like this the following example. This
indicates that the component
_tmp102
does not lead to an output,
and we should look at the execution tree to
understand its role.
Exception in thread "main" java.lang.RuntimeException: The component class mklab.JGNN.nn.operations.Multiply: _tmp102 = null does not lead to an output
at mklab.JGNN.nn.ModelBuilder.assertBackwardValidity(ModelBuilder.java:504)
at nodeClassification.APPNP.main(APPNP.java:45)
Some tensor or matrix methods do not
correspond to numerical operations but
are only responsible for naming dimensions.
Functionally, such methods are largely decorative,
but they cab improve debugging by throwing errors for
incompatible non-null names. For example,
adding two matrices with different dimension
names will result in an error. Likewise, the
inner dimension names during matrix
multiplication should agree.
Arithmetic operations, including
matrix multiplication and copying,
automatically infer dimension names in the
result to ensure that only compatible data types
are compared. Dimension name changes
do not
backtrack the changes, even for see-through
data types, such as the outcome of
asTransposed()
.
Matrices effectively have three
dimension names: for their rows, columns,
and inner data as long as they are treated
as tensors.
Operation | Comments |
---|---|
Tensor setDimensionName(String name) |
For naming tensor dimensions (of the 1D space tensors lie in). |
Tensor setRowName(String rowName) |
For naming what kind of information matrix rows hold (e.g., "samples"). Defined only to matrices. |
Tensor setColName(String colName) |
For naming what kind of information matrix columns hold (e.g., "features"). Defined only for matrices. |
Tensor setDimensionName(String rowName, String colName) |
A shorthand of calling
setRowName(rowName).setColName(colName) . Defined only for matrices.
|
There are two main mechanisms for identifying logical errors within architectures: a) mismatched dimension size, and b) mismatched dimension names. Of the two, dimension sizes are easier to comprehend since they just mean that operations are mathematically invalid. On the other hand, dimension names need to be determined for starting data, such as model inputs and parameters, and are automatically inferred from operations on such primitives. For in-line declaration of parameters in operations or layers, dimension names are copied from any hyperparameters. Therefore, for easier debugging, prefer using functional expressions that declare hyperparameters, like below.
new ModelBuilder()
.config("features", 7)
.config("hidden", 64)
.var("x")
.operation("h = x@matrix(features, hidden)");
Both mismatched dimensions and mismatched dimension names throw runtime exceptions. The beginning of their error console traces should start with something like this:
java.lang.IllegalArgumentException: Mismatched matrix sizes between SparseMatrix (3327,32) 52523/106464 entries and DenseMatrix (64, classes 6)
During the forward pass of class mklab.JGNN.nn.operations.MatMul: _tmp4 = null with the following inputs:
class mklab.JGNN.nn.activations.Relu: h1 = SparseMatrix (3327,32) 52523/106464 entries
class mklab.JGNN.nn.inputs.Parameter: _tmp5 = DenseMatrix (64, classes 6)
java.lang.IllegalArgumentException: Mismatched matrix sizes between SparseMatrix (3327,32) 52523/106464 entries and DenseMatrix (64, classes 6)
at mklab.JGNN.core.Matrix.matmul(Matrix.java:258)
at mklab.JGNN.nn.operations.MatMul.forward(MatMul.java:21)
at mklab.JGNN.nn.NNOperation.runPrediction(NNOperation.java:180)
at mklab.JGNN.nn.NNOperation.runPrediction(NNOperation.java:170)
at mklab.JGNN.nn.NNOperation.runPrediction(NNOperation.java:170)
at mklab.JGNN.nn.NNOperation.runPrediction(NNOperation.java:170)
at mklab.JGNN.nn.NNOperation.runPrediction(NNOperation.java:170)
at mklab.JGNN.nn.NNOperation.runPrediction(NNOperation.java:170)
...
This particular stack trace tells us that the architecture
encounters mismatched matrix sizes when trying
to multiply a 3327x32 SparseMatrix with a 64x6
dense matrix. Understanding the exact error is
easy—the inner dimensions of matrix
multiplication do not agree. However, we need to
find the error within our architecture to fix
it. To do this, the error message message states
that the error occures.
During the forward pass of class mklab.JGNN.nn.operations.MatMul: _tmp4 = null
. This tells us that the problem
occurs when trying to calculate
_tmp4
, which is currently assigned
a null
tensor as value. Some more
information is available to see what the
operation's inputs are like—in this case, they
coincide with the multiplication's inputs, but
this will not always be the case.
The important point is to go back to the
execution tree and see during which exact
operation this variable is defined. There, we
will undoubtedly find that some dimension had 64
instead of 32 elements or conversely.
In addition to all other debugging mechanisms,
JGNN presents a way to show when forward and
backward operations of specific code components
are executed and with what kinds of arguments.
This can be particularly useful when testing new
components in real (complex) architectures.
The practice consists of calling a
monitor(...)
function within
operations. This does not affect what
expressions do and only enables printing
execution tree operations on operation
components. For example, the next snippet
monitors the outcome of matrix multiplication:
builder.operation("h = relu(monitor(x@matrix(features, 64)) + vector(64))")
4. Training
Here we describe how to train a JGNN model created with the previous section's builders. Broadly, we need to load some reference data and employ an optimization scheme to adjust trainable parameter values based on the differences between desired and current outputs. To this end, we start by describing generic patterns for creating graph and node feature data, and then move to specific data organizations for the tasks of node classification and graph classification. These tasks have helper classes that implement common training schemas (reach out with requests for helper classes for other kinds of predictive tasks in the project's GitHub issues).
4.1. Create data
JGNN contains dataset classes that automatically download and load
datasets for out-of-the-box experimentation. These datasets can be found
in the
adhoc.datasets Javadoc, and we already covered their usage patterns.
In practice, though, you will want to
use your own data. In the simplest case, both the number of nodes or data samples, and
the number of feature dimensions are known beforehand. If so, create
dense feature matrices with the following code. This uses the
minimum memory necessary to construct the feature matrix. If
features are dense (do not have a lot of zeros),
consider using the DenseMatrix
class
instead of initializing a sparse matrix, like below.
Matrix features = new SparseMatrix(numNodes, numFeatures);
for(long nodeId=0; nodeId<numNodes; nodeId++)
for(long featureId=0; featureId<numFeatures; featureId++)
features.put(nodeId, featureId, 1);
Sometimes, it is easier to read node or sample features line-by-line, for instance, when reading a .csv file. In this case, store each line as a separate tensor. Convert a list of tensors representing row vectors into a feature matrix like in the example below.
ArrayList rows = new ArrayList();
try(BufferedReader reader = new BufferedReader(new FileReader(file))){
String line = reader.readLine();
while (line != null) {
String[] cols = line.split(",");
Tensor features = new SparseTensor(cols.length);
for(int col=0;col<cols.length;col++)
features.put(col, Double.parseDouble(cols[col]));
rows.add(features);
line = reader.readLine();
}
}
Matrix features = new WrapRows(rows).toSparse();
Creating adjacency matrices is similar to creating
preallocated feature matrices. When in doubt, use the sparse
format for adjacency matrices, as the allocated memory of dense
counterparts scales qudratically to the number of nodes. Note that many GNNs
consider bidirectional (i.e., non-directed) edges, in which case
both directions should be added to the adjacency. Use the following snippet as a
template. Recall that JGNN follows a function chain notation, so each modification
returns the matrix
instance.
Don't forget to normalize or apply the renormalization trick (self-edges) on matrices if these
are needed by your architecture, for instance by calling
adjacency.setMainDiagonal(1).setToSymmetricNormalization();
after matrix creation.
Matrix adjacency = new SparseMatrix(numNodes, numNodes);
for(Entry<Long, Long> edge : edges)
matrix
.put(edge.getKey(), edge.getValue(), 1)
.put(edge.getValue(), edge.getKey(), 1);
All tensor operations can be viewed in the
core.tensor
and core.matrix
Javadoc. The Matrix
class extends the concept
of tensors with additional operations, like transposition, matrix multiplication,
and row and column access. Under the
hood, matrices linearly store elements and use
computations to transform the (row, col)
position of their elements to respective
positions. The outcome of some methods inherited
from tensors may need to be typecast back into a
matrix (e.g., for all in-place operations).
Operations can be split into arithmetics that combine the values
of two tensors to create a new one (e.g., Tensor add(Tensor)
),
in-place arithmetics that alter a tensor without creating
a new one (e.g., Tensor selfAdd(Tensor)
),
summary statistics that output simple numeric values (e.g., double Tensor.sum()
),
and element getters and setters.
In-place arithmetics follow the same naming
conventions of base arithmetics but their method names begin with a "self"
prefix for pairwise operations and a "setTo" prefix
for unary operations. Since they do not allocate new memory,
prefer them for intermediate calculation steps.
For example, the following code can be
used for creating and normalizing a tensor of
ones without using any additional memory.
Tensor normalized = new DenseTensor(10)
.setToOnes()
.setToNormalized();
Initialize a dense or sparse tensor -both of which represent one-dimensional vectors- with its number
of elements. If there are many zeros expected,
prefer using a sparse tensor. For example, one-hot encodings for classification
problems can be generated with the following
code. This creates a dense tensor with
numClasses
elements and puts at
element classId
the value 1:
int classId = 1;
int numClasses = 5;
Tensor oneHotEncoding = new mklab.JGNN.tensor.DenseTensor(numClasses).set(classId, 1); // creates the tensor [0,1,0,0,0]
The above snippets all make use of numerical node identifiers. To
manage these, JGNN provides an IdConverter
class;
convert hashable objects (typically strings) to identifiers by calling
IdConverter.getOrCreateId(object)
. Also use
converters to one-hot encode class labels. To search only for previously
registered identifiers, use IdConverter.get(object)
.
For example, construct a label matrix with the following snippet.
In this, nodeLabels
is a dictionary
from node identifiers to node labels that is being converted to a sparse matrix.
IdConverter nodeIds = new IdConverter();
IdConverter classIds = new IdConverter();
for(Entry<String, String> entry : nodeLabels) {
nodeids.getOrCreateId(entry.getKey());
classIds.getOrCreateId(entry.getValue());
}
Matrix labels = new SparseMatrix(nodeIds.size(), classIds.size());
for(Entry<String, String> entry : nodeLabels)
labels.put(nodeids.get(entry.getKey()), classIds.get(entry.getValue()), 1);
Reverse-search the converter to obtain the original object
of predictions per IdConverter.get(String)
. The following example
accesses one row of a label matrix, performs and argmax operation to find the position of the
maximum element, and reconstruct the label for the corresponding row with reverse-search.
long nodeId = nodeIds.get("nodeName");
Tensor prediction = labels.accessRow(nodeId);
long predictedClassId = prediction.argmax();
System.out.println(classIds.get(predictedClassId));
4.2. Node classification
Node classification models can be backpropagated by considering a list of node indeces and desired predictions for those nodes. We first show an automation of the training process that controls it in a predictable manner.
This section is under construction.Slice nodes = dataset.samples().getSlice().shuffle(100); // or nodes = new Slice(0, numNodes).shuffle(100);
Model model = modelBuilder()
.getModel()
.init(new XavierNormal())
.train(trainer,
nodes.samplesAsFeatures(),
dataset.labels(),
nodes.range(0, trainSplit),
nodes.range(trainSplit, validationSplit));
4.3. Graph classification
Most neural network architectures are designed with the idea
of learning to classify nodes or samples. However, GNNs also
provide the capability to classify entire graphs based on
their structure. To define architectures for graph classification, we use
the generic LayeredBuilder
class. The main
difference compared to traditional neural networks is
that architecture inputs do not all exhibit the same
size (e.g., some graphs may have more nodes than others)
and therefore cannot be organized into tensors of common
dimensions. Instead, assume that training data are stored in the
following lists:
ArrayList adjacencyMatrices = new ArrayList();
ArrayList nodeFeatures = new ArrayList();
ArrayList graphLabels = new ArrayList();
The LayeredBuilder
class introduces the
input variable h0
for sample
features. We can use it to pass node features to the
architecture, so we only need to add a second input
storing the (sparse) adjacency matrix:
.var("A")
We can then proceed to define a GNN architecture, for instance as explained in previous tutorials. This time, though, we aim to classify entire graphs rather than individual nodes. For this reason, we need to pool top layer node representations, for instance by averaging them across all nodes:
.layer("h{l+1}=softmax(mean(h{l}, dim: 'row'))")
Finally, we need to set up the top layer as the built
model's output per: .out("h{l}")
An example architecture following these principles follows:
ModelBuilder builder = new LayeredBuilder()
.var("A")
.config("features", nodeLabelIds.size())
.config("classes", graphLabelIds.size())
.config("hidden", 16)
.layer("h{l+1}=relu(A@(h{l}@matrix(features, hidden)))")
.layer("h{l+1}=relu(A@(h{l}@matrix(hidden, classes)))")
.layer("h{l+1}=softmax(mean(h{l}, dim: 'row'))")
.out("h{l}");
For the time being, training architectures like the above on prepared data requires manually calling the backpropagation for each epoch and each graph in the training batch. To do this, first retrieve the model and initialize its parameters:
Model model = builder.getModel()
.init(new XavierNormal());
Next, define a loss function and set up a batch
optimization strategy wrapping any base optimizer and
accumulating parameter updates until
BatchOptimizer.updateAll()
is called later
on:
Loss loss = new CategoricalCrossEntropy();
BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.01));
Finally, training can be conducted by iterating through
epochs and training samples and appropriately calling
the Model.train
for combinations of node
features and graph adjacency matrix inputs and graph
label outputs. At the end of each batch (e.g., each
epoch), don't forget to call the
optimizer.updateAll()
method to apply the
accumulated gradients. This process can be realized with
the following code:
for(int epoch=0; epoch<300; epoch++) {
for(int graphId=0; graphId<graphLabels.size(); graphId++) {
Matrix adjacency = adjacencyMatrices.get(graphId);
Matrix features = nodeFeatures.get(graphId);
Tensor label = graphLabels.get(graphId);
model.train(loss, optimizer,
Arrays.asList(features, adjacency),
Arrays.asList(label));
}
optimizer.updateAll();
}
To speed up graph classification, use JGNN's parallelization capabilities to calculate gradients across multiple threads. Parallelization for node classification holds little meaning, as the same propagation mechanism needs to be run on the same graph in parallel. However, this process yields substantial speedup for the graph classification problem. Parallelization can use JGNN's thread pooling to perform gradients, wait for the conclusion of submitted tasks, and then apply the accumulated gradient updates. This is achieved through a batch optimizer that accumulates gradients in the following example:
for(int epoch=0; epoch<500; epoch++) {
// gradient update
for(int graphId=0; graphId<dtrain.adjucency.size(); graphId++) {
int graphIdentifier = graphId;
ThreadPool.getInstance().submit(new Runnable() {
@Override
public void run() {
Matrix adjacency = dtrain.adjucency.get(graphIdentifier);
Matrix features= dtrain.features.get(graphIdentifier);
Tensor graphLabel = dtrain.labels.get(graphIdentifier).asRow();
model.train(loss, optimizer,
Arrays.asList(features, adjacency),
Arrays.asList(graphLabel));
}
});
}
ThreadPool.getInstance().waitForConclusion(); // waits for all gradients to finish calculating
optimizer.updateAll();
double acc = 0.0;
for(int graphId=0; graphId<dtest.adjucency.size(); graphId++) {
Matrix adjacency = dtest.adjucency.get(graphId);
Matrix features= dtest.features.get(graphId);
Tensor graphLabel = dtest.labels.get(graphId);
if(model.predict(Arrays.asList(features, adjacency)).get(0).argmax()==graphLabel.argmax())
acc += 1;
System.out.println("iter = " + epoch + " " + acc/dtest.adjucency.size());
}
}