GravesLSTM
Suppose there is no peephole.
Architecture
The architecture of the rnn cell is like the following figure shows:
Calculation
As annotation declared in LSTMHelpers.java, this class is a implementation of this document: vector impl
The calculation is straightforward except one point. There is anothe style of calculation as demonstraed in this document: recusion impl
As declared in this document, the gradient of W/R is recursion of recursion of t, while in previous document the gradient is just recursion over t.
The point is that, the recursion in vector implementation is hidden in δ. Take lambdaWz for example:
δWz = Σ(t)(<δzHat(t), x(t)>)
δzHat(t) = δc(t) & i(t) & g'(zHat(t))
δc(t) is a function of t and (t + 1), then the same as δzHat(t).
Write δc(t) as function of t and function of (t + 1) as:
u(t) = δc(t)
v(t) = δy(t) & o(t) & h'(c(t))
u(t) = v(t) + u(t + 1) & f(t + 1)
= v(t) + (v(t + 1) + u(t + 2) & f(t + 2)) & f(t + 1)
= v(t) + v(t + 1) & f(t + 1) + u(t + 2) & f(t + 2) & f(t + 1)
= v(t) + v(t + 1) & f(t + 1) + (v(t + 2) + u(t + 3) & f(t + 3)) & f(t + 2) & f(t + 1)
= v(t) + v(t + 1) & f(t + 1) + v(t + 2) & f(t + 2) & f(t + 1) + u(t + 3) & f(t + 3) & f(t + 2) & f(t + 1)
= v(t) + Σ(r: r = t + 1 ~ T)(v(r) & π(s: s = t + 1, r)f(s))
u(t) is in a form of summary over multiplication recursion.
That is the same as δE(t)/δW(i, j) in [Gradient for an RNN]
Forward Pass
All symbols follows that in vector impl
LSTMHelpers::activateHelper
{
//W in form [Wz, Wf, Wi, Wo]
INDArray inputWeights = originalInputWeights;
// y(t - 1)
INDArray prevOutputActivations = originalPrevOutputActivations;
// c(t - 1)
INDArray prevMemCellState;
if (originalPrevMemCellState == null) {
prevMemCellState = Nd4j.create(new int[] {miniBatchSize, hiddenLayerSize}, 'f');
} else {
prevMemCellState = originalPrevMemCellState.dup('f');
}
// R in form [Rz, Rf, Ri, Ro]
INDArray recurrentWeightsIFOG = recurrentWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize)).dup('f');
boolean sigmoidGates = gateActivationFn instanceof ActivationSigmoid;
// Get g()
IActivation afn = layer.layerConf().getActivationFn();
INDArray outputActivations = null;
for (int iTimeIndex = 0; iTimeIndex < timeSeriesLength; iTimeIndex++) {
int time = iTimeIndex;
//Get x
INDArray miniBatchData = (is2dInput ? input : input.tensorAlongDimension(time, 1, 0)); //[Expected shape: [m,nIn]. Also deals with edge case of T=1, with 'time series' data of shape [m,nIn], equiv. to [m,nIn,1]
miniBatchData = Shape.toMmulCompatible(miniBatchData);
//ifogActivations = x * [Wz, Wf, Wo, Wi] = x * Wz, x * Wf, x * Wo, x * Wi
INDArray ifogActivations = miniBatchData.mmul(inputWeights); //Shape: [miniBatch,4*layerSize]
//ifogActivations += y * [Rz, Rf, Ro, Ri]
Nd4j.gemm(prevOutputActivations, recurrentWeightsIFOG, ifogActivations, false, false, 1.0, 1.0);
//ifogActivations += b
ifogActivations.addiRowVector(biases);
//zHat = inputActivations = ifogActivations[0]
INDArray inputActivations =
ifogActivations.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
//z = inputActivations = g(inputActivations) = g(zHat)
layer.layerConf().getActivationFn().getActivation(inputActivations, training);
//fHat = forgetGateActivations = ifogActivations[1]
INDArray forgetGateActivations = ifogActivations.get(NDArrayIndex.all(),
NDArrayIndex.interval(hiddenLayerSize, 2 * hiddenLayerSize));
//f = forgateGateActivations = sigmoid(forgateGateActivations)
gateActivationFn.getActivation(forgetGateActivations, training);
//iHat = inputModGateActivations = ifogActivations[3]
INDArray inputModGateActivations = ifogActivations.get(NDArrayIndex.all(),
NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize));
//i = inputModGateActivations = sigmoid(iHat)
gateActivationFn.getActivation(inputModGateActivations, training);
//c(t) = f & c(t - 1)
currentMemoryCellState = forgetGateActivations.muli(prevMemCellState);
//inputModMulInput = i & z
inputModMulInput = inputModGateActivations.muli(inputActivations); }
//c(t) = c(t) + i & z = f & c(t - 1) + i & z
l1BLAS.axpy(currentMemoryCellState.length(), 1.0, inputModMulInput, currentMemoryCellState); //currentMemoryCellState.addi(inputModMulInput)
//oHat = outputGateActivations = ifogActivations[2]
INDArray outputGateActivations = ifogActivations.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize));
//o = sigmoid(oHat)
gateActivationFn.getActivation(outputGateActivations, training);
//h(c(t))
INDArray currMemoryCellActivation = afn.getActivation(currentMemoryCellState.dup('f'), training);
//y(t) = c(t) & o
currHiddenUnitActivations = currMemoryCellActivation.muli(outputGateActivations); //Expected shape: [m,hiddenLayerSize]
}
}
Backward Pass
Volcabularies:
iz = zHat
ia = z
fa = f
ga = i
oa = o
currentMemoryCellState = c(t)
currHiddenUnitActivations = y
fwdPassOutputAsArrays = y
memCellState = c(t)
memCellActivations = h(c(t))
epsilonNext = ∆(t + 1)
nablaCellStateNext = δc(t + 1)
deltaifogNext = [deltaiNext, deltafNext, deltaoNext, deltagNext]
deltaiNext = δzHat(t)
deltafNext = δfHat(t)
deltaoNext = δoHat(t)
deltagNext = δiHat(t)
iwGradientsOut = [δWz, δWf, δWo, δWi]
rwGradientsOut = [δRz, δRf, δRo, δRi]
bGradientsOut = [δbz, δbf, δbo, δri]
wIFOG = [Rz, Rf, Ro, Ri]
nablaCellState = δc(t)
prevMemCellState = c(t - 1)
prevHiddenUnitActivation = y(t - 1)
currMemCellState(t)
epsilonSlice = ∆(t)
nablaOut = δy(t)
sigmahOfS = h(c(t))
ao = oa = o
deltao = δoHat(t)
af = f(t)
ag = i(t)
ai = z(t)
zi = zHat(t)
LSTMHelpers::activateHelper()
{
//Initiation toReturn
if (forBackprop) {
toReturn.fwdPassOutputAsArrays = new INDArray[timeSeriesLength];
toReturn.memCellState = new INDArray[timeSeriesLength];
toReturn.memCellActivations = new INDArray[timeSeriesLength];
toReturn.iz = new INDArray[timeSeriesLength];
toReturn.ia = new INDArray[timeSeriesLength];
toReturn.fa = new INDArray[timeSeriesLength];
toReturn.oa = new INDArray[timeSeriesLength];
toReturn.ga = new INDArray[timeSeriesLength];
}
for (int iTimeIndex = 0; iTimeIndex < timeSeriesLength; iTimeIndex++)
{
if (forBackprop) {
// iz[time] = zHat(t)
toReturn.iz[time] = inputActivations.dup('f');
}
// ia[time] = z(t)
if (forBackprop)
toReturn.ia[time] = inputActivations;
// fa[time] = f(t)
if (forBackprop)
toReturn.fa[time] = forgetGateActivations;
// ga[time] = i(t)
if (forBackprop)
toReturn.ga[time] = inputModGateActivations;
if (forBackprop) {
// currentMemoryCellState = c(t - 1) & f(t)
currentMemoryCellState = prevMemCellState.dup('f').muli(forgetGateActivations);
// inputModMulInput = z(t) & i(t)
inputModMulInput = inputModGateActivations.dup('f').muli(inputActivations);
}
// currentMemoryCellState = c(t)
l1BLAS.axpy(currentMemoryCellState.length(), 1.0, inputModMulInput, currentMemoryCellState); //currentMemoryCellState.addi(inputModMulInput)
// oa[time] = o(t)
if (forBackprop)
toReturn.oa[time] = outputGateActivations;
if (forBackprop) {
// currHiddenUnitActivations = y(t)
currHiddenUnitActivations = currMemoryCellActivation.dup('f').muli(outputGateActivations); //Expected shape: [m,hiddenLayerSize]
}
if (forBackprop) {
// fwdPassOutputAsArrays[time] = y(t)
toReturn.fwdPassOutputAsArrays[time] = currHiddenUnitActivations;
//memCellState[time] = c(t)
toReturn.memCellState[time] = currentMemoryCellState;
//memCellActivations = h(c(t))
toReturn.memCellActivations[time] = currMemoryCellActivation;
}
}
}
toReturn in in activateHelper passed into backpropGradientHelper as fwdPass
LSTMHelper::backpropGradientHelper()
{
// Get wIFOG = [Rz, Rf, Ro, Ri]
INDArray wIFOG = recurrentWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4 * hiddenLayerSize));
//Initialization
// ∆(t)
INDArray epsilonNext = Nd4j.create(new int[] {miniBatchSize, prevLayerSize, timeSeriesLength}, 'f'); //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T]
// δc(t + 1) = nablaCellStateNext
INDArray nablaCellStateNext = null;
// deltaifogNext = [δzHat(t), δfHat(t), δoHat(t), δiHat(t)]
INDArray deltaifogNext = Nd4j.create(new int[] {miniBatchSize, 4 * hiddenLayerSize}, 'f');
INDArray deltaiNext = deltaifogNext.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
INDArray deltafNext = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(hiddenLayerSize, 2 * hiddenLayerSize));
INDArray deltaoNext = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 3 * hiddenLayerSize));
INDArray deltagNext = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(3 * hiddenLayerSize, 4 * hiddenLayerSize));
// Initialize δW(t), δR(t), δb(t)
INDArray iwGradientsOut = gradientViews.get(inputWeightKey);
INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey); //Order: {I,F,O,G,FF,OO,GG}
INDArray bGradientsOut = gradientViews.get(biasWeightKey);
for (int iTimeIndex = timeSeriesLength - 1; iTimeIndex >= endIdx; iTimeIndex--) {
{
int time = iTimeIndex;
int inext = 1;
// Initialize
// nablaCellState = δc(t)
nablaCellState = Nd4j.create(new int[] {miniBatchSize, hiddenLayerSize}, 'f');
// prevMemCellState = c(t - 1)
INDArray prevMemCellState = (iTimeIndex == 0 ? fwdPass.prevMemCell : fwdPass.memCellState[time - inext]);
// preHiddenUnitActivations = y(t - 1)
INDArray prevHiddenUnitActivation =
(iTimeIndex == 0 ? fwdPass.prevAct : fwdPass.fwdPassOutputAsArrays[time - inext]);
// c(t)
INDArray currMemCellState = fwdPass.memCellState[time];
// ∆(t)
INDArray epsilonSlice = (is2dInput ? epsilon : epsilon.tensorAlongDimension(time, 1, 0)); //(w^{L+1}*(delta^{(L+1)t})^T)^T or equiv.
// nablaOut = δy(t) = ∆(t)
INDArray nablaOut = Shape.toOffsetZeroCopy(epsilonSlice, 'f'); //Shape: [m,n^L]
if (iTimeIndex != timeSeriesLength - 1) {
// deltaifogNext is function of (t + 1) from previous loop
// δy(t) = ∆(t) + Rz * δzHat(t + 1) + Rf * δfHat(t + 1) + Ro * δoHat(t + 1) + Ri * δiHat(t + 1)
Nd4j.gemm(deltaifogNext, wIFOG, nablaOut, false, true, 1.0, 1.0);
}
// sigmahOfS = h(c(t))
INDArray sigmahOfS = fwdPass.memCellActivations[time];
// ao = o(t)
INDArray ao = fwdPass.oa[time];
// deltao = δoHat(t)
INDArray deltao = deltaoNext;
// deltao = δoHat(t) = h(c(t)) & δy(t)
Nd4j.getExecutioner().exec(new OldMulOp(nablaOut, sigmahOfS, deltao));
if (sigmoidGates) {
// sigmaoPrimeOfZo = σ'(oHat(t))
INDArray sigmaoPrimeOfZo = Nd4j.getExecutioner().execAndReturn(new TimesOneMinus(ao.dup('f'))); //Equivalent to sigmoid deriv on zo
// deltao = δoHat(t) = h(c(t)) & δy(t) & σ'(oHat(t))
deltao.muli(sigmaoPrimeOfZo);
}
//Memory cell error:
// temp = h'(c(t)) & o(t) & δy(t)
INDArray temp = afn.backprop(currMemCellState.dup('f'), ao.muli(nablaOut)).getFirst(); //TODO activation functions with params
// δc(t) += temp = h'(c(t)) & o(t) & δy(t)
l1BLAS.axpy(nablaCellState.length(), 1.0, temp, nablaCellState);
if (iTimeIndex != timeSeriesLength - 1) {
// nextForgetGatesAs = f(t + 1)
INDArray nextForgetGateAs = fwdPass.fa[time + inext];
int length = nablaCellState.length();
// nablaCellState += f(t + 1) & δc(t + 1)
// δc(t) += f(t + 1) & δc(t + 1)
l1BLAS.axpy(length, 1.0, nextForgetGateAs.muli(nablaCellStateNext), nablaCellState); //nablaCellState.addi(nextForgetGateAs.mul(nablaCellStateNext))
}
// δc(t + 1) = δc(t)
nablaCellStateNext = workspace == null ? nablaCellState : nablaCellState.leverage();
// af = f(t)
INDArray af = fwdPass.fa[time];
INDArray deltaf = null;
if (iTimeIndex > 0 || prevMemCellState != null) { //For time == 0 && no prevMemCellState, equivalent to muli by 0
//Note that prevMemCellState may be non-null at t=0 for TBPTT
deltaf = deltafNext;
if (sigmoidGates) {
// deltafNext = deltaf = σ'(fHat(t))
Nd4j.getExecutioner().exec(new TimesOneMinus(af, deltaf));
// deltafNext = deltaf & δc(t)
deltaf.muli(nablaCellState);
// δfHat(t) = deltafNext = deltaf & c(t - 1) = σ'(fHat(t)) & δc(t) & c(t - 1)
deltaf.muli(prevMemCellState);
}
}
// ag = i
INDArray ag = fwdPass.ga[time];
// ai = z(t)
INDArray ai = fwdPass.ia[time];
INDArray deltag = deltagNext;
if (sigmoidGates) {
// deltag = deltagNext = σ'(iHat(t))
Nd4j.getExecutioner().exec(new TimesOneMinus(ag, deltag)); //Equivalent to sigmoid deriv on zg
//deltagNext = deltaNext & z(t)
deltag.muli(ai);
//δiHat(t) = deltagNext = σ'(iHat(t)) & z(t) & δc(t)
deltag.muli(nablaCellState);
}
// zi = zHat(t)
INDArray zi = fwdPass.iz[time];
INDArray deltai = deltaiNext;
// temp = δc(t) & i(t)
temp = Nd4j.getExecutioner().execAndReturn(
new OldMulOp(ag, nablaCellState, Nd4j.createUninitialized(deltai.shape(), 'f')));
// δzHat(t) = g'(zHat(t)) & δc(t) & i(t)
deltai.assign(afn.backprop(zi, temp).getFirst());
// iwGradientsOut = [δWz, δWf, δWo, δWi]
// prevLayerActivationSlice = x(t)
INDArray prevLayerActivationSlice =
Shape.toMmulCompatible(is2dInput ? input : input.tensorAlongDimension(time, 1, 0));
if (iTimeIndex > 0 || prevHiddenUnitActivation != null) { //For time == 0 && no prevMemCellState, equivalent to muli by 0
//Note that prevHiddenUnitActivations may be non-null at t=0 for TBPTT
//Again, deltaifog_current == deltaifogNext at this point... same array
// iwGradientsOut = x * [δzHat, δfHat, δoHat, δiHat]
Nd4j.gemm(prevLayerActivationSlice, deltaifogNext, iwGradientsOut, true, false, 1.0, 1.0);
} else {
// iwGradientsOut = [δzHat(t) * x, ?, δoHat(t) * x, δiHat(t) * x]
// As fHat part requires (t + 1) component
INDArray iwGradients_i =
iwGradientsOut.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
Nd4j.gemm(prevLayerActivationSlice, deltai, iwGradients_i, true, false, 1.0, 1.0);
INDArray iwGradients_og = iwGradientsOut.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
Nd4j.gemm(prevLayerActivationSlice, deltaog, iwGradients_og, true, false, 1.0, 1.0);
}
// rwGradientsIFOG = [δRz, δRf, δRo, δRi]
if (iTimeIndex > 0 || prevHiddenUnitActivation != null) {
//If t==0 and prevHiddenUnitActivation==null, equiv. to zeros(n^L,n^L), so dL/dW for recurrent weights
// will end up as 0 anyway
//At this point: deltaifog and deltaifogNext are the same thing...
//So what we are actually doing here is sum of (prevAct^transpose * deltaifog_current)
// [δRz, δRf, δRo, δRi] = [y(t - 1) * δzHat(t), y(t - 1) * δfHat(t), y(t - 1) * δoHat(t), y(t - 1) * δiHat(t)]
Nd4j.gemm(prevHiddenUnitActivation, deltaifogNext, rwGradientsIFOG, true, false, 1.0, 1.0);
}
// δx(t) = Wz * δzHat(t) + Wf * δfHat(t) + Wo * δoHat(t) + Wi * δiHat(t)
INDArray epsilonNextSlice = epsilonNext.tensorAlongDimension(time, 1, 0); //This slice: f order and contiguous, due to epsilonNext being defined as f order.
if (iTimeIndex > 0 || prevHiddenUnitActivation != null) {
//Note that prevHiddenUnitActivation may be non-null at t=0 for TBPTT
// δx(t) = Wz * δzHat(t) + Wf * δfHat(t) + Wo * δoHat(t) + Wi * δiHat(t)
Nd4j.gemm(deltaifogNext, inputWeights, epsilonNextSlice, false, true, 1.0, 1.0);
} else {
// Similar to W, without fHat part
//No contribution from forget gate at t=0
INDArray wi = inputWeights.get(NDArrayIndex.all(), NDArrayIndex.interval(0, hiddenLayerSize));
Nd4j.gemm(deltai, wi, epsilonNextSlice, false, true, 1.0, 1.0);
INDArray deltaog = deltaifogNext.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
INDArray wog = inputWeights.get(NDArrayIndex.all(),
NDArrayIndex.interval(2 * hiddenLayerSize, 4 * hiddenLayerSize));
Nd4j.gemm(deltaog, wog, epsilonNextSlice, false, true, 1.0, 1.0); //epsilonNextSlice.addi(deltao.mmul(woTranspose)).addi(deltag.mmul(wgTranspose));
}
}
}