This blog introduces the topic of the continuous adjoint equations for diffusion models, an efficient way to calculate gradients for diffusion models. We show how to design bespoke ODE/SDE solvers of the continuous adjoint equations and show that adjoint diffusion SDEs actually simplify to the adjoint diffusion ODE.
Guided generation is an important problem problem within machine learning. Solutions to this problem enable us to steer the output of the generative process to some desired output. This is especially important for allowing us to inject creative control into generative models. While there are several forms of this problem, we focus on problems which optimize the output of generative model towards some goal defined by a guidance (or loss) function defined on the output. These particular approaches excel in steering the generative process to perform adversarial ML attacks, e.g., bypassing security features, attacking Face Recognition (FR) systems, &c.
More formally, suppose we have some \(\R^d\) generative model, \(\bsg_\theta: \R^z \times \R^c \to \R^d\) parameterized by \(\theta \in \R^m\) which takes an initial latent \(\bfz \in \R^z\) and conditional information \(\mathbf{c} \in \R^c\). Furthermore, assume we have a scalar-valued guidance function \(\mathcal{L}: \R^d \to \R\). Then the guided generation problem can be expressed as an optimization problem: \begin{equation} \label{eq:opt_init} \argmin_{\bfz, \mathbf{c}, \theta} \quad \mathcal{L}(\bsg_\theta(\bfz, \mathbf{c})). \end{equation} I.e., we wish to find the optimal \(\bfz\), \(\mathbf{c}\), and \(\theta\) which minimizes our guidance function. A very natural solution to this kind of problem is to perform gradient descent by using reverse-mode automatic differentiation to find the gradients.
In this blog post, we focus on a technique for finding the gradients for a very popular class of generative models known as diffusion models
First we give a brief introduction on diffusion models and score-based generative modeling. More comprehensive coverage can be found at Yang Song’s blog post and Lilian Weng’s blog post on this topic.
Diffusion models start with a diffusion process which perturbs the original data distribution \(p_{\textrm{data}}(\bfx)\) on \(\R^d\) into isotropic Gaussian noise \(\mathcal{N}(\mathbf 0, \mathbf I)\). This process can be modeled with an Itô Stochastic Differential Equation (SDE) of the form \begin{equation} \label{eq:ito_diffusion} \mathrm{d}\bfx_t = \underbrace{f(t)\bfx_t\; \mathrm dt}_{\textrm{Deterministic term $\approx$ an ODE}} + \underbrace{g(t)\; \mathrm d\mathbf{w}_t,}_{\textrm{Stochastic term}} \end{equation} where \(f, g\) are real-valued functions, \(\{\bfw_t\}_{t \in [0, T]}\) is the standard Wiener process on time \([0, T]\), and \(\mathrm d\bfw_t\) can be thought of as infinitesimal white noise. The drift coefficient \(f(t)\bfx_t\) is the deterministic part of the SDE and \(f(t)\bfx_t\;\mathrm dt\) can be thought of as the ODE term of the SDE. Conversely, the diffusion coefficient \(g(t)\) is the stochastic part of the SDE which controls how much noise is injected into the system.
The solution to this SDE is a continuous collection of random variables \(\{\bfx_t\}_{t \in [0, T]}\) over the real interval \([0, T]\), these random variables trace stochastic trajectories over the time interval. Let \(p_t(\bfx_t)\) denote the marginal probability density function of \(\bfx_t\). Then \(p_0(\bfx_0) = p_{\textrm{data}}(\bfx)\) is the data distribution, likewise, for some sufficiently large \(T \in \R\) the terminal distribution \(p_T(\bfx_T)\) is close to some tractable noise distribution \(\pi(\bfx)\), called the prior distribution.
So far we have only covered how to destroy data by perturbing it with white noise, however, for sampling we need to be able reverse this process to create data from noise. Remarkably, Anderson
To train a diffusion model, then, we just need to learn the score function via score-matching
Song et al.
One of key benefits of expressing diffusion models in ODE form is that ODEs are easily reversible, by simply integrating forwards and backwards in time we can encode images from \(p_0(\bfx_0)\) into \(p_T(\bfx_T)\) and back again. With a neural network, often a U-Net
Researchers have proposed many ways to perform guided generation with diffusion models. Outside of directly conditioning the noise-prediction network on additional latent information Dhariwal and Nichol proposed classifier guidance
Outside of methods which require the additional to the diffusion model, or some external network, there are training-free methods which we broadly categorize into the following two categories:
The second solution of techniques is related to our initial problem statement in the introduction from Equation \eqref{eq:opt_init}. We reframe this problem for the specific case of diffusion ODEs.
Problem statement. Given the diffusion ODE in Equation \eqref{eq:empirical_pf_ode}, we wish to solve the following optimization problem: \begin{equation} \label{eq:problem_stmt_ode} \argmin_{\bfx_T, \bfz, \theta}\quad \mathcal{L}\bigg(\bfx_T + \int_T^0 f(t)\bfx_t + \frac{g^2(t)}{2\sigma_t}\bseps_\theta(\bfx_t, \bfz, t)\;\rmd t\bigg). \end{equation} N.B., without loss of generality we let \(\bseps_\theta(\bfx_t, \bfz, t)\) denote a noise-prediction network conditioned either directly on \(\bfz\) or as the classifier-free guidance model \(\tilde \bseps_\theta(\bfx_t, \bfz, t)\).
From this formulation it is readily apparent the difficulty introduced by diffusion models, over say other methods like GANs or VAEs, is that we need to perform backpropagation through an ODE solve. Luckily, diffusion models are a type of Neural ODE
The technique of solving an adjoint backwards-in-time ODE to calculate the gradients of an ODE is widely used and widespread technique initially proposed by Pontryagin et al.
We can write the diffusion ODE as a Neural ODE of the form: \(\begin{equation} \frac{\rmd \bfx_t}{\rmd t} = \bsf_\theta(\bfx_t, \bfz, t) := f(t)\bfx_t + \frac{g^2(t)}{2\sigma_t}\bseps(\bfx_t, \bfz, t). \end{equation}\) Then \(\bsf_\theta(\bfx_t, \bfz, t)\) and assuming \(\bsf_\theta\) is continuous in $t$ and uniformly Lipschitz in \(\bfx\),
While this formulation can calculate the desired gradients to solve the optimization problem it, however, fails to account of the unique construction of diffusion models in particular the special formulation of \(f\) and \(g\). Recent work
In the literature of diffusion models the sampling process is often done in reverse-time, i.e., the initial noise is \(\bfx_T\) and the final sample is \(\bfx_0\). Due to this convention solving the adjoint diffusion ODE backwards actually means integrating forwards in time. Thus while diffusion models learn to compute \(\bfx_t\) from \(\bfx_s\) with \(s > t\), the adjoint diffusion ODE seeks to compute \(\bfa_\bfx(s)\) from \(\bfa_\bfx(t)\).
Recent work on efficient ODE solvers for diffusion models
The continuous adjoint equation for \(\bfa_\bfx(t)\) in Equation \eqref{eq:adjoint_ode} can be rewritten as \(\begin{equation} \label{eq:empirical_adjoint_ode} \frac{\mathrm d\bfa_\bfx}{\mathrm dt}(t) = -f(t)\bfa_\bfx(t) - \frac{g^2(t)}{2\sigma_t}\bfa_\bfx(t)^\top \frac{\partial \bseps_\theta(\bfx_t, \bfz, t)}{\partial \bfx_t}. \end{equation}\)
Due to the gradient of the drift term in Equation \eqref{eq:empirical_adjoint_ode}, further manipulations are required to put the empirical adjoint probability flow ODE into a sufficiently ``nice’’ form. We can transform this stiff ODE into a non-stiff form by applying the integrating factor \(\exp\big({\int_0^t f(\tau)\;\mathrm d\tau}\big)\) to Equation \eqref{eq:empirical_adjoint_ode}, which is expressed as: \(\begin{equation} \label{eq:empirical_adjoint_ode_IF} \frac{\mathrm d}{\mathrm dt}\bigg[e^{\int_0^t f(\tau)\;\mathrm d\tau} \bfa_\bfx(t)\bigg] = -e^{\int_0^t f(\tau)\;\mathrm d\tau} \frac{g^2(t)}{2\sigma_t}\bfa_\bfx(t)^\top \frac{\partial \bseps_\theta(\bfx_t, \bfz, t)}{\partial \bfx_t}. \end{equation}\) Then, the exact solution at time \(s\) given time \(t < s\) is found to be \(\begin{align} \bfa_\bfx(s) = \underbrace{\vphantom{\int_t^s}e^{\int_s^t f(\tau)\;\mathrm d\tau} \bfa_\bfx(t)}_{\textrm{linear}} - \underbrace{\int_t^s e^{\int_s^u f(\tau)\;\mathrm d\tau} \frac{g^2(u)}{2\sigma_u} \bfa_\bfx(u)^\top \frac{\bseps_\theta(\bfx_u, \bfz, u)}{\partial \bfx_u}\;\rmd u}_{\textrm{non-linear}}. \label{eq:empirical_adjoint_ode_x} \end{align}\) With this transformation we can compute the linear in closed form, thereby eliminating the discretization error in the linear term. However, we still need to approximate the non-linear term which consists of a difficult integral about the complex noise-prediction model. This is where the insight of Lu et al.
Proposition (Exact solution of adjoint diffusion ODEs). Given initial values \([\bfa_\bfx(t), \bfa_\bfz(t), \bfa_\theta(t)]\) at time \(t \in (0,T)\), the solution \([\bfa_\bfx(s), \bfa_\bfz(s), \bfa_\theta(s)]\) at time \(s \in (t, T]\) of adjoint diffusion ODEs in Equation \eqref{eq:adjoint_ode} is \(\begin{align} \label{eq:exact_sol_empirical_adjoint_ode_x} \bfa_\bfx(s) &= \frac{\alpha_t}{\alpha_s} \bfa_\bfx(t) + \frac{1}{\alpha_s}\int_{\lambda_t}^{\lambda_s} \alpha_\lambda^2 e^{-\lambda} \bfa_\bfx(\lambda)^\top \frac{\partial \bseps_\theta(\bfx_\lambda, \bfz, \lambda)}{\partial \bfx_\lambda}\;\rmd \lambda,\\ \label{eq:exact_sol_empirical_adjoint_ode_z} \bfa_\bfz(s) &= \bfa_\bfz(t) + \int_{\lambda_t}^{\lambda_s}\alpha_\lambda e^{-\lambda} \bfa_\bfx(\lambda)^\top \frac{\partial \boldsymbol\epsilon_\theta(\bfx_\lambda, \bfz, \lambda)}{\partial \bfz}\;\rmd\lambda,\\ \label{eq:exact_sol_empirical_adjoint_ode_theta} \bfa_\theta(s) &= \bfa_\theta(t) + \int_{\lambda_t}^{\lambda_s}\alpha_\lambda e^{-\lambda} \bfa_\bfx(\lambda)^\top \frac{\partial \boldsymbol\epsilon_\theta(\bfx_\lambda, \bfz, \lambda)}{\partial \theta}\;\rmd\lambda. \end{align}\)
Now that we have a simplified formulation of the continuous adjoint equations we can construct bespoke numerical solvers. To do this we take approximate the integral term via a Taylor expansion which we illustrate for Equation \eqref{eq:exact_sol_empirical_adjoint_ode_x}.For \(k \geq 1\) a \((k-1)\)-th Taylor expansion of the scaled vector Jacobian about \(\lambda_t\) is equal to
For notational convenience we denote the $n$-th order derivative of scaled vector-Jacobian products at \(\lambda_t\) as \(\begin{equation} \label{eq:app:vjp_def_x} \bfV^{(n)}(\bfx; \lambda_t) = \frac{\rmd^n}{\rmd \lambda^n}\bigg[\alpha_\lambda^2\bfa_\bfx(\lambda)^\top \frac{\partial \bseps_\theta(\bfx_\lambda, \bfz, \lambda)}{\partial \bfx_\lambda}\bigg]_{\lambda = \lambda_t}. \end{equation}\) Then substituting our Taylor expansion into Equation \eqref{eq:exact_sol_empirical_adjoint_ode_x} and letting \(h = \lambda_s - \lambda_t\) denote the step size we have a \(k\)-th order solver for the continuous adjoint equation for \(\bfa_\bfx(t)\):
Let’s break this down term by term.
Linear term. The linear term of the adjoint diffusion ODE can be calculated exactly using ratio of the signal schedule \(\alpha_t / \alpha_s\). As \(\alpha_t \geq \alpha_s\) for \(t \leq s\) this implies \(\alpha_t / \alpha_s \geq 1\). \(\begin{equation*} \bfa_\bfx(s) = {\color{orange}\underbrace{ \vphantom{\int_{\lambda_t}^{\lambda_s}} \frac{\alpha_t}{\alpha_s}\bfa_\bfx(t) }_{\substack{\textrm{Linear term}\\\textbf{Exactly computed}}}} +\frac{1}{\alpha_s} \sum_{n=0}^{k-1} \underbrace{ \vphantom{\int_{\lambda_t}^{\lambda_s}} \bfV^{(n)}(\bfx; \lambda_t) }_{\substack{\textrm{Derivatives}\\\textbf{Approximated}}}\; \underbrace{ \int_{\lambda_t}^{\lambda_s} \frac{(\lambda - \lambda_t)^n}{n!} e^{-\lambda}\;\mathrm d\lambda }_{\substack{\textrm{Coefficients}\\\textbf{Analytically computed}}} + \underbrace{ \vphantom{\int_{\lambda_t}^{\lambda_s}} \mathcal{O}(h^{k+1}). }_{\substack{\textrm{Higher-order errors}\\\textbf{Omitted}}} \end{equation*}\)
Derivatives. The \(n\)-th order derivatives of scaled vector-Jacobian product can be efficiently estimated using multi-step methods
Coefficients. The exponentially weighted integral can be analytically computed in closed form. \(\begin{equation*} \bfa_\bfx(s) = \underbrace{ \vphantom{\int_{\lambda_t}^{\lambda_s}} \frac{\alpha_t}{\alpha_s}\bfa_\bfx(t) }_{\substack{\textrm{Linear term}\\\textbf{Exactly computed}}} +\frac{1}{\alpha_s} \sum_{n=0}^{k-1} \underbrace{ \vphantom{\int_{\lambda_t}^{\lambda_s}} \bfV^{(n)}(\bfx; \lambda_t) }_{\substack{\textrm{Derivatives}\\\textbf{Approximated}}}\; {\color{orange}\underbrace{ \int_{\lambda_t}^{\lambda_s} \frac{(\lambda - \lambda_t)^n}{n!} e^{-\lambda}\;\mathrm d\lambda }_{\substack{\textrm{Coefficients}\\\textbf{Analytically computed}}}} + \underbrace{ \vphantom{\int_{\lambda_t}^{\lambda_s}} \mathcal{O}(h^{k+1}). }_{\substack{\textrm{Higher-order errors}\\\textbf{Omitted}}} \end{equation*}\)
Higher-order errors. The remaining higher-order error terms are discarded. If \(h^{k+1}\) is sufficiently small than these errors are negligible. \(\begin{equation*} \bfa_\bfx(s) = \underbrace{ \vphantom{\int_{\lambda_t}^{\lambda_s}} \frac{\alpha_t}{\alpha_s}\bfa_\bfx(t) }_{\substack{\textrm{Linear term}\\\textbf{Exactly computed}}} +\frac{1}{\alpha_s} \sum_{n=0}^{k-1} \underbrace{ \vphantom{\int_{\lambda_t}^{\lambda_s}} \bfV^{(n)}(\bfx; \lambda_t) }_{\substack{\textrm{Derivatives}\\\textbf{Approximated}}}\; \underbrace{ \int_{\lambda_t}^{\lambda_s} \frac{(\lambda - \lambda_t)^n}{n!} e^{-\lambda}\;\mathrm d\lambda }_{\substack{\textrm{Coefficients}\\\textbf{Analytically computed}}} + {\color{orange}\underbrace{ \vphantom{\int_{\lambda_t}^{\lambda_s}} \mathcal{O}(h^{k+1}). }_{\substack{\textrm{Higher-order errors}\\\textbf{Omitted}}}} \end{equation*}\)
The exponentially weighted integral can be solved analytically by applying \(n\) times integration by parts
From this construction there are only two-sources of error. The error in approximating the \(n\)-th order derivative of the vector-Jacobian and the higher-order errors. Therefore, as we long as we pick a sufficiently small step size, \(h\), and appropriate order, \(k\), we can achieve accurate (enough) estimates of the gradients. The derivations for the solvers of \(\bfa_\bfz(t)\) and \(\bfa_\theta(t)\) are omitted for brevity but follow an analogous derivation. The \(k\)-th order solvers resulting from this method are called AdjointDEIS-\(k\). In
Consider the case when \(k=1\) then we have the following first-order solver.
AdjointDEIS-1. Given an initial augmented adjoint state \([\bfa_\bfx(t), \bfa_\bfz(t), \bfa_\theta(t)]\) at time \(t \in (0, T)\), the solution \([\bfa_\bfx(s), \bfa_\bfz(s), \bfa_\theta(s)]\) at time \(s \in (t, T]\) is approximated by \(\begin{align} \bfa_\bfx(s) &= \frac{\alpha_t}{\alpha_s}\bfa_\bfx(t) + \sigma_s (e^h - 1) \frac{\alpha_t^2}{\alpha_s^2}\bfa_\bfx(t)^\top \frac{\partial \bseps(\bfx_t, \bfz, t)}{\partial \bfx_t},\nonumber\\ \bfa_\bfz(s) &= \bfa_\bfz(t) + \sigma_s (e^h - 1) \frac{\alpha_t}{\alpha_s}\bfa_\bfx(t)^\top \frac{\partial \bseps(\bfx_t, \bfz, t)}{\partial \bfz},\nonumber\\ \bfa_\theta(s) &= \bfa_\theta(t) + \sigma_s (e^h - 1) \frac{\alpha_t}{\alpha_s}\bfa_\bfx(t)^\top \frac{\partial \bseps(\bfx_t, \bfz, t)}{\partial \theta}. \label{eq:adjoint_deis_1_at} \end{align}\)
The vector-Jacobian product can be easily calculated using reverse-mode automatic differentiation provided by most modern ML frameworks. We illustrate an implementation of this first-order solver using PyTorch. For simplicity we omit the code for calculating \(\bfa_\theta\) as it requires more boilerplate code.
What about diffusion SDEs, the problem statement in Equation \eqref{eq:problem_stmt_ode} would become \(\begin{equation} \label{eq:problem_stmt_sde} \argmin_{\bfx_T, \bfz, \theta}\quad \mathcal{L}\bigg(\bfx_T + \int_T^0 f(t)\bfx_t + \frac{g^2(t)}{\sigma_t}\bseps_\theta(\bfx_t, \bfz, t)\;\rmd t + \int_T^0 g(t) \; \rmd \bar\bfw_t\bigg). \end{equation}\) The technical details of working with SDEs are beyond the scope of this post; however, we will highlight one of the key insights from our work
Suppose we have an SDE in the Stratonovich sense of the form \(\begin{equation} \label{eq:stratonovich_sde} \rmd \bfx_t = \bsf(\bfx_t, t)\;\rmd t + \bsg(t) \circ \rmd \bfw_t \end{equation}\) where \(\circ \rmd \bfw_t\) denotes integration in the Stratonovich sense and \(\bsf \in \mathcal{C}_b^{\infty, 1}(\R^d)\), i.e., \(\bsf\) is continuous function to \(\R^d\) and has infinitely many bounded derivatives w.r.t. the state and bounded first derivatives w.r.t. to time. Likewise, let \(\bsg \in \mathcal{C}_b^1(\R^{d \times w})\) be a continuous function with bounded first derivatives. Lastly, let \(\bfw_t: [0,T] \to \R^w\) be a \(w\)-dimensional Wiener process. Then Equation \eqref{eq:stratonovich_sde} has unique strong solution given by \(\bfx_t: [0, T] \to \R^d\).
We show in
Remark. While the adjoint state evolves with an ODE the underlying state \(\bfx_t\) still evolves with a backwards-in-time SDE! This was the reason for our choice of Stratonovich over Itô as the Stratonovich integral is symmetric.
Now our diffusion SDE can be easily converted into Stratonovich form due to the diffusion coefficient depending only on time. Moreover, due to the shared derivation using the Kolmogorov equations in constructing diffusion SDEs and diffusion ODEs, the two forms differ only by a factor of 2 within the drift term. \(\begin{equation} {\color{orange}\underbrace{\rmd \bfx_t = f(t)\bfx_t + {\color{cyan}2} \frac{g^2(t)}{2\sigma_t} \bseps_\theta(\bfx_t, \bfz, t)\;\rmd t}_{\textrm{Diffusion ODE}}} + {\color{cyan}g(t)\circ\rmd\bar\bfw_t.} \end{equation}\) Furthermore, notice that SDE has form \(\begin{equation} \rmd \bfx_t = {\color{orange}\underbrace{f(t)\bfx_t + \frac{g^2(t)}{\sigma_t} \bseps_\theta(\bfx_t, \bfz, t)}_{= \bsf_\theta(\bfx_t,\bfz, t)}}\;\rmd t + g(t)\circ\rmd\bar\bfw_t. \end{equation}\) and then by our result from Equation \eqref{eq:sde_is_ode} the adjoint diffusion SDE evolves with the following ODE \(\begin{equation} \frac{\rmd \bfa_\bfx}{\rmd t}(t) = -\bfa_\bfx(t)^\top \frac{\partial \bsf_\theta(\bfx_t, \bfz, t)}{\partial \bfx_t}. \end{equation}\)
As the only difference between \(\bsf_\theta\) for diffusion SDEs and ODEs are a factor of 2 we realize that:
We can use the same ODE solvers for adjoint diffusion SDEs!
With the only caveat being the factor of 2. Therefore, we can modify the update equations from our code from above to now solve adjoint diffusion SDEs.
This blog post gives a detailed introduction to the continuous adjoint equations. We discuss the theory behind them and why it is an appropriate tool for solving guided generation problems for diffusion models. This post serves as a summary for our recent NeurIPS paper:
For examples of this technique used in practice check out our full paper and concurrent work from our colleagues