Intuiting Policy Gradient methods
Recently, I found it imperative to grok Policy Gradient (PG) methods. As much as I enjoy entering rabbit holes of adjacent techniques, which are abundant in RL, I have refrained. The motivation is to think effectively about PG research in LLMs/foundational models.
Problem Setting
It comes naturally to me to pit anything which learns with the standard Supervised Learning framework.
In Supervised Learning, the world is static: given a fixed dataset, the task is to learn a parameterised function mapping inputs to the targets. The gradient of your loss function provides a direct, unambiguous error signal. Reinforcement learning breaks this assumption. Here, we have an agent interacting with an environment, only yielding sparse, indirect signals.
This interaction is formalised as a Markov Decision Process (MDP). At each timestep \(t\), we have: \(\begin{aligned} &\text{State:} \quad s_t \\ &\text{Action:} \quad a_t \sim \pi_\theta(a_t \mid s_t) \\ &\text{Reward:} \quad r_t = r(s_t, a_t) \\ &\text{Transition:} \quad s_{t+1} \sim P(s_{t+1} \mid s_t, a_t) \end{aligned}\)
\[\underbrace{\pi_\theta(a_t \mid s_t)}_{\text{Policy}} \qquad \underbrace{r(s, a)}_{\text{Reward function}} \qquad \underbrace{P(s' \mid s, a)}_{\text{Transition probability}}\]LLM analogy: \(s_t\) is the prompt plus all tokens generated so far, \(a_t\) is the next token to generate, \(\pi_\theta(a_t \mid s_t)\) is the model’s policy (the probability of generating token \(a_t\) given the current context \(s_t\)), \(\tau\) is the complete response sequence, and \(\theta\) are the model weights.
Hence, in the Policy Gradient setting, our true objective is similar to Supervised Learning: to maximise the performance of our parameterised policy. However, since we do not possess golden labels, we do this indirectly by maximising the expected return over trajectories, \(J(\theta)\):
\[J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)] = \sum_{\tau} P(\tau\mid \theta) R(\tau), \quad R(\tau) = \sum_{t=0}^T r_t \tag{1}\]where a trajectory \(\tau = (s_0, a_0, \dots, s_{T+1})\) is a sequence of states and actions of length \(T\), and \(\pi_\theta\) is the policy parameterised by \(\theta\).
The policy \(\pi_\theta\) influences its own data distribution, creating a shifting landscape where we hope that the data gets better.
To optimise \(J(\theta)\), we use gradient ascent:
\[\theta \leftarrow \theta + \alpha \nabla_\theta J(\theta) \tag{2}\]Also, the trajectory probability factors can be expanded into:
\[P(\tau \mid \theta) = \rho_0(s_0) \prod_{t=0}^{T} P(s_{t+1} \mid s_t, a_t) \pi_\theta(a_t \mid s_t) \tag{3}\]i.e. the probability of a trajectory is the product, at each timestep, of the environment’s transition probability and the policy’s action probability, beginning from the initial state distribution \(\rho_0\). This factorisation is key in the next section.
Policy Gradient Theorem
However, there is a glaring issue with equation \((2)\). Computing \(\nabla_\theta J(\theta)\) directly from equation \((1)\) is infeasible because:
1. Exact Gradient Is Infeasible
Calculating the true gradient requires summing over all possible trajectories \(\tau\) in equation \((1)\). This is computationally intractable since the number of possible trajectories grows exponentially with trajectory length.
2. State Distribution Issue
Differentiating the objective would require differentiating the environment’s state distribution \(\rho_0(s_0)\) and transition probabilities \(P(s_{t+1}\mid s_t, a_t)\) in equation \((3)\). These environment dynamics are typically unknown or non-differentiable.
The Practical Solution
The policy gradient theorem provides an elegant solution:
-
Analytical Form: Derive the policy gradient as an expectation over trajectories:
\[\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ R(\tau) \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t\mid s_t) \right] \tag{4}\]
The final form of the Policy Gradient Theorem saves us from differentiating through the environment dynamics. One could even interpret it as a common stoic lesson:
Make the best use of what is in our power, and treat the rest in accordance with its nature. ~ Epictetus
See Appendix A for complete treatment on how we got to equation \(4\).
- Estimation: Estimate this expectation using sampled trajectories from agent-environment interactions.
REINFORCE
An example of estimation of the expectation is Monte Carlo REINFORCE. Specifically, we collect multiple trajectories and compute:
\[\hat{\nabla}_\theta J(\theta) = \frac{1}{m} \sum_{i=1}^{m} \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t^{(i)}\mid s_t^{(i)}) R(\tau^{(i)}) \tag{5}\]where:
- \(m\) is the number of sampled trajectories
- \(\tau^{(i)} = (s_0^{(i)}, a_0^{(i)}, \dots, s_{T+1}^{(i)})\) is the \(i\)-th trajectory
- \(R(\tau^{(i)}) = \sum_{t=0}^{T} r_t^{(i)}\) is the total return for trajectory \(i\)
- \(a_t^{(i)}, s_t^{(i)}\) are the action and state at time \(t\) in trajectory \(i\)
Intuition: Each trajectory gives a noisy gradient estimate; averaging \(m\) of them reduces variance. The total reward \(R(\tau^{(i)})\) weights each trajectory’s impact. This method is unbiased but still noisy.
You can probably deduce this method is pretty naive.
Example: Raw Returns
State | a1 reward | a2 reward |
---|---|---|
s1 | 12 | 7 |
s2 | 2 | 5 |
For clarity, we treat these as one-step episodes, so “reward” and “return” coincide: \(R(\tau) = r\). A stochastic policy samples:
- \((s_1, a_2)\) (reward = 7) (sub-optimal)
- \((s_2, a_2)\) (reward = 5) (optimal)
The raw REINFORCE update is:
\[\hat{\nabla}_\theta J(\theta) = 7\,\nabla_\theta\log\pi_\theta(a_2 \mid s_1) + 5\,\nabla_\theta\log\pi_\theta(a_2 \mid s_2)\]Here, the wrong action in the high-reward state (\(s_1, a_2\)) receives a larger update than the correct action in the low-reward state, purely due to reward scale and MC randomness.
Variance Reduction
Most takes on Policy Gradient formulations try to resolve this nasty variance issue. One of the most common ways is using a baseline. A baseline \(b(s_t)\) is a function of the state that is subtracted from the return. The gradient estimator becomes:
\[\hat{\nabla}_\theta J(\theta) = \frac{1}{m} \sum_{i=1}^{m} \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t^{(i)}\mid s_t^{(i)}) (R(\tau^{(i)}) - b(s_t^{(i)})) \tag{6}\]Let’s apply this to our example. A simple baseline is the average reward for each state:
- For \(s_1\): \(b(s_1) = \frac{12 + 7}{2} = 9.5\) (average of possible rewards)
- For \(s_2\): \(b(s_2) = \frac{2 + 5}{2} = 3.5\) (average of possible rewards)
For our sampled actions:
- \((s_1, a_2) \implies 7 - 9.5 = \mathbf{-2.5}\) (Correctly discourages below-average action)
- \((s_2, a_2) \implies 5 - 3.5 = \mathbf{+1.5}\) (Correctly encourages above-average action)
Intuition: Raw returns can be huge or tiny depending on the state. Subtracting \(b(s)\) recentres the reward around “what’s normal” for that state, reducing variance by making updates reflect surplus/deficit rather than absolute scale.
The Quest for the Optimal Baseline
The optimal baseline, \(b(s_t)\), should minimise the variance of the gradient estimator in equation \((6)\) without introducing bias. A baseline is guaranteed to be unbiased if it only depends on the state \(s_t\). Hence, its expected contribution to the gradient is zero: \(\mathbb{E}_{a_t \sim \pi_\theta(\cdot\mid s_t)}[\nabla_\theta \log \pi_\theta(a_t\mid s_t)b(s_t)] = 0\). See the proof in Appendix B.
Now, minimising variance can be tricky, however, we can simplify the problem as follows:
- We ignore correlations between timesteps, focusing on a single timestep’s gradient contribution.
- We treat this contribution’s gradient term as a scalar. We call this term \(g_t\).
The full derivation and intuition are provided in Appendix C. The practical result is that the optimal baseline is the state-value function:
\[b^*(s_t) = \frac{\mathbb{E}\left[ g_t\,R_t \mid s_t \right]}{\mathbb{E}\left[ g_t^2 \mid s_t \right]} \approx \mathbb{E}[R_t \mid s_t] = V^{\pi}(s_t) \tag{7}\]Looking through the Advantage Function Lens
Given the action-value function: \(Q^{\pi}(s, a) = \mathbb{E}_{\pi}\left[ R_t \mid s_t = s, a_t = a \right]\)
In practice, we estimate \(Q^{\pi}(s, a)\) with the reward-to-go from a single trajectory: \(R_t^{(i)} = \sum_{k=t}^T \gamma^{k-t} r_k^{(i)}\) This uses only future rewards, as actions cannot affect the past.
We can define the advantage function as: \(A^{\pi}(s, a) = Q^{\pi}(s, a) - V^{\pi}(s)\) where \(V^{\pi}(s) = \mathbb{E}_{a \sim \pi(\cdot\mid s)}[Q^{\pi}(s, a)]\) is the expected value of the state under the policy.
Intuition: \(A^{\pi}(s, a)\) measures how much better (or worse) action \(a\) is compared to the average. Effectively, highlighting outlier actions.
The baseline-subtracted term from equation \((6)\) is exactly our empirical advantage estimate:
\[\hat{A}^{\pi}(s_t^{(i)}, a_t^{(i)}) = R_t^{(i)} - b(s_t^{(i)})\]Hence, the final policy gradient estimator uses the advantage function:
\[\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^T \nabla_\theta \log \pi_\theta(a_t\mid s_t) A^{\pi}(s_t, a_t) \right] \tag{8}\]Appendix
A. Policy Gradient Theorem Proof
The policy gradient theorem gets us over the direct gradient calculation hump. However, it is important to understand how we shed certain terms to get to the final form. Here’s the complete derivation:
1: Start from the objective definition
\[\nabla_\theta J(\theta) = \nabla_\theta \sum_{\tau} P(\tau\mid \theta) R(\tau)\]Why: We start with equation \((1)\), \(J(\theta) = \sum_{\tau} P(\tau\mid \theta) R(\tau)\), and take its gradient.
2: Linearity of differentiation
\[\nabla_\theta J(\theta) = \sum_{\tau} \nabla_\theta P(\tau\mid \theta) R(\tau)\]3: Apply the log-derivative trick
\[\nabla_\theta J(\theta) = \sum_{\tau} P(\tau\mid \theta) \nabla_\theta \log P(\tau\mid \theta) R(\tau)\]Why: The basic log trick: \(\nabla_\theta \log f = \frac{\nabla_\theta f}{f}\)
4: Rewrite as expectation
\[\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ R(\tau) \nabla_\theta \log P(\tau\mid \theta) \right]\]Why: The sum over trajectories weighted by their probabilities is equivalent to an expectation under the policy.
5: Factorise trajectory probability
Let’s bring back equation \((3)\) into focus, we expand the trajectory probability:
\[P(\tau \mid \theta) = \rho_0(s_0) \prod_{t=0}^{T} P(s_{t+1} \mid s_t, a_t) \pi_\theta(a_t \mid s_t)\]6: Take logarithm of the product
\[\log P(\tau \mid \theta) = \log \rho_0(s_0) + \sum_{t=0}^{T} \log P(s_{t+1} \mid s_t, a_t) + \sum_{t=0}^{T} \log \pi_\theta(a_t \mid s_t)\]7: Differentiate and eliminate environment terms
\[\nabla_\theta \log P(\tau \mid \theta) = \nabla_\theta \log \rho_0(s_0) + \sum_{t=0}^{T} \nabla_\theta \log P(s_{t+1} \mid s_t, a_t) + \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t \mid s_t)\]Since \(\rho_0(s_0)\) and \(P(s_{t+1}\mid s_t, a_t)\) don’t depend on policy parameters \(\theta\):
\[\nabla_\theta \log \rho_0(s_0) = 0\] \[\nabla_\theta \log P(s_{t+1}\mid s_t, a_t) = 0\]Therefore:
\[\nabla_\theta \log P(\tau\mid \theta) = \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t\mid s_t)\]Why: Only policy terms carry \(\theta\) dependence. Environment dynamics are constant with respect to policy parameters.
8: Substitute back to obtain final result
\[\nabla_\theta J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ R(\tau) \sum_{t=0}^{T} \nabla_\theta \log \pi_\theta(a_t\mid s_t) \right]\]In this formulation we relinquish control over the environment. However, I believe, high-agency humans do exert control over their environment to get to their goals.
B. Proof of Unbiasedness for State-Dependent Baselines
The proof relies on the fact that the expectation of the score function is zero.
Proof of Zero-Expectation Score:
\[1 = \int_{a_t} \pi_\theta(a_t\mid s_t) da_t\]\[\nabla_\theta(1) = \nabla_\theta \int_{a_t} \pi_\theta(a_t\mid s_t) da_t\] \[0 = \int_{a_t} \nabla_\theta \pi_\theta(a_t\mid s_t) da_t\] \[0 = \int_{a_t} \pi_\theta(a_t\mid s_t) \nabla_\theta \log \pi_\theta(a_t\mid s_t) da_t\]Any probability distribution must integrate to 1.
\[0 = \mathbb{E}_{a_t \sim \pi_\theta(\cdot\mid s_t)}[\nabla_\theta \log \pi_\theta(a_t\mid s_t)]\]Apply the identity \(\nabla_\theta f = f \nabla_\theta \log f\) (same as Appendix A).
The integral defines the expectation, proving the expected score is zero.
Unbiased Baseline Proof:
\[\begin{aligned} \mathbb{E}_{a_t \sim \pi_\theta(\cdot \mid s_t)}[\nabla_\theta \log \pi_\theta(a_t \mid s_t) b(s_t)] &= b(s_t) \mathbb{E}_{a_t \sim \pi_\theta(\cdot \mid s_t)}[\nabla_\theta \log \pi_\theta(a_t \mid s_t)] \\ &= b(s_t) \cdot 0 \\ &= 0 \end{aligned}\]Since \(b(s_t)\) is constant with respect to the expectation over actions, it can be factored out.
This proves the baseline is a control variate.
C. Optimal Baseline Derivation
To find the variance-minimising baseline \(b(s_t)\) under the simplified model, we analyse the variance of a single gradient sample, \(Z_t = (R_t - b(s_t))g_t\), where \(g_t\) is treated as a scalar: \(g_t = \nabla_\theta \log \pi_\theta(a_t\mid s_t)\).
Step 1: Set Up the Variance Minimisation
The goal is to minimise \(\text{Var}[Z_t] = \mathbb{E}[Z_t^2] - (\mathbb{E}[Z_t])^2\). Since the baseline is unbiased, \(\mathbb{E}[Z_t]\) is a constant with respect to \(b(s_t)\), so we only need to minimise \(\mathbb{E}[Z_t^2]\). For a given state \(s_t\):
\[\min_{b(s_t)} \mathbb{E}_{a_t, R_t | s_t} \left[ g_t^2 (R_t - b(s_t))^2 \right]\]Why is \(\mathbb{E}[Z_t]\) constant?
Because \(\mathbb{E}[Z_t] = \mathbb{E}[g_t(R_t-b(s_t))] = \mathbb{E}[g_t R_t] - b(s_t)\,\mathbb{E}[g_t]\), and by the result in Appendix B, \(\mathbb{E}[g_t] = 0\). Thus, the baseline-dependent part vanishes, leaving a term independent of \(b(s_t)\).
Step 2: Differentiate and Solve
We find the minimum by differentiating with respect to \(b(s_t)\) and setting the result to zero:
\[\frac{\partial}{\partial b(s_t)} \mathbb{E}\left[g_t^2(R_t - b(s_t))^2 | s_t\right] = \mathbb{E}\left[-2g_t^2(R_t-b(s_t))|s_t\right] = 0\]This simplifies to \(\mathbb{E}[g_t^2 R_t \mid s_t] = b(s_t) \mathbb{E}[g_t^2 \mid s_t]\), which gives the optimal baseline:
\[b^*(s_t)=\frac{\mathbb{E}[g_t^2 R_t \mid s_t]}{\mathbb{E}[g_t^2\mid s_t]}\]Since \(b(s_t)\) is a function of state \(s_t\) only, it is constant with respect to the expectation over actions and rewards, conditioned on that state.
Step 3: One more approximation
This weighted baseline is still impractical. The final leap is to assume the policy sensitivity term, \(g_t^2\), can be factored out and cancelled.
Why is this reasonable?
While \(g_t\) itself varies by action, its expected squared norm, \(\mathbb{E}[g_t^2\mid s_t]\), can be seen as a measure of the policy’s overall “instability” at a state \(s_t\). By assuming this instability is not strongly correlated with the rewards \(R_t\), we can approximate \(\mathbb{E}[g_t^2 R_t \mid s_t] \approx \mathbb{E}[g_t^2\mid s_t]\mathbb{E}[R_t\mid s_t]\). It’s important to note that this is not an exact result, but rather a pragmatic workaround.
The sensitivity term now cancels out, leaving the simple expected return: \(b(s_t) \approx \frac{\mathbb{E}[g_t^2 \mid s_t] \mathbb{E}[R_t \mid s_t]}{\mathbb{E}[g_t^2 \mid s_t]} = \mathbb{E}[R_t \mid s_t] \equiv V^{\pi}(s_t)\)
Acknowledgments
Excellent resources I referenced to build my understanding:
- Daniel Takeshi’s blog post on Policy Gradient Fundamentals
- OpenAI Spinning Up
- RLHF Book: Policy Gradient Algorithms
- Lilian Weng’s Policy Gradient Methods
- Sutton, R. S., & Barto, A. G. (2018). Reinforcement Learning: An Introduction (2nd ed.). MIT Press.