JGNN
Resource efficient machine learning and graph neural networks in native Java.
Graph Neural Networks (GNNs) are getting more and more popular, for example to make predictions based on relational information, or to perform inference on small datasets. JGNN is a library that provides cross-platform implementations of this paradigm and traditional neural networks without the need for dedicated hardware or firmware; create highly portable models that fit and are trained in a few megabytes of memory.
Keep in mind that this is not a library for computationally intensive architectures; it does support multiple CPU cores and contains highly optimized code, but has no GPU support and we do not plan to add any (unless such support gets integrated in the Java virtual machine). So, while complex architectures are supported and scale to graphs with many nodes, running them fastly requires compromises in the number of learned parameters or running time.
1. Setup
The simplest way to set up JGNN is to download it as a JAR package from
releases
and add it your Java project's build path. Those working with Maven
or Gradle can instead add JGNN's latest nightly release as a dependency from its JitPack
distribution. Follow the link below for full instructions.
For example, the fields in the snippet below may be added in a Maven pom.xml file to work with the latest nightly release.
<repositories>
<repository>
<id>jitpack.io</id>
<url>https://jitpack.io</url>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>com.github.MKLab-ITI</groupId>
<artifactId>JGNN</artifactId>
<version>SNAPSHOT</version>
</dependency>
</dependencies>
2. Quickstart
Here we demonstrate JGNN for node classification. This is an inductive learning task that fills in node labels given a graph's structure, node features, and some already known labels. Classifying graphs is also supported, although it is a harder task to explain details.
GNN architectures are typically written as message passing mechanisms; they diffuse node representations across edges, where node neighbors pick up, aggregate (e.g., average), and transform incoming representations to update theirs. Alternatives that boast higher expressive power also exist and are supported, but simple architectures may be just as good or better than complex alternatives in solving practical problems [Krasanakis et al., 2024]. Simple architectures also enjoy reduced resource consumption.
Start with the following commands that load the Cora
dataset from those shipped
with the library for out-of-the-box testing. The first time an instance of this dataset is created,
it downloads raw data from a web resource and stores them in a local downloads/
folder. The data are then loaded into a sparse graph adjacency matrix, a dense node feature matrix,
and a dense node label matrix.
Each row in those matrices contains the corresponding node's
neighbors, features, or one-hot encoding of labels.
The second command applies the renormalization trick and symmetric normalization on the adjacency matrix; these respectively make GNN computations numerically stable by adding self-loops to all nodes and enable graph spectral theory properties. The in-place variation of those operations is used for minimal memory footprint.
Dataset dataset = new Cora();
dataset.graph().setMainDiagonal(1).setToSymmetricNormalization();
Sparse and dense representations are interoperable. Sparsity is more efficient for lots of zeros, so internal computations automatically select matrix types for outcomes.
We now incrementally create a trainable model using symbolic expressions that resemble math
notation. The expressions are part of a scripting language called Neuralang. For fast onboarding, stick to
the FastBuilder
class for creating models, which ommits some of
the language's features in favor of providing programmatic shortcuts for boilerplate. Its constructor
accepts two arguments A
and h0
, respectivel holding
the graph's adjacency matrix and node features. These are set as constant symbols and can be used in expressions;
other constants and input variables can be set too,
but more on this later.
JGNN promotes functional programming method chains, where the builder's instance is returned by each of
its methods. Below we use this pattern to implement the Graph Convolutional Network (GCN)
architecture [Kipf and Welling, 2017].
For the time being, notice the matrix
and vector
functions in scripted expressions;
these inline declarations of learnable parameter for
given dimensions and regularization.
long numSamples = dataset.samples().getSlice().size();
long numClasses = dataset.labels().getCols();
ModelBuilder modelBuilder = new FastBuilder(dataset.graph(), dataset.features())
.config("reg", 0.005)
.config("classes", numClasses)
.config("hidden", 64)
.function("gcnlayer", "(A,h){Adrop = dropout(A, 0.5); return Adrop@(h@matrix(?, hidden, reg))+vector(?);}")
.layer("h{l+1}=relu(gcnlayer(A, h{l}))")
.config("hidden", "classes") // reassigns the output gcnlayer's "hidden" to be equal to the number of "classes"
.layer("h{l+1}=gcnlayer(A, h{l})")
.classify()
.autosize(new EmptyTensor(numSamples));
Commonly used methods of FastBuilder
.
config
- Configures hyperparameter values. These can be used in all subsequent function and layer declarations.function
- Declares a Neuralang function, in this case with inputsA
andh
.layer
- Declares a layer that can use builtin and Neuralang functions. In this, the symbols{l}
and{l+1}
specifically are replaced by a layer counter.classify
- Adds a softmax layer tailored to classification. This also silently declares an inputnodes
that represents a list of node indices where the outputs should be computed.autosize
- Automatically sizes matrix and vector dimensions that were originally defnoted with a questionmark?
. This method requires some input example, and here we provide a list of node identifiers, which we also make dataless (have only the correct dimensions without allocating memory). This method also checks for integrity errors in the declared architecture, such as computational paths that do not lead to an output.
Training can be implemented
manually, by using inputs to compute outputs on the built model, computing losses, and triggering backpropagation
on an optimizer. JGNN automates common
patterns by extending a base ModelTraining
class with strategies
tailored to different data formats and predictive tasks. Find these subclasses in the
adhoc.train
Javadoc. Instances of model trainers
accept chain notation to set parameters like training and validation data
(these should be created first and depend on the model training class) and aspects of the learning strategy like the number of epochs, patience
for early stopping, the employed optimizer, and loss functions. An example is presented below.
Slice nodes = dataset.samples().getSlice().shuffle(); // permutes node int ids
Matrix inputFeatures = Tensor.fromRange(nodes.size()).asColumn();
Loss loggingLoss = new VerboseLoss(new CategoricalCrossEntropy(), new Accuracy())
.setInterval(10);
ModelTraining trainer = new SampleClassification()
// data
.setFeatures(inputFeatures)
.setLabels(dataset.labels())
.setTrainingSamples(nodes.range(0, 0.6))
.setValidationSamples(nodes.range(0.6, 0.8))
// learning strategy
.setOptimizer(new Adam(0.01))
.setEpochs(3000)
.setPatience(100)
.setLoss(new CategoricalCrossEntropy())
.setValidationLoss(loggingLoss);
Model model = modelBuilder.getModel()
.init(new XavierNormal())
.train(trainer);
About the example training strategy.
Of data needed for training, the graph adjacency matrix and node features are already declared as constants by the
FastBuilder
constructor, as node classification takes place on the same graph
with fully known node features. Thus, input features are a column of node identifiers, which the
classify
method uses to gather
the predictions on respective nodes. Architecture outputs are softmax approximation of the one-hot
encodings of respective node labels. The simplest way to handle missing labels for test data without modifying
the example is to leave their one-hot encodings as zeroes only.
Additionally, this particular training strategy accepts training and validation data slices, where slices are lists
of integer entries pointing to rows of inputs and outputs - find more later.
The example further selects
Adam
optimization with learning rate 0.01, and training
over many epochs with early stopping. A verbose
loss prints every 10 epochs the progress of cross entropy and accuracy on validation data, where the
first of these two is used for the early stopping criterion.
To run a full training process, pass a strategy to a model.
In a cold start scenario, apply a parameter initializer first before training is conducted.
A warm start that resumes training from some previously trained outcomes would skip this step.
Selecting an initializer is not part of the learning strategy
to signify its model-dependent nature; dense layers should maintain the expected
input variances in the output before the first epoch, and therefore the initializer depends
on the type of activation functions.
Trained models and their generating builders can be saved and loaded. The next snippet demonstrates how raw predictions can be made too. During this process, some matrix manipulation operations are employed to obtain transparent access to parts of input and output data of the dataset without copying memory.
modelBuilder.save(Paths.get("gcn_cora.jgnn")); // needs a Path as an input
Model loadedModel = ModelBuilder.load(Paths.get("gcn_cora.jgnn")).getModel();
Matrix output = loadedModel
.predict(Tensor.fromRange(0, nodes.size()).asColumn())
.get(0) // get our one output from the list of outputs
.cast(Matrix.class); // functional cast
double acc = 0;
for(Long node : nodes.range(0.8, 1)) {
Matrix nodeLabels = dataset.labels().accessRow(node).asRow();
Tensor nodeOutput = output.accessRow(node).asRow();
acc += nodeOutput.argmax()==nodeLabels.argmax()?1:0;
}
System.out.println("Acc\t "+acc/nodes.range(0.8, 1).size());
3. GNN Builders
We already touched on the subject of model builders in the previous section, where one of them was used to create a model. There exist different kinds of builders that offer different conveniences. First is a base class for parsing simple Neuralang expressions. Second is a fast builder that contains GNN declation boilerplate. Last is a full parser for Neuralang that maintains model definitions in a string or file while interacting with Java code. Hwe three corresponding builder classes are covered. We also summarize debugging mechanisms for checking the integrity of constructed models, visualize their data flow, and monitor specific data at runtime.
3.1. ModelBuilder
This is the base model builder class; it offers a wide breadth of functionalities that other builders extend. Models take tensors as input and outputs. Tensors will be covered later; for now, it suffices to think of them as numerical vectors, which are sometimes endowed with matrix dimensions. The models themselves are built from Java classes that indicate sub-operations. However, this can be too verbose many lines of code are needed to declare even simple expressions, making models cumbersome to read and maintain - hence the need for builders that construct the models from concise symbolic expressions.
To create a model with the ModelBuilder
class,
instantiating the builder, use a method chain to declare an input variable
with the .var(String)
method, parse an expression with the
.operation(String)
method, and finally declare which symbol holds
outputs with the .out(String)
method.
The first and last of these methods can be called multiple times
to declare several inputs and outputs. Inputs need to be only one symbol, but a whole expression
for evaluation can be declared in outputs. Obtain the created model's instance with the .getModel()
method.
After defining models, use them to make predictions like below. The prediction method takes as input one or more comma-separated tensors that match the model's inputs (in the same order) and computes a list of output tensors. If inputs are dynamically created, an overloaded version of the same method supports passing an array list of input tensors instead.
ModelBuilder modelBuilder = new ModelBuilder()
.var("x")
.operation("y = log(2*x+1)")
.out("y");
Model model = modelBuilder.getModel();
System.out.println(model.predict(Tensor.fromDouble(2)));
Equivalent implementation without the builder.
Under the hood, JGNN models are collections of NNOperation
instances, each representing a numerical computation with
specified inputs and outputs that are subclasses of
JGNN's base Tensor
type. Computations and the state of operations are thread-safe, so that the same model
can run and be trained in multiple threads simultaneously. This guidebook does not list operation subclasses, as they are rarely used directly and can be found
in the Javadoc under the modules
nn.inputs,
nn.activations,
and
nn.pooling.
Create models in pure Java like the example below, which does not have any trainable parameters.
Variable x = new Variable();
Constant c1 = new Constant(Tensor.fromDouble(1)); // holds the constant "1"
Constant c2 = new Constant(Tensor.fromDouble(2)); // holds the constant "2"
NNOperation mult = new Multiply()
.addInput(x)
.addInput(c2);
NNOperation add = new Add()
.addInput(mult)
.addInput(c1);
NNOperation y = new Log()
.addInput(add);
Model model = new Model()
.addInput(x)
.addOutput(y);
System.out.println(model.predict(Tensor.fromDouble(2))); // one-element input
Differences between expression parsing and Neuralang.
The operation
method parses string expressions that are typically structured
as assignments to symbols; the right-hand side of assignments accepts several operators and functions that
are listed in the next table. Models allow multiple operations too, which are parsed through either multiple
method calls or by being separated with a semicolon ;
within larger string expressions.
All methods need to use previously declared symbols. For example, parsing .out("symbol")
throws an exception if no operation previously assigned to the symbol or declared it as an input. As a safety
mechanism against consequenses of the parsing engine that conflict with typical programming logic
symbols cannot be overwritten or set to updated values outside of Neuralang functions.
The base model builder
class does support a roundabout declaration of Neuralang functions with expressions like the following snippet taken from the Quickstart:
.function("gcnlayer", "(A,h){return A@(h@matrix(?, hidden, reg))+vector(?);}")
.
The first method argument is the declared function symbol's name, and the second should necessarily have the arguments enclosed in
a parenthesis and the function's body enclosed in brackets. Learn more about Neuralang functions in the language's description
in the namesake section.
Table of parsable expressions.
Here is a table of available operations that you can use in expressions. Standard priority rules for priority and parentheses apply. Prefer using configuration hyperparameters to set matrix and vector creation, as these transfer their names to respective dimensions for error checking - more on this later.
Symbol | Type | Description |
---|---|---|
x = expr |
Operator | Assign to variable x the outcome of executing expression expr . This expression does not evaluate to anything. |
x + y |
Operator | Element-by-element addition. |
x * y |
Operator | Element-by-element multiplication. |
x - y |
Operator | Element-by-element subtraction. |
x @ y |
Operator | Matrix multiplication. |
x | y |
Operator | Row-wise concatenation of x and y . |
x [y] |
Operator | Gathers the rows of x with indexes y . Indexes are still tensors, whose elements are cast to integers during this operation. |
transpose(A) |
Function | Transposes matrix A . |
log(x) |
Function | Applies a logarithm on each element of tensor x . |
exp(x) |
Function | Exponentiates each element of tensor x . |
nexp(x) |
Function | Exponentiates each non-zero element of tensor x . Typically used for neighbor attention (see below). |
relu(x) |
Function | Apply relu on each tensor element. |
tanh(x) |
Function | Apply a tanh activation on each tensor element. |
sigmoid(x) |
Function | Apply a sigmoid activation on each tensor element. |
dropout(x, rate) |
Function | Apply training dropout on tensor x with constant dropout rate hyperparameter rate . |
drop(x, rate) |
Function | Shorthand notation dropout . |
lrelu(x, slope) |
Function | Leaky relu on tensor x with constant negative slope hyperparameter slope . |
prelu(x) |
Function | Leaky relu on tensor x with learnable negative slope. |
softmax(x, dim) |
Function | Apply a softmax reduction on x , where dim is either dim:'row' (default) or dim:'col' . |
sum(x, dim) |
Function | Apply a sum reduction on x , where dim is either dim:'row' (default) or dim:'col' . |
mean(x, dim) |
Function | Apply a mean reduction on x , where dim is either dim:'row' (default) or dim:'col' . |
L1(x, dim) |
Function | Apply an L1 normalization on x across dimension dim , where dim is either dim:'row' (default) or dim:'col' . |
L2(x, dim) |
Function | Apply an L2 normalization on x across dimension dim , where dim is either dim:'row' (default) or dim:'col' . |
max(x, dim) |
Function | Apply a max reduction on x , where dim is either dim:'row' (default) or dim:'col' . |
min(x, dim) |
Function | Apply a min reduction on x , where dim is either dim:'row' (default) or dim:'col' . |
matrix(rows, cols) |
Function | Generate a matrix parameter with respective hyperparameter dimensions. |
matrix(rows, cols, reg) |
Function | Generate a matrix parameter with respective hyperparameter dimensions, and L2 regularization hyperparameter reg . |
mat(rows, cols) |
Function | Shorthand notation matrix . |
mat(rows, cols, reg) |
Function | Shorthand notation matrix . |
vector(len) |
Function | Generate a vector with size hyperparameter len . |
vector(len, reg) |
Function | Generate a vector with size hyperparameter len , and L2 regularization hyperparameter reg . |
vec(len) |
Function | Shorthand notation vector . |
vec(len, reg) |
Function | Shorthand notation vector . |
Model definitions have so far been too simple to be employed in practice;
we need trainable parameters, which are created inline with the matrix
and vector
Neuralang functions. Do not use equivalent Java code, because
it is better to keep model definitions simple.
Additionally, there may be constants and configuration hyperparameters. Of these, constants reflect
untrainable tensors and set in a builder with .const(String, Tensor)
.
Both numbers in the last snippet's symbolic definition are internally parsed into constants.
On the other hand, configuration hyperparameters are numerical values used by the parser and
set for a builder with .config(String, double)
. Provide another
configuration's name as the second argument to copy its value.
On the other hand, hyperparameters can be used as arguments to dimension sizes and regularization.
Retrieve previously set builder hyperparameters though
.getConfigOrDefault(String, double)
, where the second argumement may be ommitted.
This is mostly useful for bringing into code hyperparameters declared in Neuralang scripts.
3.2. FastBuilder
The FastBuilder
class
extends the base model builder with boilerplate methods for setting inputs, outputs, and layers tailored
to specifif downstream tasks. Prefer this builder to keep trackof the whole model definition in one place within Java code.
The main difference compared to the base ModelBuilder
is that now we have two constuctor arguments, namely a square matrix A
that
is typically a normalization of the (sparse) adjacency matrix,
and a feature matrix h0
.
This builder further supports the notation symbol{l}
,
where the layer counter replaces the symbol part {l}
with 0 for the first layer,
1 for the second, and so on.
Prefer the notation h{l}
to refer to the node representation
matrix of the current layer; for the first layer, this is parsed as h0
, which is the constant
set by the constructor. This builder also offers a .layer(String)
method that is a variation of operation parsing where the
the symbol segment {l+1}
is substituted with the next layer's counter,
the expression is parsed, and the layer counter is incremented by one. Example usage is shown below, where
symbolic expressions read similarly to what you would find in a paper.
The base operation of message passing GNNs, which are often used for node classification,
is to propagate node representations to neighbors via graph edges. Then, neighbors aggregate
the received representation, where aggregation typically consists of a weighted average per
the normalized adjacency matrix's edge weights. For symmetric normalization, the
weighted sum is compatible with spectral graph signal processing. The operation to perform
one propagation can be written as .layer("h{l+1}=A @ h{l}")
.
The propagation's outcome is typically transformed further by passing through a dense
layer.
FastBuilder modelBuilder = new FastBuilder(adjacency, features) // sets A, h0
.layer("h{l+1}=relu(A@(h{l}@matrix(features, hidden, reg))+vector(hidden))") // h1 = ...
.layer("h{l+1}=A@(h{l}@matrix(hidden, classes, reg))+vector(classes)"); // h2 = ...
Improving representation diffusion at the cost of resources.
Several improvements of the representation diffusion scheme demonstrated above have been proposed. However, these tend to incur marginal accuracy improvements at the cost of a lot of compute and memory.
JGNN still supports the improvements listed below, since they could be used when running time is not a pressing issue (e.g., for transfer or stream learning that applies updates for a few epochs), or to analyse smaller graphs:
- Edge dropout - Applying dropout on the adjacency
matrix on each layer with
.layer("h{l+1}=dropout(A,0.5) @ h{l}")
. Usage of this operation seems innocuous, but it disables a bunch of caching optimizations that occur under-the-hood. - Heterogeneity - Some rcent approaches explicitly account for high-pass frequency diffusion by accounting for the graph Laplacian too. Insert this into the
architecture as a normal constant like so:
.constant("L", adjacency.negative().cast(Matrix.class).setMainDiagonal(1))
- Edge attention - Performs the dot product of edge nodes to create new edge weights
per the mathematical formula
A.(hTh)
, whereA
is a sparse adjacency matrix, the dot.
represents the Hadamard product (element-by-element multiplication), andh
is a dense matrix whose rows hold respective node representations. JGNN efficiently implements this operation with the Neuralang functionatt(A, h)
. For example, weighted adjacency matrices for each layer of gated attention networks are implemented as.operation("A{l} = L1(nexp(att(A, h{l})))")
. - General message passing - JGNN also supports the the fully generized
message passing scheme between node neighbors of more complex relational analysis
[Velickovic, 2022].
In this generalization, each edge is responsible for appropriately
transforming and propagating representations to node neighbors;
create message matrices whose rows correspond to edges and
columns to edge features by gathering the features of the edge
source and destination nodes. Programmatically,obtain edge source indexes
src=from(A)
and destination indexesdst=to(A)
, whereA
is the adjacency matrix. Then use the horizontal concatenation operation|
to concatenate node features. One may also concatenate edge features. Given a constructed message, define any kind of ad-hoc mechanism or neural processing of messages with traditional matrix operations (take care to define correct matrix sizes for dense transformations, e.g., twice the number of columns ash{l}
in the previous snippet). For any kind ofLayeredBuilder
, don't forget thatmessage{l}
within operations is needed to obtain a message from the representationsh{l}
that is not accidentally shared with future layers. Receiver mechanisms need to perform some kind of reduction on messages. JGNN implements summation reduction, given that this has the same theoretical expressive power as maximum-based reduction but is easier to backpropagate through. Perform this like below. The sum is weighted per the values of the adjacency matrixA
. Thus, perform adjacency matrix normalization only if you want such weighting to occur.modelBuilder .operation("src=from(A)") .operation("dst=to(A)") .operation("message{l}=h{l}[src] | h{l}[dst]") // has two times the number of h{l}'s features .operation("transformed_message{l}=...") // fill in the transformation .operation("received{l}=reduce(transformed_message{l}, A)");
So far, we discussed the propagation mechanisms of
GNNs, which consider the features of all nodes. However,
in node classification settings, training data labels
are typically available only for certain nodes, despite
all node features being required to make any prediction.
We thus need a mechanism to retrieve the predictions of the top
layer for those nodes, for example before applying a softmax.
This is achieved in the snippet below, which uses the gather
operations through brackets. Alternatively, chain the
FastBuilder.classify()
method, which injects this exact code.
modelBuilder
.var("nodes")
.layer("h{l} = softmax(h{l})")
.operation("output = h{l}[nodes]")
.out("output");
It is often the case that GNNs need to make predictions for the whole graph - a task that falls under the more general category of learning invariant GNNs. In the simplest case, node representations would be reduced for each input graph to obtain one representation. However, this is a symmetric operation and therefore fails to distinguish between the structural positioning of nodes to be pooled, which is often important. One computationally light alternative, which you can use in your models, is sorting nodes based on learned features before concatenating their features into one vector for each graph.
This process is further simplified by keeping the top reduced number of nodes to concatenate their features, where the order is determined by an arbitrarily selected feature (in our implementation: the sorting is made on the last feature, with the previous feature being used to break ties, and so on. The idea is that the selected feature determines important nodes whose information can be adopted by others. To implement this scheme, JGNN provides independent operations to sort nodes, gather node latent representations, and reshape matrices into row or column tensors with learnable transformations to class outputs. These components are demonstrated in the following code snippet:
long reduced = 5; // input graphs need to have at least that many nodes
long hidden = 8; // many latent dims reduce speed without GPU parallelization
ModelBuilder builder = new LayeredBuilder()
.var("A")
.config("features", 1)
.config("classes", 2)
.config("reduced", reduced)
.config("hidden", hidden)
.layer("h{l+1}=relu(A@(h{l}@matrix(features, hidden))+vector(hidden))")
.layer("h{l+1}=relu(A@(h{l}@matrix(hidden, hidden))+vector(hidden))")
.concat(2) // concatenates the outputs of the last 2 layers
.config("hiddenReduced", hidden*2*reduced) // 2* due to concatenation
.operation("z{l}=sort(h{l}, reduced)") // z{l} are node indexes
.layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)")
.layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)")
.layer("h{l+1}=softmax(h{l}, dim: 'row')")
.out("h{l}");
Invariant vs equivariant tasks.
Tasks like node classification consider equivariant GNNs, whose outputs follow any node permutations applied on their inputs. In simple terms, if the order of node idenfitifiers is modified (both in graph adjacency and in node feature matrices), the order of rows will be similarly modified for outputs. Most JGNN operations are equivariant, and therefore their synthesis is also equivariant.
However, there are cases where created GNNs should be invariant, which means that they should create predictions that remain the same despite any input permutations. Invariance is the property to impose when classifying graphs, where one prediction should be made for the whole graph.
Imposing invariance is simple enough; take an equivariant
architecture and then apply an invariant operation on top.
You may want to perform further transformations (e.g., some
dense layers) afterwards, but the general idea remains
the same. JGNN offers two types of invariant operations:
reductions and sort-based pooling covered above. Reductions
are straightforward to implement
by taking a dimensionality reduction mechanism (min
,
max
, sum
, mean
) and
applying it column-wise on the output feature matrix.
Recall that each row has the features of a different node,
so reduction yields an one-dimensional vector that,
for each feature dimension, aggregates feature values across all nodes.
3.3. Neuralang
Neuralang scripts consist of functions that declare machine learning
components. Use a Rust highlighter to cover all keywords.
Functions correspond to machine learning modules and call each other.
At their end lies a return
statement, which expresses their
outcome. All arguments are passed by value, i.e., any assignments are
performed on fresh variable instances.
Before explaining how to use the Neuralang
model builder,
we present and analyse code that supports a fully functional architecture.
First, look at the classify
function, which for completeness is presented below.
This takes two tensor inputs: nodes
that correspond to identifiers
insicating which nodes should be classified (the output has a number of rows equal to the
number of identifiers), and a node feature matrix h
.
It then computes and returns a softmax for the features of the specified nodes.
Aside from main inputs, the function's
signature also has several configuration values, whose defaults
are indicated by a colon :
(only configurations have defaults and conversely).
The same notation is used to
set/overwrite configurations when calling functions, as we do for softmax
to apply it row-wise. Think of configurations as keyword
arguments of typical programming languages, with the difference that
they control hyperparameters, like dimension sizes or regularization.
Exclamation marks !
before numbers broadcast values
to all subsequent function calls that have configurations with the same
name. The broadcasted defaults overwrite already existing defaults of configurations with the same
name anywhere in the code. All defaults are replaced by values explicitly set when calling functions.
For example, take advantage of this prioritization to force output layer dimensions match your data. Importantly,
broadcasted values are stored within JGNN's Neuralang
model
builder too; this is useful for Java integration, for example to retrieve learning training hyperparameters
from the model.
fn classify(nodes, h, epochs: !3000, patience: !100, lr: !0.01) {
return softmax(h[nodes], dim: "row");
}
Write exact values for configurations, as for now no arithmetics can be used to compute them. For example, setting patience:2*50 creates an error.
Configuration values have the following priority, from strongest to weakest:
1. Arguments set during the function's call.
2. Broacasts (the last broadcasted value, include configurations set in Java).
3. Function signature defaults.
Next, let us look at some functions creating the main body of an architecture.
In that, the questionmark ?
lets the autosize feature of JGNN determine dimension sizes based on a test run.
The number of classes is unknown as of writting the model, and thus is externally declared
with the extern
keyword to signify that this value should always be provided
in Java.
fn gcnlayer(A, h, hidden: 64, reg: 0.005) {
return A@h@matrix(?, hidden, reg) + vector(hidden);
}
fn gcn(A, h, classes: extern) {
h = gcnlayer(A, h);
h = dropout(relu(h), 0.5);
return gcnlayer(A, h, hidden: classes);
}
About the architecture.
First, gcnlayer
accepts
two parameters: an adjacency matrix A
and input feature matrix h
.
The configuration hidden: 64
in the functions's signature
specifies the deafult number of hidden units,
whereas reg: 0.005
is the L2 regularization applied
during machine learning.
Finally, the function returns the activated output of a
GCN layer.
Similarly, look at the gcn
function. This declares
the GCN architecture and has as configuration the number of output classes.
The function basically consists of two gcnlayer
layers,
where the second's hidden units are set to the value of output classes.
We now move to parsing our declarations with the Neuralang
model builder and using them to create an architecture. To this end, save your code
to a file and get it as a path Path architecture = Paths.get("filename.nn");
.
Avoid external files by inlining the definition within Java code through
a multiline string per String architecture = """ ... """;
.
Below, this string is parsed within a chain, where
each method call returns the model builder instance to continue calling more methods.
long numSamples = dataset.samples().getSlice().size();
long numClasses = dataset.labels().getCols();
ModelBuilder modelBuilder = new Neuralang()
.parse(architecture)
.constant("A", dataset.graph())
.constant("h", dataset.features())
.var("nodes")
.config("classes", numClasses)
.config("hidden", numClasses+2) // custom number of hidden dimensions
.out("classify(nodes, gcn(A,h))") // expression to parse into a value
.autosize(new EmptyTensor(numSamples));
System.out.println("Preferred learning rate: "+modelBuilder.getConfig("lr"));
About the java side.
The above snippet sets remaining hyperparameters
and overwrites the default value
for "hidden"
. It also specifies
that certain variables are constants, namely the adjacency matrix A
and node
representation h
, as well as that node identifiers is a variable that serves
as the architecture's inpu.
There could be multiple inputs, so the distinction of what is a constant and what is a variable depends mostly on which quantities change during training and is managed by only the Java side of the code. In the case of node classification, both the adjacency matrix and node features remain constant, as we work in one graph, but this is not the same for graph classification where many graphs are encountered.
Finally, the definition
sets an Neuralang expression as the architecture's output
by calling the .out(String)
method,
and applies the .autosize(Tensor...)
method to infer hyperparameter
values denoted with ?
from an example input.
For faster completion of the model, we provide a dataless list of node identifiers as input.
3.4. Debugging
JGNN offers high-level tools for debugging architectures. Here we cover what diagnostics to run, and how to make sense of error messages to fix erroneous architectures.
We already mentioned that model builder
symbols should be assigned to before
subsequent use. For example, consider a FastBuilder
that
tries to parse the expression .layer("h{l+1}=relu(hl@matrix(features, 32, reg)+vector(32))")
,
where hl
is a typographical error of
h{l}
. In this case, an exception is thrown:
Exception in thread "main" java.lang.RuntimeException: Symbol hl not defined.
Internally, models are effectively directed acyclic graphs (DAGs)
that model builders create. DAGs should not be confused with the graphs
that GNNs architectures analyse; they are just an organization of data flow
between NNComponent
s. During parsing, builders
may create temporary variables, which start with
the _tmp
prefix and are followed by
a number. Those variables often link
components to others that use them.
The easiest way to understand execution DAGs is
to look at them textually or visually: a .print()
method
that prints built functional flows in the system
console, and a .getExecutionGraphDot()
method that returns a string holding the execution graph in
.dot format for visualization with
tools like GraphViz.
Another error-checking procedure consists of
an assertion that all model operations eventually affect
at least one output. Computational branches that lead nowhere mess up the
DAG traversal during backpropagation and should be checked with the
method .assertBackwardValidity()
.
The latter throws an exception if an invalid model is found.
Performing this assertion early on in
model building will likely throw exceptions that
are not logical errors, given that independend
outputs may be combined later. Backward validity errors
look like this the following example. This
indicates that the component
_tmp102
does not lead to an output,
and we should look at the execution tree to
understand its role.
Exception in thread "main" java.lang.RuntimeException: The component class mklab.JGNN.nn.operations.Multiply: _tmp102 = null does not lead to an output
at mklab.JGNN.nn.ModelBuilder.assertBackwardValidity(ModelBuilder.java:504)
at nodeClassification.APPNP.main(APPNP.java:45)
Some tensor or matrix methods do not correspond to numerical operations but are only responsible for naming dimensions.
asTransposed()
.
Matrices effectively have three
dimension names: for their rows, columns,
and inner data as long as they are treated
as tensors.
How to manually set dimension names.
Operation | Comments |
---|---|
Tensor setDimensionName(String name) |
For naming tensor dimensions (of the 1D space tensors lie in). |
Tensor setRowName(String rowName) |
For naming what kind of information matrix rows hold (e.g., "samples"). Defined only to matrices. |
Tensor setColName(String colName) |
For naming what kind of information matrix columns hold (e.g., "features"). Defined only for matrices. |
Tensor setDimensionName(String rowName, String colName) |
A shorthand of calling
setRowName(rowName).setColName(colName) . Defined only for matrices.
|
There are two main mechanisms for identifying logical errors within architectures: a) mismatched dimension size, and b) mismatched dimension names. Of the two, dimension sizes are easier to comprehend since they just mean that operations are mathematically invalid. On the other hand, dimension names need to be determined for starting data, such as model inputs and parameters, and are automatically inferred from operations on such primitives. For in-line declaration of parameters in operations or layers, dimension names are copied from any hyperparameters. Therefore, for easier debugging, prefer using functional expressions that declare hyperparameters, similarly to the example.
new ModelBuilder()
.config("features", 7)
.config("hidden", 64)
.var("x")
.operation("h = x@matrix(features, hidden)");
Both mismatched dimensions and mismatched dimension names throw runtime exceptions. The beginning of their error console traces should start with something like this:
java.lang.IllegalArgumentException: Mismatched matrix sizes between SparseMatrix (3327,32) 52523/106464 entries and DenseMatrix (64, classes 6)
During the forward pass of class mklab.JGNN.nn.operations.MatMul: _tmp4 = null with the following inputs:
class mklab.JGNN.nn.activations.Relu: h1 = SparseMatrix (3327,32) 52523/106464 entries
class mklab.JGNN.nn.inputs.Parameter: _tmp5 = DenseMatrix (64, classes 6)
java.lang.IllegalArgumentException: Mismatched matrix sizes between SparseMatrix (3327,32) 52523/106464 entries and DenseMatrix (64, classes 6)
at mklab.JGNN.core.Matrix.matmul(Matrix.java:258)
at mklab.JGNN.nn.operations.MatMul.forward(MatMul.java:21)
at mklab.JGNN.nn.NNOperation.runPrediction(NNOperation.java:180)
at mklab.JGNN.nn.NNOperation.runPrediction(NNOperation.java:170)
at mklab.JGNN.nn.NNOperation.runPrediction(NNOperation.java:170)
at mklab.JGNN.nn.NNOperation.runPrediction(NNOperation.java:170)
at mklab.JGNN.nn.NNOperation.runPrediction(NNOperation.java:170)
at mklab.JGNN.nn.NNOperation.runPrediction(NNOperation.java:170)
...
How to read the error message.
This particular stack trace tells us that the architecture
encounters mismatched matrix sizes when trying
to multiply a 3327x32 SparseMatrix with a 64x6
dense matrix. Understanding the exact error is
easy—the inner dimensions of matrix
multiplication do not agree. However, we need to
find the error within our architecture to fix
it. To do this, the error message message states
that the error occures.
During the forward pass of class mklab.JGNN.nn.operations.MatMul: _tmp4 = null
.
This tells us that the problem
occurs when trying to calculate
_tmp4
, which is currently assigned
a null
tensor as value. Some more
information is available to see what the
operation's inputs are like—in this case, they
coincide with the multiplication's inputs, but
this will not always be the case.
The important point is to go back to the
execution tree and see during which exact
operation this variable is defined. There, we
will undoubtedly find that some dimension had 64
instead of 32 elements or conversely, and we
can fix the parsed expression responsible.
In addition to all other debugging mechanisms,
JGNN presents a way to show when forward and
backward operations of specific code components
are executed and with what kinds of arguments.
This can be particularly useful when testing new
components in real (complex) architectures.
The practice consists of calling a
monitor(...)
function within
operations. This does not affect what
expressions do and only enables printing
execution tree operations on operation
components. For example, the next snippet
monitors the outcome of matrix multiplication:
builder.operation("h = relu(monitor(x@matrix(features, 64)) + vector(64))")
4. Training
Here we describe how to train a JGNN model created with the previous section's builders. Broadly, we need to load some reference data and employ an optimization scheme to adjust trainable parameter values based on the differences between desired and current outputs. To this end, we start by describing generic patterns for creating graph and node feature data, and then move to specific data organizations for the tasks of node classification and graph classification. These tasks have helper classes that implement common training schemas (reach out with requests for helper classes for other kinds of predictive tasks in the project's GitHub issues).
4.1. Create data
JGNN contains dataset classes that automatically download and load
datasets for out-of-the-box experimentation. These datasets can be found
in the
adhoc.datasets Javadoc, and we already covered their usage patterns.
In practice, though, you will want to
use your own data. In the simplest case, both the number of nodes or data samples, and
the number of feature dimensions are known beforehand. If so, create
dense feature matrices with the following code. This uses the
minimum memory necessary to construct the feature matrix. If
features are dense (do not have a lot of zeros),
consider using the DenseMatrix
class
instead of initializing a sparse matrix, like below.
Matrix features = new SparseMatrix(numNodes, numFeatures);
for(long nodeId=0; nodeId<numNodes; nodeId++)
for(long featureId=0; featureId<numFeatures; featureId++)
features.put(nodeId, featureId, 1);
Sometimes, it is easier to read node or sample features line-by-line, for instance, when reading a .csv file. In this case, store each line as a separate tensor. Convert a list of tensors representing row vectors into a feature matrix like in the example below.
ArrayList rows = new ArrayList();
try(BufferedReader reader = new BufferedReader(new FileReader(file))){
String line = reader.readLine();
while (line != null) {
String[] cols = line.split(",");
Tensor features = new SparseTensor(cols.length);
for(int col=0;col<cols.length;col++)
features.put(col, Double.parseDouble(cols[col]));
rows.add(features);
line = reader.readLine();
}
}
Matrix features = new WrapRows(rows).toSparse();
Creating adjacency matrices is similar to creating
preallocated feature matrices. When in doubt, use the sparse
format for adjacency matrices, as the allocated memory of dense
counterparts scales qudratically to the number of nodes. Note that many GNNs
consider bidirectional (i.e., non-directed) edges, in which case
both directions should be added to the adjacency. Use the following snippet as a
template. Recall that JGNN follows a function chain notation, so each modification
returns the matrix
instance.
Don't forget to normalize or apply the renormalization trick (self-edges) on matrices if these
are needed by your architecture, for instance by calling
adjacency.setMainDiagonal(1).setToSymmetricNormalization();
after matrix creation.
Matrix adjacency = new SparseMatrix(numNodes, numNodes);
for(Entry<Long, Long> edge : edges)
matrix
.put(edge.getKey(), edge.getValue(), 1)
.put(edge.getValue(), edge.getKey(), 1);
All tensor operations can be viewed in the
core.tensor
and core.matrix
Javadoc. The Matrix
class extends the concept
of tensors with additional operations, like transposition, matrix multiplication,
and row and column access. Under the
hood, matrices linearly store elements and use
computations to transform the (row, col)
position of their elements to respective
positions. The outcome of some methods inherited
from tensors may need to be typecast back into a
matrix (e.g., for all in-place operations).
Operations can be split into arithmetics that combine the values
of two tensors to create a new one (e.g., Tensor add(Tensor)
),
in-place arithmetics that alter a tensor without creating
a new one (e.g., Tensor selfAdd(Tensor)
),
summary statistics that output simple numeric values (e.g., double Tensor.sum()
),
and element getters and setters.
In-place arithmetics follow the same naming
conventions of base arithmetics but their method names begin with a "self"
prefix for pairwise operations and a "setTo" prefix
for unary operations. Since they do not allocate new memory,
prefer them for intermediate calculation steps.
For example, the following code can be
used for creating and normalizing a tensor of
ones without using any additional memory.
Tensor normalized = new DenseTensor(10)
.setToOnes()
.setToNormalized();
Initialize a dense or sparse tensor -both of which represent one-dimensional vectors- with its number
of elements. If there are many zeros expected,
prefer using a sparse tensor. For example, one-hot encodings for classification
problems can be generated with the following
code. This creates a dense tensor with
numClasses
elements and puts at
element classId
the value 1:
int classId = 1;
int numClasses = 5;
Tensor oneHotEncoding = new mklab.JGNN.tensor.DenseTensor(numClasses).set(classId, 1); // creates the tensor [0,1,0,0,0]
The above snippets all make use of numerical node identifiers. To
manage these, JGNN provides an IdConverter
class;
convert hashable objects (typically strings) to identifiers by calling
IdConverter.getOrCreateId(object)
. Also use
converters to one-hot encode class labels. To search only for previously
registered identifiers, use IdConverter.get(object)
.
For example, construct a label matrix with the following snippet.
In this, nodeLabels
is a dictionary
from node identifiers to node labels that is being converted to a sparse matrix.
IdConverter nodeIds = new IdConverter();
IdConverter classIds = new IdConverter();
for(Entry<String, String> entry : nodeLabels) {
nodeids.getOrCreateId(entry.getKey());
classIds.getOrCreateId(entry.getValue());
}
Matrix labels = new SparseMatrix(nodeIds.size(), classIds.size());
for(Entry<String, String> entry : nodeLabels)
labels.put(nodeids.get(entry.getKey()), classIds.get(entry.getValue()), 1);
Reverse-search the converter to obtain the original object
of predictions per IdConverter.get(String)
. The following example
accesses one row of a label matrix, performs and argmax operation to find the position of the
maximum element, and reconstruct the label for the corresponding row with reverse-search.
long nodeId = nodeIds.get("nodeName");
Tensor prediction = labels.accessRow(nodeId);
long predictedClassId = prediction.argmax();
System.out.println(classIds.get(predictedClassId));
4.2. Node classification
Node classification models can be backpropagated by considering a list of node indeces and desired predictions for those nodes. We first show an automation of the training process that controls it in a predictable manner.
This section is under construction.Slice nodes = dataset.samples().getSlice().shuffle(100); // or nodes = new Slice(0, numNodes).shuffle(100);
Model model = modelBuilder()
.getModel()
.init(new XavierNormal())
.train(trainer,
nodes.samplesAsFeatures(),
dataset.labels(),
nodes.range(0, trainSplit),
nodes.range(trainSplit, validationSplit));
4.3. Graph classification
Most neural network architectures are designed with the idea
of learning to classify nodes or samples. However, GNNs also
provide the capability to classify entire graphs based on
their structure. To define architectures for graph classification, we use
the generic LayeredBuilder
class. The main
difference compared to traditional neural networks is
that architecture inputs do not all exhibit the same
size (e.g., some graphs may have more nodes than others)
and therefore cannot be organized into tensors of common
dimensions. Instead, assume that training data are stored in the
following lists:
ArrayList adjacencyMatrices = new ArrayList();
ArrayList nodeFeatures = new ArrayList();
ArrayList graphLabels = new ArrayList();
The LayeredBuilder
class introduces the
input variable h0
for sample
features. We can use it to pass node features to the
architecture, so we only need to add a second input
storing the (sparse) adjacency matrix:
.var("A")
We can then proceed to define a GNN architecture, for instance as explained in previous tutorials. This time, though, we aim to classify entire graphs rather than individual nodes. For this reason, we need to pool top layer node representations, for instance by averaging them across all nodes:
.layer("h{l+1}=softmax(mean(h{l}, dim: 'row'))")
Finally, we need to set up the top layer as the built
model's output per: .out("h{l}")
An example architecture following these principles follows:
ModelBuilder builder = new LayeredBuilder()
.var("A")
.config("features", nodeLabelIds.size())
.config("classes", graphLabelIds.size())
.config("hidden", 16)
.layer("h{l+1}=relu(A@(h{l}@matrix(features, hidden)))")
.layer("h{l+1}=relu(A@(h{l}@matrix(hidden, classes)))")
.layer("h{l+1}=softmax(mean(h{l}, dim: 'row'))")
.out("h{l}");
For the time being, training architectures like the above on prepared data requires manually calling the backpropagation for each epoch and each graph in the training batch. To do this, first retrieve the model and initialize its parameters:
Model model = builder.getModel()
.init(new XavierNormal());
Next, define a loss function and set up a batch
optimization strategy wrapping any base optimizer and
accumulating parameter updates until
BatchOptimizer.updateAll()
is called later
on:
Loss loss = new CategoricalCrossEntropy();
BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.01));
Finally, training can be conducted by iterating through
epochs and training samples and appropriately calling
the Model.train
for combinations of node
features and graph adjacency matrix inputs and graph
label outputs. At the end of each batch (e.g., each
epoch), don't forget to call the
optimizer.updateAll()
method to apply the
accumulated gradients. This process can be realized with
the following code:
for(int epoch=0; epoch<300; epoch++) {
for(int graphId=0; graphId<graphLabels.size(); graphId++) {
Matrix adjacency = adjacencyMatrices.get(graphId);
Matrix features = nodeFeatures.get(graphId);
Tensor label = graphLabels.get(graphId);
model.train(loss, optimizer,
Arrays.asList(features, adjacency),
Arrays.asList(label));
}
optimizer.updateAll();
}
To speed up graph classification, use JGNN's parallelization capabilities to calculate gradients across multiple threads. Parallelization for node classification holds little meaning, as the same propagation mechanism needs to be run on the same graph in parallel. However, this process yields substantial speedup for the graph classification problem. Parallelization can use JGNN's thread pooling to perform gradients, wait for the conclusion of submitted tasks, and then apply the accumulated gradient updates. This is achieved through a batch optimizer that accumulates gradients in the following example:
for(int epoch=0; epoch<500; epoch++) {
// gradient update
for(int graphId=0; graphId<dtrain.adjucency.size(); graphId++) {
int graphIdentifier = graphId;
ThreadPool.getInstance().submit(new Runnable() {
@Override
public void run() {
Matrix adjacency = dtrain.adjucency.get(graphIdentifier);
Matrix features= dtrain.features.get(graphIdentifier);
Tensor graphLabel = dtrain.labels.get(graphIdentifier).asRow();
model.train(loss, optimizer,
Arrays.asList(features, adjacency),
Arrays.asList(graphLabel));
}
});
}
ThreadPool.getInstance().waitForConclusion(); // waits for all gradients to finish calculating
optimizer.updateAll();
double acc = 0.0;
for(int graphId=0; graphId<dtest.adjucency.size(); graphId++) {
Matrix adjacency = dtest.adjucency.get(graphId);
Matrix features= dtest.features.get(graphId);
Tensor graphLabel = dtest.labels.get(graphId);
if(model.predict(Arrays.asList(features, adjacency)).get(0).argmax()==graphLabel.argmax())
acc += 1;
System.out.println("iter = " + epoch + " " + acc/dtest.adjucency.size());
}
}