Home > Articles

  • Print
  • + Share This
This chapter is from the book

6.4 Implementing A2C

We now have all of the elements needed to implement the Actor-Critic algorithms. The main components are

  • Advantage estimation –for example, n-step returns or GAE

  • Value loss and policy loss

  • The training loop

In what follows, we discuss an implementation of each of these components, ending with the training loop which brings them all together. Since Actor-Critic conceptually extends REINFORCE, it is implemented by inheriting from the Reinforce class.

Actor-Critic is also an on-policy algorithm since the actor component learns a policy using the policy gradient. Consequently, we train Actor-Critic algorithms using an on-policy Memory, such as OnPolicyReplay which trains in an episodic manner, or OnPolicyBatchReplay which trains using batches of data. The code that follows applies to either approach.

6.4.1 Advantage Estimation

6.4.1.1 Advantage Estimation with n-Step Returns

The main trick when implementing the n-step Qπ estimate is to notice that we have access to the rewards received in an episode in sequence. We can calculate the discounted sum of n rewards for each element of the batch in parallel by taking advantage of vector arithmetic, as shown in Code 6.1. The approach goes as follows:

  1. Initialize a vector rets to populate with the Q-value estimates—that is, n-step returns (line 4).

  2. For efficiency, the computation is done from the last term to the first. The future_ret is a placeholder to accumulate the summed rewards as we work backwards (line 5). It is initialized to next_v_pred since the last term in the n-step Q-estimate is f0144-01.jpg.

  3. not_dones is a binary variable to handle the episodic boundary and stop the sum from propagating across episodes.

  4. Note that the n-step Q-estimate is defined recursively, that is, Qcurrent = rcurrent + Qnext. This is mirrored exactly by line 8.

Code 6.1 Actor-Critic implementation: calculate n-step f0144-02.jpg

1   # slm_lab/lib/math_util.py
2
3   def calc_nstep_returns(rewards, dones, next_v_pred, gamma, n):
4       rets = torch.zeros_like(rewards)
5       future_ret = next_v_pred
6       not_dones = 1 - dones
7       for t in reversed(range(n)):
8           rets[t] = future_ret = rewards[t] + gamma * future_ret * not_dones[t]
9       return rets

Now, the ActorCritic class needs a method to compute the advantage estimates and the target V values for computing the policy and value losses, respectively. This is relatively straightforward, as shown in Code 6.2.

One detail about advs and v_targets is important to highlight. They do not have a gradient, as can be seen in the torch.no_grad() and .detach() operations in lines 9–11. In policy loss (Equation 6.1), the advantage only acts as a scalar multiplier to the gradient of the policy log probability. As for the value loss from Algorithm 6.1 (lines 13–14), we assume the target V -value is fixed, and the goal is to train the critic to predict V -value that closely matches it.

Code 6.2 Actor-Critic implementation: calculate n-step advantages and V -target values

 1  # slm_lab/agent/algorithm/actor_critic.py
 2
 3   class ActorCritic(Reinforce):
 4       ...
 5
 6       def calc_nstep_advs_v_targets(self, batch, v_preds):
 7           next_states = batch['next_states'][-1]
 8           ...
 9           with torch.no_grad():
10               next_v_pred = self.calc_v(next_states, use_cache=False)
11           v_preds = v_preds.detach()# adv does not accumulate grad
12           ...
13           nstep_rets = math_util.calc_nstep_returns(batch['rewards'],
             ↪ batch['dones'], next_v_pred, self.gamma, self.num_step_returns)
14           advs = nstep_rets - v_preds
15           v_targets = nstep_rets
16           ...
17           return advs, v_targets

6.4.1.2 Advantage Estimation with GAE

The implementation of GAE shown in Code 6.3 has a very similar form to that of n-step. It uses the same backward computation, except that we need an extra step to compute the δ term at each time step (line 11).

Code 6.3 Actor-Critic implementation: calculate GAE

 1  # slm_lab/lib/math_util.py
 2
 3   def calc_gaes(rewards, dones, v_preds, gamma, lam):
 4       T = len(rewards)
 5       assert T + 1 == len(v_preds) # v_preds includes states and 1 lastnext_state
 6       gaes = torch.zeros_like(rewards)
 7       future_gae = torch.tensor(0.0, dtype=rewards.dtype)
 8       # to multiply with not_dones to handle episode boundary (last state has noV(s'))
 9       not_dones = 1 - dones
10       for t in reversed(range(T)):
11           delta = rewards[t] + gamma * v_preds[t + 1] * not_dones[t] -
             ↪ v_preds[t]
