Package mklab.JGNN.adhoc
Class ModelTraining
java.lang.Object
mklab.JGNN.adhoc.ModelTraining
- Direct Known Subclasses:
AGFTraining
,SampleClassification
This is a helper class that automates the definition of training processes of
Model
instances by defining the number of epochs, loss functions,
number of batches and the ability to use ThreadPool
for parallelized
batch computations.- Author:
- Emmanouil Krasanakis
-
Field Summary
Modifier and TypeFieldDescriptionprotected int
protected Loss
protected int
protected BatchOptimizer
protected boolean
protected int
protected boolean
protected Loss
protected boolean
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionconfigFrom
(ModelBuilder modelBuilder) Retrieves the learning rate (lr), epochs, batches, and patience parameters from the configurations of agetBatchData
(int batch, int epoch) Returns a listBatchData
instance to be used for a specific batch and training epoch.getValidationData
(int epoch) Returns aBatchData
instance to be used for validation at a given training epoch.protected void
Performs any cleanup operations at the end of thetrain(Model)
loop.protected abstract void
onStartEpoch
(int epoch) Performs necessary training operations at the beginning of each epoch.setEpochs
(int epochs) Sets the maximum number of epochs for which training runs.Sets whichLoss
should be applied on training batches (the loss is averaged across batches, but is aggregated as a sum within each batch byBatchOptimizer
).setNumBatches
(int numBatches) Sets the number of batches training data slices should be split into.setOptimizer
(Optimizer optimizer) Sets anOptimizer
instance to controls parameter updates during training.setParallelizedStochasticGradientDescent
(boolean paralellization) Sets whether the training strategy should reflect stochastic gradient descent by randomly sampling from the training data samples.setPatience
(int patience) Sets the patience of the training strategy that performs early stopping.setValidationLoss
(Loss loss) Sets whichLoss
should be applied on validation data on each epoch.setVerbose
(boolean verbose) Deprecated.This method was available in earlier JGNN versions but will be gradually phased out.Trains the parameters of aModel
based on current settings and the data.Deprecated.This method's full implementation has been moved totrain(Model)
-
Field Details
-
optimizer
-
numBatches
protected int numBatches -
epochs
protected int epochs -
patience
protected int patience -
paralellization
protected boolean paralellization -
stochasticGradientDescent
protected boolean stochasticGradientDescent -
loss
-
validationLoss
-
verbose
protected boolean verbose
-
-
Constructor Details
-
ModelTraining
public ModelTraining()
-
-
Method Details
-
setVerbose
Deprecated.This method was available in earlier JGNN versions but will be gradually phased out. Instead, wrap the validation loss withinVerboseLoss
to replicate the same behavior.- Parameters:
verbose
- Whether an error message will be printed.- Returns:
- The model training instance.
-
setLoss
Sets whichLoss
should be applied on training batches (the loss is averaged across batches, but is aggregated as a sum within each batch byBatchOptimizer
). Model training mainly uses the loss'sLoss.derivative(Tensor, Tensor)
method, alongsideLoss.onEndEpoch()
andLoss.onEndTraining()
. If no validation loss is set, in which case the training loss is also used for validation.- Parameters:
loss
- The loss's instance.- Returns:
- The model training instance.
- See Also:
-
setValidationLoss
Sets whichLoss
should be applied on validation data on each epoch. The loss'sLoss.onEndEpoch()
,Loss.onEndTraining()
, andLoss.evaluate(Tensor, Tensor)
methods are used. In the case where validation is split into multiple instances of batch data, which may be necessary for complex scenarios like graph classification, the loss value is averaged across those batches. The methods mentioned above are not used by losses employed in training.- Parameters:
loss
- The loss's instance.- Returns:
- The model training instance.
- See Also:
-
setOptimizer
Sets anOptimizer
instance to controls parameter updates during training. If the provided optimizer is not an instance ofBatchOptimizer
, it is forcefully wrapped by the latter. Training calls the batch optimizer's update method after every batch. Each batch could contain multiple instances of batch data. However, the total number of applied gradient updates is always equal to the value set bysetNumBatches(int)
.- Parameters:
optimizer
- The desired optimizer.- Returns:
this
model training instance.- See Also:
-
setNumBatches
Sets the number of batches training data slices should be split into.- Parameters:
numBatches
- The desired number of batches. Default is 1.- Returns:
this
model training instance.- See Also:
-
setParallelizedStochasticGradientDescent
Sets whether the training strategy should reflect stochastic gradient descent by randomly sampling from the training data samples. Iftrue
, both this feature and acceptable thread-based paralellization is enabled. Parallelization uses JGNN'sThreadPool
.- Parameters:
paralellization
- A boolean value indicating whether this feature is enabled.- Returns:
this
model training instance.- See Also:
-
setEpochs
Sets the maximum number of epochs for which training runs. If no patience has been set, training runs for exactly this number of epochs.- Parameters:
epochs
- The maximum number of epochs.- Returns:
this
model training instance.- See Also:
-
setPatience
Sets the patience of the training strategy that performs early stopping. If training does not encounter a smaller validation loss for this number of epochs, it stops.- Parameters:
patience
- The number of patience epochs. Default is Integer.MAX_VALUE to effectively disable this feature and let training always reach the maximum number of set epochs.- Returns:
this
model training instance.- See Also:
-
train
public Model train(Model model, Matrix features, Matrix labels, Slice trainingSamples, Slice validationSamples) Deprecated.This method's full implementation has been moved totrain(Model)
This is a leftover method from an earlier version of JGNN's interface. For the time being, there is no good alternative, but it will be phased out. -
onStartEpoch
protected abstract void onStartEpoch(int epoch) Performs necessary training operations at the beginning of each epoch. These typically consist of dataset shuffling ifsetParallelizedStochasticGradientDescent(boolean)
is enabled.- Parameters:
epoch
- The epoch that now starts. Takes values 0,1,2,...,epochs-1, though early stopping may not reach the maximum number.
-
onEndTraining
protected void onEndTraining()Performs any cleanup operations at the end of thetrain(Model)
loop. This method is mostly used to "unlock" data insertions to the training process. -
getBatchData
Returns a listBatchData
instance to be used for a specific batch and training epoch. This list may have only one entry if the whole batch can be organized into one pair of model inputs-outputs (e.g., in node classification). This method is overloaded by classes extendingModelTraining
to let them work as dataset loaders. Batch data may be created anew each time, though they are often transparent views of parts of training data. Batch data generation may be parallelized, depending on the whethersetParallelizedStochasticGradientDescent(boolean)
is enabled. If some operations (e.g., data shuffling) take place at the beginning of each epoch, they instead reside in the#startEpoch()
method.- Parameters:
batch
- The batch identifier. Takes values 0,1,2,..,numBatches-1.epoch
- The epoch in which the batch is extracted. Takes values 0,1,2,...,epochs-1, though early stopping may not reach the maximum number.- Returns:
- An list of batch data instances.
-
getValidationData
Returns aBatchData
instance to be used for validation at a given training epoch. This list may have only one entry if the whole batch can be organized into one pair of model inputs-outputs (e.g., in node classification). This method is overloaded by classes extendingModelTraining
to let them work as dataset loaders. Batch data may be created anew each time, though they are often transparent views of parts of training data.- Parameters:
epoch
- The epoch in which the batch is extracted. Takes values 0,1,2,...,epochs-1, though early stopping may not reach the maximum number.- Returns:
- An list of batch data instances.
-
train
Trains the parameters of aModel
based on current settings and the data.- Parameters:
model
- The model instance to train.
-
configFrom
Retrieves the learning rate (lr), epochs, batches, and patience parameters from the configurations of a- Parameters:
modelBuilder
-- Returns:
this
model training instance.
-