## 6.4 Implementing A2C

We now have all of the elements needed to implement the Actor-Critic algorithms. The main components are

Advantage estimation –for example,

*n*-step returns or GAEValue loss and policy loss

The training loop

In what follows, we discuss an implementation of each of these components, ending with the training loop which brings them all together. Since Actor-Critic conceptually extends REINFORCE, it is implemented by inheriting from the `Reinforce` class.

Actor-Critic is also an on-policy algorithm since the actor component learns a policy using the policy gradient. Consequently, we train Actor-Critic algorithms using an on-policy `Memory`, such as `OnPolicyReplay` which trains in an episodic manner, or `OnPolicyBatchReplay` which trains using batches of data. The code that follows applies to either approach.

### 6.4.1 Advantage Estimation

#### 6.4.1.1 Advantage Estimation with *n*-Step Returns

The main trick when implementing the *n*-step *Q ^{π}* estimate is to notice that we have access to the rewards received in an episode in sequence. We can calculate the discounted sum of

*n*rewards for each element of the batch in parallel by taking advantage of vector arithmetic, as shown in Code 6.1. The approach goes as follows:

Initialize a vector

`rets`to populate with the*Q*-value estimates—that is,*n*-step returns (line 4).For efficiency, the computation is done from the last term to the first. The

`future_ret`is a placeholder to accumulate the summed rewards as we work backwards (line 5). It is initialized to`next_v_pred`since the last term in the*n*-step*Q*-estimate is .`not_dones`is a binary variable to handle the episodic boundary and stop the sum from propagating across episodes.Note that the

*n*-step*Q*-estimate is defined recursively, that is,*Q*_{current}=*r*_{current}+*Q*_{next}. This is mirrored exactly by line 8.

#### Code 6.1 Actor-Critic implementation: calculate *n*-step

1# slm_lab/lib/math_util.py2 3defcalc_nstep_returns(rewards, dones, next_v_pred, gamma, n): 4 rets = torch.zeros_like(rewards) 5 future_ret = next_v_pred 6 not_dones =1- dones 7fort in reversed(range(n)): 8 rets[t] = future_ret = rewards[t] + gamma * future_ret * not_dones[t] 9returnrets

Now, the `ActorCritic` class needs a method to compute the advantage estimates and the target *V* values for computing the policy and value losses, respectively. This is relatively straightforward, as shown in Code 6.2.

One detail about `advs` and `v_targets` is important to highlight. They do not have a gradient, as can be seen in the `torch.no_grad()` and `.detach()` operations in lines 9–11. In policy loss (Equation 6.1), the advantage only acts as a scalar multiplier to the gradient of the policy log probability. As for the value loss from Algorithm 6.1 (lines 13–14), we assume the target *V* -value is fixed, and the goal is to train the critic to predict *V* -value that closely matches it.

#### Code 6.2 Actor-Critic implementation: calculate *n*-step advantages and *V* -target values

1# slm_lab/agent/algorithm/actor_critic.py2 3classActorCritic(Reinforce): 4 ... 5 6defcalc_nstep_advs_v_targets(self, batch, v_preds): 7 next_states = batch['next_states'][-1] 8 ... 9withtorch.no_grad(): 10 next_v_pred = self.calc_v(next_states, use_cache=False) 11 v_preds = v_preds.detach()# adv does not accumulate grad12 ... 13 nstep_rets = math_util.calc_nstep_returns(batch['rewards'], ↪ batch['dones'], next_v_pred, self.gamma, self.num_step_returns) 14 advs = nstep_rets - v_preds 15 v_targets = nstep_rets 16 ... 17returnadvs, v_targets

#### 6.4.1.2 Advantage Estimation with GAE

The implementation of GAE shown in Code 6.3 has a very similar form to that of *n*-step. It uses the same backward computation, except that we need an extra step to compute the *δ* term at each time step (line 11).

#### Code 6.3 Actor-Critic implementation: calculate GAE

1# slm_lab/lib/math_util.py2 3defcalc_gaes(rewards, dones, v_preds, gamma, lam): 4 T = len(rewards) 5assertT +1== len(v_preds)# v_preds includes states and 1 last↪next_state6 gaes = torch.zeros_like(rewards) 7 future_gae = torch.tensor(0.0, dtype=rewards.dtype) 8# to multiply with not_dones to handle episode boundary (last state has no↪V(s'))9 not_dones =1- dones 10fort in reversed(range(T)): 11 delta = rewards[t] + gamma * v_preds[t +1] * not_dones[t] - ↪ v_preds[t] 12 gaes[t] = future_gae = delta + gamma * lam * not_dones[t] * future_gae 13returngaes

