Gradients for Time Scheduled Conditional Variables in Neural Differential Equations

A short derivation of the continuous adjoint equation for time scheduled conditional variables.

\newtheorempropositionProposition

Introduction

The advent of large-scale diffusion models conditioned on text embeddings has allowed for creative control over the generative process. A recent and powerful technique is that of prompt scheduling, i.e., instead of passing a fixed prompt to the diffusion model, the prompt can changed depending on the timestep. This concept was initially proposed by Doggettx in this reddit post and the code changes to the stable diffusion repository can be seen here.

Examples of the prompt scheduling technique proposed by Doggettx.

More generally, we can view this as have the conditional information (in this case text embeddings) scheduled w.r.t. time. Formally, assume we have a U-Net trained on the noise-prediction task ϵθ(xt,z,t) conditioned on a time scheduled text embedding z(t). The sampling procedure amounts to solving the probability flow ODE from time T to time 0. (1)dxtdt=f(t)xt+g2(t)2σtϵθ(xt,z(t),t), where f,g define the drift and diffusion coefficients of a Variance Preserving (VP) type SDE .

Training-free guidance

A closely related area of active research has been the development of techniques which search of the optimal generation parameters.

More specifically, they attempt to solve the following optimization problem: (2)argminxT,z,θL(xT+T0f(t)xt+g2(t)2σtϵθ(xt,z,t)dt), where L is a real-valued loss function on the output x0.

Several recent works this year explore solving the continuous adjoint equations to find the gradients: (3)Lxt,Lz,Lθ. These gradients can the be used in combination with gradient descent algorithms to solve the optimization problem. However, what if z is scheduled and not constant w.r.t to time?

Problem statement. Given (4)x0=xT+T0f(t)xt+g2(t)2σtϵθ(xt,z(t),t)dt, and L(x0), find: (5)Lz(t),t[0,T].

In an earlier blog post we showed how to find L/z by solving the continuous adjoint equations. How do the continuous adjoint equations change with replacing z with time scheduled z(t) in the sampling equation? What we will now show is that

We can just simply replace z with z(t) in the continuous adjoint equations.

This result will intuitive, does require some technical details to show.

Gradients of time-scheduled conditional variables

It is well known that diffusion models are just a special type of neural differential equation, either a neural ODE or SDE. As such we will show this result holds more generally for neural ODEs.

Theorem (Continuous adjoint equations for time scheduled conditional variables). Suppose there exists a function z:[0,T]Rz which can be defined as a càdlàgFrench: continue à droite, limite à gauche. piecewise function where z is continuous on each partition of [0,T] given by Π={0=t0<t1<<tn=T} and whose right derivatives exists for all t[0,T]. Let fθ:Rd×Rz×[0,T]Rd be continuous in t, uniformly Lipschitz in y, and continuously differentiable in y. Let y:[0,T]Rd be the unique solution for the ODE (6)dydt(t)=fθ(y(t),z(t),t), with initial condition y(0)=y0. Then L/z(t):=az(t) and there exists a unique solution az:[0,T]Rz to the following initial value problem: (7)az(T)=0,dazdt(t)=ay(t)fθ(y(t),z(t),t)z(t).

Why càdlàg?

In practice z(t) is often a discrete set {zk}k=1n where n corresponds to the number of discretization steps the numerical ODE solver takes. While the proof is easier for a continuously differentiable function z(t) we opt for this construction for the sake of generality. We choose a càdlàg piecewise function, a relatively mild assumption, to ensure that the we can define the augmented state on each continuous interval of the piecewise function in terms of the right derivative.

In the remainder of this blog post will provide the proof of this result. Our proof technique is an extension of the one used by Patrick Kidger (Appendix C.3.1) used to prove the existence to the solution to the continuous adjoint equations for neural ODEs.

