Class ModelTraining

java.lang.Object
mklab.JGNN.adhoc.ModelTraining
Direct Known Subclasses:
AGFTraining, SampleClassification

public abstract class ModelTraining extends Object
This is a helper class that automates the definition of training processes of Model instances by defining the number of epochs, loss functions, number of batches and the ability to use ThreadPool for parallelized batch computations.
Author:
Emmanouil Krasanakis
  • Field Details

    • optimizer

      protected BatchOptimizer optimizer
    • numBatches

      protected int numBatches
    • epochs

      protected int epochs
    • patience

      protected int patience
    • paralellization

      protected boolean paralellization
    • stochasticGradientDescent

      protected boolean stochasticGradientDescent
    • loss

      protected Loss loss
    • validationLoss

      protected Loss validationLoss
    • verbose

      protected boolean verbose
  • Constructor Details

    • ModelTraining

      public ModelTraining()
  • Method Details

    • setVerbose

      public ModelTraining setVerbose(boolean verbose)
      Deprecated.
      This method was available in earlier JGNN versions but will be gradually phased out. Instead, wrap the validation loss within VerboseLoss to replicate the same behavior.
      Parameters:
      verbose - Whether an error message will be printed.
      Returns:
      The model training instance.
    • setLoss

      public ModelTraining setLoss(Loss loss)
      Sets which Loss should be applied on training batches (the loss is averaged across batches, but is aggregated as a sum within each batch by BatchOptimizer). Model training mainly uses the loss's Loss.derivative(Tensor, Tensor) method, alongside Loss.onEndEpoch() and Loss.onEndTraining(). If no validation loss is set, in which case the training loss is also used for validation.
      Parameters:
      loss - The loss's instance.
      Returns:
      The model training instance.
      See Also:
    • setValidationLoss

      public ModelTraining setValidationLoss(Loss loss)
      Sets which Loss should be applied on validation data on each epoch. The loss's Loss.onEndEpoch(), Loss.onEndTraining(), and Loss.evaluate(Tensor, Tensor) methods are used. In the case where validation is split into multiple instances of batch data, which may be necessary for complex scenarios like graph classification, the loss value is averaged across those batches. The methods mentioned above are not used by losses employed in training.
      Parameters:
      loss - The loss's instance.
      Returns:
      The model training instance.
      See Also:
    • setOptimizer

      public ModelTraining setOptimizer(Optimizer optimizer)
      Sets an Optimizer instance to controls parameter updates during training. If the provided optimizer is not an instance of BatchOptimizer, it is forcefully wrapped by the latter. Training calls the batch optimizer's update method after every batch. Each batch could contain multiple instances of batch data. However, the total number of applied gradient updates is always equal to the value set by setNumBatches(int).
      Parameters:
      optimizer - The desired optimizer.
      Returns:
      this model training instance.
      See Also:
    • setNumBatches

      public ModelTraining setNumBatches(int numBatches)
      Sets the number of batches training data slices should be split into.
      Parameters:
      numBatches - The desired number of batches. Default is 1.
      Returns:
      this model training instance.
      See Also:
    • setParallelizedStochasticGradientDescent

      public ModelTraining setParallelizedStochasticGradientDescent(boolean paralellization)
      Sets whether the training strategy should reflect stochastic gradient descent by randomly sampling from the training data samples. If true, both this feature and acceptable thread-based paralellization is enabled. Parallelization uses JGNN's ThreadPool.
      Parameters:
      paralellization - A boolean value indicating whether this feature is enabled.
      Returns:
      this model training instance.
      See Also:
    • setEpochs

      public ModelTraining setEpochs(int epochs)
      Sets the maximum number of epochs for which training runs. If no patience has been set, training runs for exactly this number of epochs.
      Parameters:
      epochs - The maximum number of epochs.
      Returns:
      this model training instance.
      See Also:
    • setPatience

      public ModelTraining setPatience(int patience)
      Sets the patience of the training strategy that performs early stopping. If training does not encounter a smaller validation loss for this number of epochs, it stops.
      Parameters:
      patience - The number of patience epochs. Default is Integer.MAX_VALUE to effectively disable this feature and let training always reach the maximum number of set epochs.
      Returns:
      this model training instance.
      See Also:
    • train

      public Model train(Model model, Matrix features, Matrix labels, Slice trainingSamples, Slice validationSamples)
      Deprecated.
      This method's full implementation has been moved to train(Model)
      This is a leftover method from an earlier version of JGNN's interface. For the time being, there is no good alternative, but it will be phased out.
    • onStartEpoch

      protected abstract void onStartEpoch(int epoch)
      Performs necessary training operations at the beginning of each epoch. These typically consist of dataset shuffling if setParallelizedStochasticGradientDescent(boolean) is enabled.
      Parameters:
      epoch - The epoch that now starts. Takes values 0,1,2,...,epochs-1, though early stopping may not reach the maximum number.
    • onEndTraining

      protected void onEndTraining()
      Performs any cleanup operations at the end of the train(Model) loop. This method is mostly used to "unlock" data insertions to the training process.
    • getBatchData

      protected abstract List<BatchData> getBatchData(int batch, int epoch)
      Returns a list BatchData instance to be used for a specific batch and training epoch. This list may have only one entry if the whole batch can be organized into one pair of model inputs-outputs (e.g., in node classification). This method is overloaded by classes extending ModelTraining to let them work as dataset loaders. Batch data may be created anew each time, though they are often transparent views of parts of training data. Batch data generation may be parallelized, depending on the whether setParallelizedStochasticGradientDescent(boolean) is enabled. If some operations (e.g., data shuffling) take place at the beginning of each epoch, they instead reside in the #startEpoch() method.
      Parameters:
      batch - The batch identifier. Takes values 0,1,2,..,numBatches-1.
      epoch - The epoch in which the batch is extracted. Takes values 0,1,2,...,epochs-1, though early stopping may not reach the maximum number.
      Returns:
      An list of batch data instances.
    • getValidationData

      protected abstract List<BatchData> getValidationData(int epoch)
      Returns a BatchData instance to be used for validation at a given training epoch. This list may have only one entry if the whole batch can be organized into one pair of model inputs-outputs (e.g., in node classification). This method is overloaded by classes extending ModelTraining to let them work as dataset loaders. Batch data may be created anew each time, though they are often transparent views of parts of training data.
      Parameters:
      epoch - The epoch in which the batch is extracted. Takes values 0,1,2,...,epochs-1, though early stopping may not reach the maximum number.
      Returns:
      An list of batch data instances.
    • train

      public Model train(Model model)
      Trains the parameters of a Model based on current settings and the data.
      Parameters:
      model - The model instance to train.
    • configFrom

      public ModelTraining configFrom(ModelBuilder modelBuilder)
      Retrieves the learning rate (lr), epochs, batches, and patience parameters from the configurations of a
      Parameters:
      modelBuilder -
      Returns:
      this model training instance.