Package mklab.JGNN.nn
Class NNOperation
java.lang.Object
mklab.JGNN.nn.NNOperation
- Direct Known Subclasses:
Add,Attention,Complement,Concat,Dropout,Exp,From,Gather,Identity,L1,Log,LRelu,MatMul,Max,Mean,Multiply,NExp,Parameter,PRelu,Reduce,Relu,Repeat,Reshape,Sigmoid,SoftMax,Sort,Sum,Tanh,To,Transpose,Variable
This class defines an abstract neural network operation with forward and
backpropagation capabilities. Defined operations create execution trees based
on input dependencies, which can then be run by
Operations are thread-safe in the sense that they store gradients for backward passes on different objects across different threads. This, way models can perform learning passes which are all synchronized when eventually backpropagation feeds
The internal state of operations can be obtained with
Model instances to
make predictions. Creating the execution tree can be done by using the
addInput(NNOperation) method. The correct number of inputs should be
added to each operation. Compliance to this rule needs to be checked by
individual operations during forward passes. Operations are thread-safe in the sense that they store gradients for backward passes on different objects across different threads. This, way models can perform learning passes which are all synchronized when eventually backpropagation feeds
Parameter updates to an
Optimizer. The internal state of operations can be obtained with
getPrediction() to obtain their last Tensor output (this
output is depends on the thread calling the operation) and
getLastTapeError() to obtain the last gradient obtained through
backpropagation.- Author:
- Emmanouil Krasanakis
-
Nested Class Summary
Nested Classes -
Field Summary
Fields -
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionaddInput(NNOperation inputComponent) protected voidfinal voidprotected NNOperation.ThreadDatadata()describe()Retrieves a concise description of the operation that shows metadata and potential data descriptions processed by the current thread.protected abstract TensorRetrieves a list of input operations within a model's execution graph.final TensordoublegetNonLinearity(int inputId, double inputMass, double outputNonLinearity) Retrieves the degree of non-linearity of the operation to be used byVariancePreservingInitializer.Retrieves a list of output operations within a model's execution graph.final TensorProvides a simple description to show when drawing .dot format diagrams.booleanChecks whether the operation's output should be cached given that it is a constant.booleanChecks whether the operation yields a constant output, so that propagation does not try to compute partial derivatives for it.protected booleanisInputNeededForDerivative(int inputId) protected booleanprotected abstract Tensorfinal TensorPerforms a forward pass in the operation without inducing any kind of learning or storing the outcome.final TensorPerforms a forward pass in the operation without inducing any kind of learning or storing the outcome.final TensorvoidsetDescription(String description) protected voidtrainParameters(Optimizer optimizer, Tensor error) view()Retrieves a string that views internal data being processed by the current thread, including gradients.
-
Field Details
-
debugging
public boolean debugging
-
-
Constructor Details
-
NNOperation
protected NNOperation()
-
-
Method Details
-
data
-
setDescription
-
getDescription
-
describe
Retrieves a concise description of the operation that shows metadata and potential data descriptions processed by the current thread.- Returns:
- A
Stringdescription. - See Also:
-
view
Retrieves a string that views internal data being processed by the current thread, including gradients. This may- Returns:
- A
Stringview. - See Also:
-
getInputs
Retrieves a list of input operations within a model's execution graph.- Returns:
- A list of
NNOperations.
-
getOutputs
Retrieves a list of output operations within a model's execution graph.- Returns:
- A list of
NNOperations.
-
isConstant
public boolean isConstant()Checks whether the operation yields a constant output, so that propagation does not try to compute partial derivatives for it.- Returns:
- A
booleanvalue.
-
isCachable
public boolean isCachable()Checks whether the operation's output should be cached given that it is a constant. This returnsfalseonly for randomized components that yield different outputs from different inputs, such as dropouts.- Returns:
- A
booleanvalues.
-
getNonLinearity
public double getNonLinearity(int inputId, double inputMass, double outputNonLinearity) Retrieves the degree of non-linearity of the operation to be used byVariancePreservingInitializer. Default is one for operations like addition, multiplication, and matrix multiplication, and is different only for activation functions.- Parameters:
inputId- The input for which the non-linearity is calculated.inputMass- The fraction of (matrix) parameters affecting the calculation coming from the respective input.outputNonLinearity- The output's non-linearity gain.- Returns:
doubledescribing the non-linearity.
-
clearPrediction
public final void clearPrediction() -
addInput
-
getLastTapeError
-
getPrediction
-
isOutputNeededForDerivative
protected boolean isOutputNeededForDerivative() -
isInputNeededForDerivative
protected boolean isInputNeededForDerivative(int inputId) -
runPrediction
-
run
Performs a forward pass in the operation without inducing any kind of learning or storing the outcome. This is just a way to replicate the operation at the tensor level and does not affect or is affected by any dependent inputsaddInput(mklab.JGNN.nn.NNOperation).- Parameters:
inputs- A list of input tensors needed by the operation.- Returns:
- A Tensor with the operation's outcome.
-
run
Performs a forward pass in the operation without inducing any kind of learning or storing the outcome. This is just a way to replicate the operation at the tensor level and does not affect or is affected by any dependent inputsaddInput(mklab.JGNN.nn.NNOperation).- Parameters:
inputs- A list of input tensors needed by the operation.- Returns:
- A Tensor with the operation's outcome.
-
trainParameters
-
forward
-
partial
-
getSimpleDescription
Provides a simple description to show when drawing .dot format diagrams.- Returns:
- A string description, usually the component's class name.
-
runPredictionAndAutosize
-
autosize
-