Skip to the content.

A3C Cartepole

The figures in this link explain what A3C is doing. Simple Reinforcement Learning with Tensorflow Part 8: Asynchronous Actor-Critic Agents (A3C)

Architecture

arch

Sequence

arch

Calculation

In this example, value net and policy net are two independent NN, and they are treated independently

Global Thread

AsyncGlobal polls the queue and updates the network timely. Then update target network in some condition in double NN case

    public void run() {

        while (!isTrainingComplete() && running) {
            if (!queue.isEmpty()) {
                Pair<Gradient[], Integer> pair = queue.poll();
                T.addAndGet(pair.getSecond());
                Gradient[] gradient = pair.getFirst();
                synchronized (this) {
                    current.applyGradient(gradient, pair.getSecond());
                }
                if (a3cc.getTargetDqnUpdateFreq() != -1
                                && T.get() / a3cc.getTargetDqnUpdateFreq() > (T.get() - pair.getSecond())
                                                / a3cc.getTargetDqnUpdateFreq()) {
                    log.info("TARGET UPDATE at T = " + T.get());
                    synchronized (this) {
                        target.copy(current);
                    }
                }
            }
        }

    }

Workers send gradients for value net and policy net as an array and update them accordingly

ActorCriticSeparate.applyGradient(Gradient[] gradient, int batchSize) {
        MultiLayerConfiguration valueConf = valueNet.getLayerWiseConfigurations();
        int valueIterationCount = valueConf.getIterationCount();
        valueNet.getUpdater().update(valueNet, gradient[0], valueIterationCount, batchSize);
        valueNet.params().subi(gradient[0].gradient());
        Collection<IterationListener> valueIterationListeners = valueNet.getListeners();
        if (valueIterationListeners != null && valueIterationListeners.size() > 0) {
            for (IterationListener listener : valueIterationListeners) {
                listener.iterationDone(valueNet, valueIterationCount);
            }
        }
        valueConf.setIterationCount(valueIterationCount + 1);

        MultiLayerConfiguration policyConf = policyNet.getLayerWiseConfigurations();
        int policyIterationCount = policyConf.getIterationCount();
        policyNet.getUpdater().update(policyNet, gradient[1], policyIterationCount, batchSize);
        policyNet.params().subi(gradient[1].gradient());
        Collection<IterationListener> policyIterationListeners = policyNet.getListeners();
        if (policyIterationListeners != null && policyIterationListeners.size() > 0) {
            for (IterationListener listener : policyIterationListeners) {
                listener.iterationDone(policyNet, policyIterationCount);
            }
        }
        policyConf.setIterationCount(policyIterationCount + 1);
}

The updater for both of them are initiated in ActorCriticFactorySeparateStdDense.buildActorCritic() as default updater: Adam

Worker Thread

The worker is running in similar way as DQN does. There are three minor differences:

About above differences, they could be proved in *.trainSubEpoch() and A3CThreadDiscrete.calcGradient()

The last one is a special loss function for policy network which implemented as ActorCriticLoss

ActorCriticLoss

    public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        INDArray output = activationFn.getActivation(preOutput.dup(), true).addi(1e-5);
        INDArray logOutput = Transforms.log(output, true);
        INDArray entropyDev = logOutput.addi(1);
        INDArray dLda = output.rdivi(labels).subi(entropyDev.muli(BETA)).negi();
        INDArray grad = activationFn.backprop(preOutput, dLda).getFirst();

        if (mask != null) {
            LossUtil.applyMask(grad, mask);
        }
        return grad;
    }

As declared in reference, the loss function is L = log(π)(R - V) + BETA * H(π)

Make A = R - V

π is output of policy network, θ is W in previous networks descriptions.

Get π and log(π)

        INDArray output = activationFn.getActivation(preOutput.dup(), true).addi(1e-5);
        INDArray logOutput = Transforms.log(output, true);

Get (- d(H) / d(π))

H(π) = - (Σπi * log(πi))

d(H) / d(πi) = (1 + log(πi))

        INDArray entropyDev = logOutput.addi(1);

Get d(L) / d(π)

d(L) / d(πi) = A / πi - (1 + log(πi))

A are labels computed in trainSubEpoch

        INDArray dLda = output.rdivi(labels).subi(entropyDev.muli(BETA)).negi();

d(L) / d(θ) = (d(L) / d(π)) * (d(π) / d(θ))

In following sentence, backprop is doing this job and dLda is d(L) / d(π)

        INDArray grad = activationFn.backprop(preOutput, dLda).getFirst();