Likewise, in Code 6.4, the Actor-Critic class method to compute the advantages and target *V* -values closely follows that of *n*-step with two important differences. First, `calc_gaes` (line 14) returns the full advantage estimates, whereas `calc_nstep_returns` in the *n*-step case returns *Q* value estimates. To recover the target *V* values, we therefore need to add the predicted *V* -values (line 15). Second, it is good practice to standardize the GAE advantage estimates (line 16).

#### Code 6.4 Actor-Critic implementation: calculate GAE advantages and *V* -target values

1# slm_lab/agent/algorithm/actor_critic.py2 3classActorCritic(Reinforce): 4 ... 5 6defcalc_gae_advs_v_targets(self, batch, v_preds): 7 next_states = batch['next_states'][-1] 8 ... 9withtorch.no_grad(): 10 next_v_pred = self.calc_v(next_states, use_cache=False) 11 v_preds = v_preds.detach()# adv does not accumulate grad12 ... 13 v_preds_all = torch.cat((v_preds, next_v_pred), dim=0) 14 advs = math_util.calc_gaes(batch['rewards'], batch['dones'], ↪ v_preds_all, self.gamma, self.lam) 15 v_targets = advs + v_preds 16 advs = math_util.standardize(advs)# standardize only for advs, not↪v_targets17 ... 18returnadvs, v_targets

### 6.4.2 Calculating Value Loss and Policy Loss

In Code 6.5, the policy loss has the same form as in the REINFORCE implementation. The only difference is that it uses the advantages instead of returns as a reinforcing signal, so we can inherit and reuse the method from REINFORCE (line 7).

The value loss is simply a measure of the error between (`v_preds`) and (`v_targets`). We are free to choose any appropriate measure such as MSE by setting the `net.loss_spec` param in the spec file. This will initialize a loss function `self.net_loss_fn` used in line 11.

#### Code 6.5 Actor-Critic implementation: two loss functions

1# slm_lab/agent/algorithm/actor_critic.py2 3classActorCritic(Reinforce): 4 ... 5 6defcalc_policy_loss(self, batch, pdparams, advs): 7returnsuper().calc_policy_loss(batch, pdparams, advs) 8 9defcalc_val_loss(self, v_preds, v_targets): 10assertv_preds.shape == v_targets.shape, f'{v_preds.shape} !=↪{v_targets.shape}'11 val_loss = self.val_loss_coef * self.net.loss_fn(v_preds, v_targets) 12returnval_loss

### 6.4.3 Actor-Critic Training Loop

The actor and critic can be implemented with either separate networks or a single shared network. This is reflected in the `train` method in Code 6.6. Lines 10–15 calculate the policy and value losses for training. If the implementation uses a shared network (line 16), the two losses are combined and used to train the network (lines 17–18). If the actor and critic are separate networks, the two losses are used separately to train the relevant networks (lines 20–21). Section 6.5 goes into more details on network architecture.

#### Code 6.6 Actor-Critic implementation: training method

1# slm_lab/agent/algorithm/actor_critic.py2 3classActorCritic(Reinforce): 4 ... 5 6deftrain(self): 7 ... 8 clock = self.body.env.clock 9ifself.to_train ==1: 10 batch = self.sample() 11 clock.set_batch_size(len(batch)) 12 pdparams, v_preds = self.calc_pdparam_v(batch) 13 advs, v_targets = self.calc_advs_v_targets(batch, v_preds) 14 policy_loss = self.calc_policy_loss(batch, pdparams, advs)# from↪actor15 val_loss = self.calc_val_loss(v_preds, v_targets)# from critic16ifself.shared:# shared network17 loss = policy_loss + val_loss 18 self.net.train_step(loss, self.optim, self.lr_scheduler, ↪ clock=clock, global_net=self.global_net) 19else: 20 self.net.train_step(policy_loss, self.optim, ↪ self.lr_scheduler, clock=clock, ↪ global_net=self.global_net) 21 self.critic_net.train_step(val_loss, self.critic_optim, ↪ self.critic_lr_scheduler, clock=clock, ↪ global_net=self.global_critic_net) 22 loss = policy_loss + val_loss 23# reset24 self.to_train =025returnloss.item() 26else: 27returnnp.nan