12           gaes[t] = future_gae = delta + gamma * lam * not_dones[t] * future_gae
13       return gaes

Likewise, in Code 6.4, the Actor-Critic class method to compute the advantages and target V -values closely follows that of n-step with two important differences. First, calc_gaes (line 14) returns the full advantage estimates, whereas calc_nstep_returns in the n-step case returns Q value estimates. To recover the target V values, we therefore need to add the predicted V -values (line 15). Second, it is good practice to standardize the GAE advantage estimates (line 16).

Code 6.4 Actor-Critic implementation: calculate GAE advantages and V -target values

 1  # slm_lab/agent/algorithm/actor_critic.py
 2
 3   class ActorCritic(Reinforce):
 4       ...
 5
 6       def calc_gae_advs_v_targets(self, batch, v_preds):
 7           next_states = batch['next_states'][-1]
 8           ...
 9           with torch.no_grad():
10               next_v_pred = self.calc_v(next_states, use_cache=False)
11           v_preds = v_preds.detach()# adv does not accumulate grad
12           ...
13           v_preds_all = torch.cat((v_preds, next_v_pred), dim=0)
14           advs = math_util.calc_gaes(batch['rewards'], batch['dones'],
             ↪ v_preds_all, self.gamma, self.lam)
15           v_targets = advs + v_preds
16           advs = math_util.standardize(advs) # standardize only for advs, notv_targets
17           ...
18           return advs, v_targets

6.4.2 Calculating Value Loss and Policy Loss

In Code 6.5, the policy loss has the same form as in the REINFORCE implementation. The only difference is that it uses the advantages instead of returns as a reinforcing signal, so we can inherit and reuse the method from REINFORCE (line 7).

The value loss is simply a measure of the error between f0147-01.jpg (v_preds) and f0147-02.jpg (v_targets). We are free to choose any appropriate measure such as MSE by setting the net.loss_spec param in the spec file. This will initialize a loss function self.net_loss_fn used in line 11.

Code 6.5 Actor-Critic implementation: two loss functions

 1  # slm_lab/agent/algorithm/actor_critic.py
 2
 3   class ActorCritic(Reinforce):
 4       ...
 5
 6       def calc_policy_loss(self, batch, pdparams, advs):
 7           return super().calc_policy_loss(batch, pdparams, advs)
 8
 9       def calc_val_loss(self, v_preds, v_targets):
10           assert v_preds.shape == v_targets.shape, f'{v_preds.shape} !={v_targets.shape}'
11           val_loss = self.val_loss_coef * self.net.loss_fn(v_preds, v_targets)
12           return val_loss

6.4.3 Actor-Critic Training Loop

The actor and critic can be implemented with either separate networks or a single shared network. This is reflected in the train method in Code 6.6. Lines 10–15 calculate the policy and value losses for training. If the implementation uses a shared network (line 16), the two losses are combined and used to train the network (lines 17–18). If the actor and critic are separate networks, the two losses are used separately to train the relevant networks (lines 20–21). Section 6.5 goes into more details on network architecture.

Code 6.6 Actor-Critic implementation: training method

 1 # slm_lab/agent/algorithm/actor_critic.py
 2
 3 class ActorCritic(Reinforce):
 4     ...
 5
 6     def train(self):
 7         ...
 8         clock = self.body.env.clock
 9         if self.to_train == 1:
10             batch = self.sample()
11             clock.set_batch_size(len(batch))
12             pdparams, v_preds = self.calc_pdparam_v(batch)
13             advs, v_targets = self.calc_advs_v_targets(batch, v_preds)
14             policy_loss = self.calc_policy_loss(batch, pdparams, advs)# fromactor
15             val_loss = self.calc_val_loss(v_preds, v_targets)# from critic
16             if self.shared:# shared network
17                 loss = policy_loss + val_loss
18                 self.net.train_step(loss, self.optim, self.lr_scheduler,
                   ↪ clock=clock, global_net=self.global_net)
19             else:
20                 self.net.train_step(policy_loss, self.optim,
                   ↪ self.lr_scheduler, clock=clock,
                   ↪ global_net=self.global_net)
21                 self.critic_net.train_step(val_loss, self.critic_optim,
                   ↪ self.critic_lr_scheduler, clock=clock,
                   ↪ global_net=self.global_critic_net)
22                 loss = policy_loss + val_loss
23             # reset
24             self.to_train = 0
25             return loss.item()
26         else:
27             return np.nan
  • + Share This
  • 🔖 Save To Your Account