Proof. Recall that z(t) is a piecewise function of time with partition of the time domain Π. Without loss of generality we consider some time interval π=[tm1,tm] for some 1mn. Consider the augmented state defined on the interval π: (8)ddt[yz](t)=faug=[fθ(yt,zt,t)z(t)], where z(t):[0,T]Rz denotes the right derivative of z at time t. Let aaug denote the augmented state as (9)aaug(t):=[ayaz](t). Then the Jacobian of faug is defined as (10)faug[y,z]=[fθ(y,z,t)yfθ(y,z,t)z00]. As the state z(t) evolves with z(t) on the interval [tm1,tm] in the forward direction the derivative of this augmented vector field w.r.t. z is clearly 0 as it only depends on time. Remark, as the bottom row of the Jacobian of faug is all 0 and fθ is continuous in t we can consider the evolution of aaug over the whole interval [0,T] rather than just a partition of it. The evolution of the augmented adjoint state on [0,T] is then given as (11)daaugdt(t)=[ayaz](t)faug[y,z](t). Therefore, az(t) is a solution to the initial value problem: (12)az(T)=0,dazdt(t)=ay(t)fθ(y(t),z(t),t)z(t).

Next we show that there exist a unique solution to the initial value problem. Now as y is continuous and fθ is continuously differentiable in y it follows that tfθy(y(t),z(t),t) is a continuous function on the compact set [tm1,tm]. As such it is bounded by some L>0. Likewise, for ayRd the map (t,ay)ayfθ[y,z](y(t),z(t),t) is Lipschitz in ay with Lipschitz constant L and this constant is independent of t. Therefore, by the Picard-Lindelöf theorem the solution aaug(t) exists and is unique.


If you found this useful and would like to cite this post in academic context, please cite this as:

Blasingame, Zander W. (Dec 2024). Gradients for Time Scheduled Conditional Variables in Neural Differential Equations. https://zblasingame.github.io.

or as a BibTeX entry:

@article{blasingame2024gradients-for-time-scheduled-conditional-variables-in-neural-differential-equations,
  title   = {Gradients for Time Scheduled Conditional Variables in Neural Differential Equations},
  author  = {Blasingame, Zander W.},
  year    = {2024},
  month   = {Dec},
  url     = {https://zblasingame.github.io/blog/2024/cadlag-conditional/}
}

Footnotes

  1. French: continue à droite, limite à gauche.[↩]

References

  1. High-Resolution Image Synthesis With Latent Diffusion Models
    Rombach, R., Blattmann, A., Lorenz, D., Esser, P. and Ommer, B., 2022. Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 10684-10695.
  2. {Hierarchical Text-Conditional Image Generation with CLIP Latents}
    {Ramesh}, A., {Dhariwal}, P., {Nichol}, A., {Chu}, C. and {Chen}, M., 2022. arXiv e-prints, pp. arXiv:2204.06125. DOI: 10.48550/arXiv.2204.06125
  3. Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding[PDF]
    Saharia, C., Chan, W., Saxena, S., Li, L., Whang, J., Denton, E.L., Ghasemipour, K., Gontijo Lopes, R., Karagol Ayan, B., Salimans, T., Ho, J., Fleet, D.J. and Norouzi, M., 2022. Advances in Neural Information Processing Systems, Vol 35, pp. 36479--36494. Curran Associates, Inc.
  4. Denoising Diffusion Implicit Models[link]
    Song, J., Meng, C. and Ermon, S., 2021. International Conference on Learning Representations.
  5. AdjointDPM: Adjoint Sensitivity Method for Gradient Backpropagation of Diffusion Probabilistic Models[link]
    Pan, J., Liew, J.H., Tan, V., Feng, J. and Yan, H., 2024. The Twelfth International Conference on Learning Representations.
  6. AdjointDEIS: Efficient Gradients for Diffusion Models[link]
    Blasingame, Z.W. and Liu, C., 2024. The Thirty-eighth Annual Conference on Neural Information Processing Systems.
  7. Implicit Diffusion: Efficient Optimization through Stochastic Sampling
    Marion, P., Korba, A., Bartlett, P., Blondel, M., De Bortoli, V., Doucet, A., Llinares-Lopez, F., Paquette, C. and Berthet, Q., 2024. arXiv preprint arXiv:2402.05468.
  8. On Neural Differential Equations
    Kidger, P., 2022.