Learning objectives
- State the gradient descent update rule and explain what each term means.
- Identify how the learning rate affects convergence: too small (slow), too large (diverge), just right (converges).
- Implement gradient descent from scratch in NumPy and plot the resulting loss curve.
Concept and real-world motivation
Imagine you are blindfolded on a hilly landscape and you want to reach the lowest valley. You cannot see the whole landscape, but you can feel the slope under your feet. Gradient descent says: take a small step in the direction that slopes downward, then repeat. The learning rate controls how large each step is.
Formally, if we want to minimise a loss \(L\) with respect to a parameter \(w\), we update:
\[w \leftarrow w - \alpha \nabla_w L\]
where \(\alpha\) is the learning rate and \(\nabla_w L\) is the gradient (the direction of steepest ascent). We subtract the gradient to go downhill. We repeat this update many times — each repetition is called an iteration or step.
The loss curve — a plot of \(L\) vs iteration — tells the whole story of training. A healthy loss curve drops fast early and plateaus smoothly. A diverging curve (loss explodes upward) means the learning rate is too large. A flat curve from the start means the learning rate is too small or the gradient is zero.
RL connection: Policy gradient ascent does the exact same thing, but in reverse — it maximises the expected return \(J(\theta)\) instead of minimising a loss:
\[\theta \leftarrow \theta + \alpha \nabla_\theta J(\theta)\]
REINFORCE, PPO, and A3C all implement this. The only difference from supervised gradient descent is the \(+\) sign. Mastering gradient descent here means policy gradient later is just a sign change.
Illustration: The loss curve below shows typical gradient descent behaviour: rapid decrease in early iterations, then a gentle plateau as the minimum is approached.
How does the learning rate affect convergence speed? Explore different values here:
Exercise: Implement gradient descent to minimise \(f(x) = (x-3)^2\). Start at \(x=10\), use \(\alpha=0.1\), run 30 steps. Print \(x\) and \(f(x)\) at each step to watch convergence.
Professor’s hints
- The gradient of \((x-3)^2\) with respect to \(x\) is \(2(x-3)\). When \(x > 3\), the gradient is positive, so subtracting it decreases \(x\). When \(x < 3\), it is negative, so subtracting it increases \(x\). Either way, \(x\) moves toward 3.
- With \(\alpha=0.1\), each step reduces the error by 80% (since \(1 - 2\alpha = 0.8\) for this parabola). After 30 steps, \(x\) will be extremely close to 3.
- Use
print(f'step {step}: x={x:.6f}, loss={f(x):.8f}')to see progress at each iteration.
Common pitfalls
- Wrong sign: Writing
x = x + lr * grad_f(x)instead ofx - lr * ...causes gradient ascent — \(x\) runs away from the minimum. This is the most common gradient descent bug. - Learning rate too large: For \(f(x) = (x-3)^2\), using \(\alpha \geq 1\) causes the update to overshoot and diverge. Try
lr=2.0to see this happen. - Not tracking the loss: Always log the loss during training. Silent divergence (loss going to infinity) should fail loudly, not quietly.
Worked solution
Full gradient descent implementation with output:
| |
The key line is x = x - lr * grad_f(x). After 30 steps with lr=0.1, \(x\) converges to 3.0 to 5 decimal places.
Extra practice
- Warm-up: For \(f(x) = (x-3)^2\) and starting point \(x=5\), compute the first two gradient steps by hand with \(\alpha=0.1\). What is \(x\) after step 1? After step 2?
- Coding: Minimise \(g(x) = x^4 - 4x^2 + x\) using gradient descent. The gradient is \(g’(x) = 4x^3 - 8x + 1\). Start at \(x=-2\), use \(\alpha=0.01\), run 200 steps. What is the minimum you find?
- Challenge: Implement mini-batch gradient descent on the linear regression problem from the previous page. Instead of using all 5 samples each step, use a random batch of 2. Compare convergence to full-batch gradient descent.
- Variant: Run gradient descent on \(f(x) = (x-3)^2\) with
lr=2.0. Print \(x\) and \(f(x)\) at each step. What happens? Explain why the loss explodes.
- Debug: The gradient sign is wrong below — the parameter moves away from the minimum instead of toward it. Fix the update rule.
- Conceptual: Explain in your own words why gradient descent does not always find the global minimum of a loss function. Under what conditions is it guaranteed to find the global minimum?
- Recall: Write the gradient descent update rule from memory. Identify each symbol: \(w\), \(\alpha\), \(\nabla_w L\). What is the unit of \(\alpha\)?