Learning objectives

  • Implement a complete training loop: forward pass → loss → backprop → weight update
  • Understand the role of mini-batches and epochs in training efficiency
  • Track loss over epochs and interpret a learning curve
  • Connect the training loop pattern to DQN’s replay buffer training

Concept and real-world motivation

Training a neural network means repeatedly: (1) run a forward pass to get predictions, (2) compute the loss, (3) run backpropagation to get gradients, (4) update weights using an optimizer. This loop runs for many epochs (full passes over the training data). Each epoch is divided into mini-batches — subsets of the data processed together.

Why mini-batches? Computing gradients on one sample at a time (SGD) is noisy but fast per step. Computing on the whole dataset is stable but slow. Mini-batches balance these: enough samples for a stable gradient estimate, processed efficiently in parallel.

In RL: The DQN training loop samples a mini-batch from the replay buffer, does a forward pass to compute Q-values, computes the TD loss (a form of MSE), runs backprop through the Q-network, and updates with Adam. The “replay buffer” plays the role of the training dataset. The key difference from supervised learning: the targets change as the network improves — this instability is why DQN needs a target network.

Illustration — Training loop flow:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
┌─────────────────────────────────────────────────────────┐
│  Initialize weights                                      │
│         ↓                                               │
│  For each epoch:                                        │
│    For each mini-batch:                                 │
│      → Forward pass (compute predictions)               │
│      → Compute loss                                     │
│      → Backpropagation (compute gradients)              │
│      → Update weights (SGD / Adam)                      │
│         ↓                                               │
│  Log loss every N epochs → learning curve               │
└─────────────────────────────────────────────────────────┘

Exercise: Complete the full training loop for a 2-layer MLP on synthetic binary classification data.

Try it — edit and run (Shift+Enter)

Professor’s hints

  • Always shuffle the data at the start of each epoch (np.random.permutation) to avoid the network seeing the same order every time.
  • The learning curve should generally decrease. If it goes up after initially going down, the learning rate may be too large.
  • The backward pass mirrors the forward pass in reverse order — backprop through the last layer first.
  • Numerical stability: always add a small \(\epsilon\) (like 1e-9) inside log() to avoid log(0) = -inf.

Common pitfalls

  • Forgetting to subtract the gradient (adding it instead makes the loss increase).
  • Not zeroing out momentum/Adam state between epochs if you store it outside the loop.
  • Using the wrong denominator for the gradient: divide by batch size for the mean loss.

Extra practice

  1. Warm-up: Implement forward pass + cross-entropy loss for one batch only, without the training loop.

    Try it — edit and run (Shift+Enter)

  2. Coding: Add accuracy tracking to the training loop above: after each epoch, compute training accuracy (percentage of correct predictions).

  3. Challenge: Modify the training loop to use Adam instead of SGD. Track the loss and compare learning curves.

  4. Variant: Implement learning rate decay: multiply lr by 0.95 after each epoch. Observe the effect on convergence.

  5. Debug: Fix the training loop where weights are not being updated (gradient is added instead of subtracted):

    Try it — edit and run (Shift+Enter)

  6. Conceptual: In DQN, the “training data” changes over time as the agent collects new experience. How does the replay buffer address this? Why can’t you just use the last transition as a single-sample batch?

  7. Recall: What is the difference between a batch, a mini-batch, and an epoch? Why do we shuffle data at the start of each epoch?