Used in Preliminary: PyTorch basics and in the curriculum for DQN, policy gradients, actor-critic, PPO, and SAC. PyTorch’s define-by-run style and clear autograd make it a natural fit for custom RL loss functions.
Why PyTorch matters for RL Tensors — States, actions, and batches are tensors. torch.tensor(), requires_grad=True, and .to(device) are daily use. Autograd — Policy gradient and value losses need gradients; backward() and .grad are central. nn.Module — Q-networks, policy networks, and critics are nn.Module subclasses; parameters are collected for optimizers. Optimizers — torch.optim.Adam, zero_grad(), loss.backward(), optimizer.step(). Device — Move model and data to GPU with .to(device) for faster training. Core concepts with examples Tensors and gradients 1 2 3 4 5 6 import torch x = torch.tensor(2.0, requires_grad=True) y = x**2 y.backward() print(x.grad) # 4.0 Batches and shapes 1 2 3 4 5 # Batch of 32 states, 4 features (e.g. CartPole) states = torch.randn(32, 4) # Linear layer: 4 -> 64 W = torch.randn(4, 64, requires_grad=True) out = states @ W # (32, 64) Simple MLP with nn.Module 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 import torch.nn as nn class QNetwork(nn.Module): def __init__(self, state_dim=4, n_actions=2, hidden=64): super().__init__() self.net = nn.Sequential( nn.Linear(state_dim, hidden), nn.ReLU(), nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, n_actions), ) def forward(self, x): return self.net(x) q = QNetwork() s = torch.randn(8, 4) # batch 8 q_vals = q(s) # (8, 2) Training step (e.g. MSE loss) 1 2 3 4 5 6 optimizer = torch.optim.Adam(q.parameters(), lr=1e-3) targets = torch.randn(8, 2) loss = nn.functional.mse_loss(q_vals, targets) optimizer.zero_grad() loss.backward() optimizer.step() Device and CPU/GPU 1 2 3 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") q = q.to(device) states = states.to(device) Worked examples Example 1 — Autograd (Exercise 1). Create \(x = 3.0\) with requires_grad=True, compute \(y = x^3 + 2x\), call y.backward(), and verify x.grad.
...