Package mklab.JGNN.nn
Class NNOperation
java.lang.Object
mklab.JGNN.nn.NNOperation
- Direct Known Subclasses:
Add
,Attention
,Complement
,Concat
,Dropout
,Exp
,From
,Gather
,Identity
,L1
,Log
,LRelu
,MatMul
,Max
,Mean
,Multiply
,NExp
,Parameter
,PRelu
,Reduce
,Relu
,Repeat
,Reshape
,Sigmoid
,SoftMax
,Sort
,Sum
,Tanh
,To
,Transpose
,Variable
This class defines an abstract neural network operation with forward and
backpropagation capabilities. Defined operations create execution trees based
on input dependencies, which can then be run by
Operations are thread-safe in the sense that they store gradients for backward passes on different objects across different threads. This, way models can perform learning passes which are all synchronized when eventually backpropagation feeds
The internal state of operations can be obtained with
Model
instances to
make predictions. Creating the execution tree can be done by using the
addInput(NNOperation)
method. The correct number of inputs should be
added to each operation. Compliance to this rule needs to be checked by
individual operations during forward passes. Operations are thread-safe in the sense that they store gradients for backward passes on different objects across different threads. This, way models can perform learning passes which are all synchronized when eventually backpropagation feeds
Parameter
updates to an
Optimizer
. The internal state of operations can be obtained with
getPrediction()
to obtain their last Tensor
output (this
output is depends on the thread calling the operation) and
getLastTapeError()
to obtain the last gradient obtained through
backpropagation.- Author:
- Emmanouil Krasanakis
-
Nested Class Summary
-
Field Summary
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionaddInput
(NNOperation inputComponent) protected void
final void
protected NNOperation.ThreadData
data()
describe()
Retrieves a concise description of the operation that shows metadata and potential data descriptions processed by the current thread.protected abstract Tensor
Retrieves a list of input operations within a model's execution graph.final Tensor
double
getNonLinearity
(int inputId, double inputMass, double outputNonLinearity) Retrieves the degree of non-linearity of the operation to be used byVariancePreservingInitializer
.Retrieves a list of output operations within a model's execution graph.final Tensor
Provides a simple description to show when drawing .dot format diagrams.boolean
Checks whether the operation's output should be cached given that it is a constant.boolean
Checks whether the operation yields a constant output, so that propagation does not try to compute partial derivatives for it.protected boolean
isInputNeededForDerivative
(int inputId) protected boolean
protected abstract Tensor
final Tensor
Performs a forward pass in the operation without inducing any kind of learning or storing the outcome.final Tensor
Performs a forward pass in the operation without inducing any kind of learning or storing the outcome.final Tensor
void
setDescription
(String description) protected void
trainParameters
(Optimizer optimizer, Tensor error) view()
Retrieves a string that views internal data being processed by the current thread, including gradients.
-
Field Details
-
debugging
public boolean debugging
-
-
Constructor Details
-
NNOperation
protected NNOperation()
-
-
Method Details
-
data
-
setDescription
-
getDescription
-
describe
Retrieves a concise description of the operation that shows metadata and potential data descriptions processed by the current thread.- Returns:
- A
String
description. - See Also:
-
view
Retrieves a string that views internal data being processed by the current thread, including gradients. This may- Returns:
- A
String
view. - See Also:
-
getInputs
Retrieves a list of input operations within a model's execution graph.- Returns:
- A list of
NNOperation
s.
-
getOutputs
Retrieves a list of output operations within a model's execution graph.- Returns:
- A list of
NNOperation
s.
-
isConstant
public boolean isConstant()Checks whether the operation yields a constant output, so that propagation does not try to compute partial derivatives for it.- Returns:
- A
boolean
value.
-
isCachable
public boolean isCachable()Checks whether the operation's output should be cached given that it is a constant. This returnsfalse
only for randomized components that yield different outputs from different inputs, such as dropouts.- Returns:
- A
boolean
values.
-
getNonLinearity
public double getNonLinearity(int inputId, double inputMass, double outputNonLinearity) Retrieves the degree of non-linearity of the operation to be used byVariancePreservingInitializer
. Default is one for operations like addition, multiplication, and matrix multiplication, and is different only for activation functions.- Parameters:
inputId
- The input for which the non-linearity is calculated.inputMass
- The fraction of (matrix) parameters affecting the calculation coming from the respective input.outputNonLinearity
- The output's non-linearity gain.- Returns:
double
describing the non-linearity.
-
clearPrediction
public final void clearPrediction() -
addInput
-
getLastTapeError
-
getPrediction
-
isOutputNeededForDerivative
protected boolean isOutputNeededForDerivative() -
isInputNeededForDerivative
protected boolean isInputNeededForDerivative(int inputId) -
runPrediction
-
run
Performs a forward pass in the operation without inducing any kind of learning or storing the outcome. This is just a way to replicate the operation at the tensor level and does not affect or is affected by any dependent inputsaddInput(mklab.JGNN.nn.NNOperation)
.- Parameters:
inputs
- A list of input tensors needed by the operation.- Returns:
- A Tensor with the operation's outcome.
-
run
Performs a forward pass in the operation without inducing any kind of learning or storing the outcome. This is just a way to replicate the operation at the tensor level and does not affect or is affected by any dependent inputsaddInput(mklab.JGNN.nn.NNOperation)
.- Parameters:
inputs
- A list of input tensors needed by the operation.- Returns:
- A Tensor with the operation's outcome.
-
trainParameters
-
forward
-
partial
-
getSimpleDescription
Provides a simple description to show when drawing .dot format diagrams.- Returns:
- A string description, usually the component's class name.
-
runPredictionAndAutosize
-
autosize
-