Alternative to PyTorch for implementing DQN, policy gradients, and other deep RL algorithms. The Keras API provides layers and optimizers; GradientTape gives full control over custom loss functions (e.g. policy gradient, CQL).
Why TensorFlow matters for RL
- Keras API —
tf.keras.Sequential,tf.keras.Model, layers (Dense, Conv2D). Quick prototyping of Q-networks and policies. - Gradient tape —
tf.GradientTape()records operations so you can compute gradients of any scalar with respect to trainable variables. Essential for policy gradient and custom losses. - Optimizers —
tf.keras.optimizers.Adam,apply_gradients. - Device placement — GPU via
tf.configwhen available.
Core concepts with examples
Dense layers and Sequential model
| |
Forward pass and MSE loss
| |
Training step with GradientTape
| |
Subclassing for custom models
| |
Exercises
Exercise 1. Create a Sequential model with one hidden layer (64 units, ReLU) and output dimension 2. Build it with input_shape=(4,). Call model(tf.random.normal((10, 4))) and print the output shape. Then use model.summary() to inspect parameters.
Exercise 2. In a GradientTape scope, compute \(y = x^2\) for \(x = tf.constant(3.0)\) and then tape.gradient(y, x). Verify the gradient is 6.0. (Use a variable: x = tf.Variable(3.0) so it’s differentiable.)
Exercise 3. Implement a training step that: (1) takes states (32, 4) and targets (32, 2); (2) inside GradientTape, computes Q-values from your model and MSE loss; (3) computes gradients and applies them with an Adam optimizer. Run 50 steps with random data and print the loss every 10 steps.
Exercise 4. Implement a softmax policy: a small model that maps state (4,) to logits (2,). Given a batch of states, compute action probabilities with tf.nn.softmax(logits). Sample actions with tf.random.categorical(tf.math.log(probs), 1). Return both the sampled actions and the log-probabilities of those actions (using tf.math.log and gather).
Exercise 5. Create a subclassed tf.keras.Model with two dense layers (64, ReLU) and output 2. Override call(self, inputs). Train it for 100 steps with random states and targets using GradientTape and Adam. Store the loss in a list and plot it (e.g. with matplotlib) to confirm it decreases.
Exercise 6. Create a Variable x = tf.Variable(2.0) and inside GradientTape() compute y = x ** 2, then grad = tape.gradient(y, x). Verify grad is 4.0. In RL: GradientTape records ops so policy and value gradients can be computed for custom losses.
Exercise 7. Build a small model (4 → 64 → 2). In a loop, generate random (32, 4) states and (32, 2) targets, call your train step, and append the loss to a list. Plot the list with matplotlib. In RL: This mirrors the inner loop of DQN or actor-critic training.
Exercise 8. (Challenge) Implement a softmax policy that takes state (batch, 4) and returns (actions, log_probs). Use tf.random.categorical for sampling. Train with a dummy “loss” = -mean(log_probs) for 50 steps and confirm loss decreases (you are maximizing log-prob). In RL: This is the core of REINFORCE-style updates.
Professor’s hints
- In RL: Use
GradientTape()for policy gradient and any loss that is not a simple Keras built-in. Record the forward pass inside the tape, thentape.gradient(loss, model.trainable_variables). - Wrap the training step in
@tf.functionfor speed after you have verified it works in eager mode. Be careful: Python side effects (e.g. appending to a list) insidetf.functionmay not run as expected. - Keep the model and optimizer creation outside the training step so variables are reused. Create the tape inside the step so each step has a fresh tape.
Common pitfalls
- Using a Python float instead of a Variable for gradients:
tape.gradient(y, x)requiresxto be atf.Variable(or a trainable model parameter). Constants do not get gradients. - Tape used outside scope: The tape is only valid inside the
with tf.GradientTape() as tape:block. Do not calltape.gradientafter the block. - Graph vs eager: In TensorFlow 2, eager execution is default. If you use
tf.function, ensure inputs are tensors or convert withtf.convert_to_tensor; avoid passing Python lists that change shape between calls (they can trigger retracing).
Docs: tensorflow.org/api_docs. Keras for high-level API.