Package mklab.JGNN.nn

Class NNOperation

java.lang.Object
mklab.JGNN.nn.NNOperation
Direct Known Subclasses:
Add, Attention, Complement, Concat, Dropout, Exp, From, Gather, Identity, L1, Log, LRelu, MatMul, Max, Mean, Multiply, NExp, Parameter, PRelu, Reduce, Relu, Repeat, Reshape, Sigmoid, SoftMax, Sort, Sum, Tanh, To, Transpose, Variable

public abstract class NNOperation extends Object
This class defines an abstract neural network operation with forward and backpropagation capabilities. Defined operations create execution trees based on input dependencies, which can then be run by Model instances to make predictions. Creating the execution tree can be done by using the addInput(NNOperation) method. The correct number of inputs should be added to each operation. Compliance to this rule needs to be checked by individual operations during forward passes.
Operations are thread-safe in the sense that they store gradients for backward passes on different objects across different threads. This, way models can perform learning passes which are all synchronized when eventually backpropagation feeds Parameter updates to an Optimizer.
The internal state of operations can be obtained with getPrediction() to obtain their last Tensor output (this output is depends on the thread calling the operation) and getLastTapeError() to obtain the last gradient obtained through backpropagation.
Author:
Emmanouil Krasanakis
  • Field Details

    • debugging

      public boolean debugging
  • Constructor Details

    • NNOperation

      protected NNOperation()
  • Method Details

    • data

      protected NNOperation.ThreadData data()
    • setDescription

      public void setDescription(String description)
    • getDescription

      public String getDescription()
    • describe

      public String describe()
      Retrieves a concise description of the operation that shows metadata and potential data descriptions processed by the current thread.
      Returns:
      A String description.
      See Also:
    • view

      public String view()
      Retrieves a string that views internal data being processed by the current thread, including gradients. This may
      Returns:
      A String view.
      See Also:
    • getInputs

      public ArrayList<NNOperation> getInputs()
      Retrieves a list of input operations within a model's execution graph.
      Returns:
      A list of NNOperations.
    • getOutputs

      public ArrayList<NNOperation> getOutputs()
      Retrieves a list of output operations within a model's execution graph.
      Returns:
      A list of NNOperations.
    • isConstant

      public boolean isConstant()
      Checks whether the operation yields a constant output, so that propagation does not try to compute partial derivatives for it.
      Returns:
      A boolean value.
    • isCachable

      public boolean isCachable()
      Checks whether the operation's output should be cached given that it is a constant. This returns false only for randomized components that yield different outputs from different inputs, such as dropouts.
      Returns:
      A boolean values.
    • getNonLinearity

      public double getNonLinearity(int inputId, double inputMass, double outputNonLinearity)
      Retrieves the degree of non-linearity of the operation to be used by VariancePreservingInitializer. Default is one for operations like addition, multiplication, and matrix multiplication, and is different only for activation functions.
      Parameters:
      inputId - The input for which the non-linearity is calculated.
      inputMass - The fraction of (matrix) parameters affecting the calculation coming from the respective input.
      outputNonLinearity - The output's non-linearity gain.
      Returns:
      double describing the non-linearity.
    • clearPrediction

      public final void clearPrediction()
    • addInput

      public NNOperation addInput(NNOperation inputComponent)
    • getLastTapeError

      public final Tensor getLastTapeError()
    • getPrediction

      public final Tensor getPrediction()
    • isOutputNeededForDerivative

      protected boolean isOutputNeededForDerivative()
    • isInputNeededForDerivative

      protected boolean isInputNeededForDerivative(int inputId)
    • runPrediction

      public final Tensor runPrediction()
    • run

      public final Tensor run(List<Tensor> inputs)
      Performs a forward pass in the operation without inducing any kind of learning or storing the outcome. This is just a way to replicate the operation at the tensor level and does not affect or is affected by any dependent inputs addInput(mklab.JGNN.nn.NNOperation).
      Parameters:
      inputs - A list of input tensors needed by the operation.
      Returns:
      A Tensor with the operation's outcome.
    • run

      public final Tensor run(Tensor... inputs)
      Performs a forward pass in the operation without inducing any kind of learning or storing the outcome. This is just a way to replicate the operation at the tensor level and does not affect or is affected by any dependent inputs addInput(mklab.JGNN.nn.NNOperation).
      Parameters:
      inputs - A list of input tensors needed by the operation.
      Returns:
      A Tensor with the operation's outcome.
    • trainParameters

      protected void trainParameters(Optimizer optimizer, Tensor error)
    • forward

      protected abstract Tensor forward(List<Tensor> inputs)
    • partial

      protected abstract Tensor partial(int inputId, List<Tensor> inputs, Tensor output, Tensor error)
    • getSimpleDescription

      public String getSimpleDescription()
      Provides a simple description to show when drawing .dot format diagrams.
      Returns:
      A string description, usually the component's class name.
    • runPredictionAndAutosize

      public Tensor runPredictionAndAutosize()
    • autosize

      protected void autosize(ArrayList<Tensor> lastInputs)