34) Differentiation#
Last time#
Gradient descent
Nonlinear models
Today#
Computing derivatives
1.1 Numeric
1.2 Analytic by hand
1.3 Algorithmic (automatic) differentiationRecap on Finite Differences
Ill-conditioned optimization
1. Computing derivatives#
We know the definition of the difference quotient from Calculus:
How should we choose \(h\)?
Taylor series#
Classical accuracy analysis assumes that functions are sufficiently smooth, meaning that derivatives exist and Taylor expansions are valid within a neighborhood. In particular,
The big-\(O\) notation is meant in the limit \(h\to 0\). Specifically, a function \(g(h) \in O(h^p)\) (sometimes written \(g(h) = O(h^p)\)) when there exists a constant \(C\) such that
Rounding error#
We have an additional source of error, rounding error, which comes from not being able to compute \(f(x)\) or \(f(x+h)\) exactly, nor subtract them exactly. Suppose that we can, however, compute these functions with a relative error on the order of \(\epsilon_{\text{machine}}\). This leads to
Tedious error propagation#
where we have assumed that \(h \ge \epsilon_{\text{machine}}\). This error becomes large (relative to \(f'\) – we are concerned with relative error after all).
What can be problematic:
\(f\) is large compared to \(f'\)
\(x\) is large
\(h\) is too small
Automatic step size selection#
Reference: Numerical Optimization.
Walker and Pernice
Dennis and Schnabel
diff(f, x; h=1e-8) = (f(x+h) - f(x)) / h
function diff_wp(f, x; h=1e-8)
"""Diff using Walker and Pernice (1998) choice of step"""
h *= (1 + abs(x))
(f(x+h) - f(x)) / h
end
x = 1000
diff_wp(sin, x) - cos(x)
-4.139506408429305e-6
1.1 Symbolic differentiation#
using Pkg
Pkg.add("Symbolics")
using Symbolics
@variables x
Dx = Differential(x)
y = sin(x)
Dx(y)
Resolving package versions...
Installed TimerOutputs ──────────── v0.5.29
Installed CompositeTypes ────────── v0.1.4
Installed Bijections ────────────── v0.2.2
Installed MultivariatePolynomials ─ v0.5.9
Installed DomainSets ────────────── v0.7.15
Installed DynamicPolynomials ────── v0.6.2
Installed SymbolicLimits ────────── v0.2.2
Installed Combinatorics ─────────── v1.0.3
Installed MutableArithmetics ────── v1.6.4
Installed SciMLBase ─────────────── v2.102.0
Installed Symbolics ─────────────── v6.41.0
Installed IntegerMathUtils ──────── v0.1.2
Installed TaskLocalValues ───────── v0.1.2
Installed TermInterface ─────────── v2.0.0
Installed WeakValueDicts ────────── v0.1.0
Installed Primes ────────────────── v0.5.7
Installed SymbolicUtils ─────────── v3.29.0
Installed Tables ────────────────── v1.12.1
Installed Unityper ──────────────── v0.1.6
Updating `~/.julia/environments/v1.10/Project.toml`
[0c5d862f] + Symbolics v6.41.0
Updating `~/.julia/environments/v1.10/Manifest.toml`
[47edcb42] + ADTypes v1.14.0
[1520ce14] + AbstractTrees v0.4.5
[7d9f7c33] + Accessors v0.1.42
[4fba245c] + ArrayInterface v7.19.0
[e2ed5e7c] + Bijections v0.2.2
[d360d2e6] + ChainRulesCore v1.25.1
[861a8166] + Combinatorics v1.0.3
[38540f10] + CommonSolve v0.2.4
[f70d9fcc] + CommonWorldInvalidations v1.0.0
[b152e2b5] + CompositeTypes v0.1.4
[a33af91c] + CompositionsBase v0.1.2
[a8cc5b0e] + Crayons v4.1.1
[e2d170a0] + DataValueInterfaces v1.0.0
[b552c78f] + DiffRules v1.15.1
[31c24e10] + Distributions v0.25.120
[5b8099bc] + DomainSets v0.7.15
[7c1d4256] + DynamicPolynomials v0.6.2
[4e289a0a] + EnumX v1.0.5
[e2ba6199] + ExprTools v0.1.10
[55351af7] + ExproniconLite v0.10.14
[1a297f60] + FillArrays v1.13.0
[069b7b12] + FunctionWrappers v1.1.3
[77dc65aa] + FunctionWrappersWrappers v0.1.3
[46192b85] + GPUArraysCore v0.2.0
[34004b35] + HypergeometricFunctions v0.3.28
[18e54dd8] + IntegerMathUtils v0.1.2
[8197267c] + IntervalSets v0.7.11
[3587e190] + InverseFunctions v0.1.17
[82899510] + IteratorInterfaceExtensions v1.0.0
[ae98c720] + Jieko v0.2.1
[2e0e35c7] + Moshi v0.3.5
[102ac46a] + MultivariatePolynomials v0.5.9
[d8a4904e] + MutableArithmetics v1.6.4
[6fe1bfb0] + OffsetArrays v1.17.0
[90014a1f] + PDMats v0.11.35
[08abe8d2] + PrettyTables v2.4.0
[27ebfcd6] + Primes v0.5.7
[1fd47b50] + QuadGK v2.11.2
[731186ca] + RecursiveArrayTools v3.33.0
[79098fc4] + Rmath v0.8.0
[7e49a35a] + RuntimeGeneratedFunctions v0.5.15
[0bca4576] + SciMLBase v2.102.0
[c0aeaf25] + SciMLOperators v1.3.1
[53ae85a6] + SciMLStructures v1.7.0
[4c63d2b9] + StatsFuns v1.5.0
[892a3eda] + StringManipulation v0.4.1
[2efcf032] + SymbolicIndexingInterface v0.3.40
[19f23fe9] + SymbolicLimits v0.2.2
[d1185830] + SymbolicUtils v3.29.0
[0c5d862f] + Symbolics v6.41.0
[3783bdb8] + TableTraits v1.0.1
[bd369af6] + Tables v1.12.1
[ed4db957] + TaskLocalValues v0.1.2
[8ea1fca8] + TermInterface v2.0.0
[a759f4b9] + TimerOutputs v0.5.29
[a7c27f48] + Unityper v0.1.6
[897b6980] + WeakValueDicts v0.1.0
[f50d1b31] + Rmath_jll v0.5.1+0
[8ba89e20] + Distributed
[4607b0f0] + SuiteSparse
Precompiling
packages...
404.7 ms ✓ IntegerMathUtils
393.7 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
525.1 ms ✓ ADTypes → ADTypesConstructionBaseExt
558.0 ms ✓ TaskLocalValues
597.7 ms ✓ TermInterface
603.2 ms ✓ CompositeTypes
605.3 ms ✓ WeakValueDicts
615.9 ms ✓ Bijections
632.3 ms ✓ Tables
689.4 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
749.6 ms ✓ Unityper
992.3 ms ✓ Combinatorics
722.5 ms ✓ Unitful → InverseFunctionsUnitfulExt
585.6 ms ✓ Primes
1238.2 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
946.8 ms ✓ Polynomials → PolynomialsChainRulesCoreExt
859.4 ms ✓ HypergeometricFunctions
1633.4 ms ✓ TimerOutputs
1896.6 ms ✓ Accessors
1757.2 ms ✓ DomainSets
467.8 ms ✓ Accessors → LinearAlgebraExt
1366.9 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
494.0 ms ✓ Accessors → IntervalSetsExt
1625.3 ms ✓ StatsFuns
705.8 ms ✓ Accessors → TestExt
750.1 ms ✓ Accessors → StaticArraysExt
965.7 ms ✓ Accessors → UnitfulExt
550.8 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
3621.6 ms ✓ MutableArithmetics
1504.5 ms ✓ SciMLOperators
1206.6 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
393.2 ms ✓ SciMLOperators → SciMLOperatorsStaticArraysCoreExt
638.3 ms ✓ SciMLOperators → SciMLOperatorsSparseArraysExt
1053.1 ms ✓ Polynomials → PolynomialsMutableArithmeticsExt
1902.0 ms ✓ MultivariatePolynomials
1366.5 ms ✓ DynamicPolynomials
4017.4 ms ✓ Distributions
946.4 ms ✓ Distributions → DistributionsChainRulesCoreExt
951.1 ms ✓ Distributions → DistributionsTestExt
10673.3 ms ✓ PrettyTables
1332.9 ms ✓ SymbolicIndexingInterface
1634.3 ms ✓ RecursiveArrayTools
790.9 ms ✓ RecursiveArrayTools → RecursiveArrayToolsSparseArraysExt
expand_derivatives(Dx(y))
Awesome, what about products?
y = x
for _ in 1:2
y = cos(y^pi) * log(y)
end
expand_derivatives(Dx(y))
The size of these expressions can grow exponentially
1.2 Hand-coding (analytic) derivatives#
function f(x)
y = x
for _ in 1:2
a = y^pi
b = cos(a)
c = log(y)
y = b * c
end
y
end
f(1.9), diff_wp(f, 1.9)
function df(x, dx)
y = x
dy = dx
for _ in 1:2
a = y ^ pi
da = pi * y^(pi - 1) * dy
b = cos(a)
db = -sin(a) * da
c = log(y)
dc = dy / y
y = b * c
dy = db * c + b * dc
end
y, dy
end
df(1.9, 1)
We can go the other way#
We can differentiate a composition \(h(g(f(x)))\) as
What we’ve done above is called “forward mode”, and amounts to placing the parentheses in the chain rule like
The expression means the same thing if we rearrange the parentheses,
which we can compute in reverse order via
A reverse mode example#
function g(x)
a = x^pi
b = cos(a)
c = log(x)
y = b * c
y
end
(g(1.4), diff_wp(g, 1.4))
function gback(x, y_)
a = x^pi
b = cos(a)
c = log(x)
y = b * c
# backward pass
c_ = y_ * b
b_ = c * y_
a_ = -sin(a) * b_
x_ = 1/x * c_ + pi * x^(pi-1) * a_
x_
end
gback(1.4, 1)
1.3 Automatic differentiation: Zygote.jl#
using Pkg
Pkg.add("Zygote")
import Zygote
Zygote.gradient(f, 1.9)
g(x) = exp(x) + x^2
@code_llvm Zygote.gradient(g, 1.9)
How does Zygote work?#
square(x) = x^2
@code_llvm square(1.5)
@code_llvm Zygote.gradient(square, 1.5)
Kinds of algorithmic differentation#
Source transformation: Fortran code in, Fortran code out
Duplicates compiler features, usually incomplete language coverage
Produces efficient code
Operator overloading: C++ types
Hard to vectorize
Loops are effectively unrolled/inefficient
Just-in-time compilation: tightly coupled with compiler
JIT lag
Needs dynamic language features (JAX) or tight integration with compiler (Zygote, Enzyme)
Some sharp bits in Python’s JAX
Forward or reverse?#
It all depends on the shape.
If you have one input, many outputs: use forward mode
“One input” can be looking in one direction
If you have many inputs, one output: use reverse mode
Will need to traverse execution backwards (“tape”)
Hierarchical checkpointing
What about square cases (same number of input/output)? Forward is usually a bit more efficient.
Can you differentiate an algorithm?#
Examples:
Optimization: Input \(c\), output \(x\) such that \(f(x,c)=0\)
Finding eigenvalues/eigenvectors: Input \(A\), output \(\lambda\) such that \(Ax = \lambda x\) for some nonzero vector
2. Recap on Finite Differences#
Finite Differences#
To define derivatives, we use Taylor’s expansion:
Similarly,
Consider a uniformly discretized domain (with uniform \(h \equiv \Delta x\) subintervals)
We can define the first-order derivative at the point \(x_i\) using the forward difference:
this is a first-order approximation.
Similarly, we can define the first-order derivative at the point \(x_i\) using the backward difference:
this is a first-order approximation.
If we use one point to the right of \(x_i\) and one point to the left of \(x_i\) we have a centered difference approximation for the first-order derivative at the point \(x_i\) using the centered difference:
this is a second-order approximation.
Thus we note that the centered difference approximates the first derivative with respect to \(x\) more accurately than either of the one-sided differences, \(O( \Delta x^2 )\) versus \(\Delta x\).
We can now define a second-order derivative, at the point \(x_i\) using a centered difference formula:
3. Ill-conditioned optimization#
Recall that the computation of the gradient of the loss function \(L\) requires the Jacobian, denoted by \(J\), of the model \(f\) differentiated w. r. t. the constants \(c\).
We can find the constants \(c\) for which \(g(c) = 0\) using a Newton method
The Hessian requires the second derivative of \(f\), which can cause problems.
Newton-like methods for optimization#
We want to solve
Update
Gauss-Newton: \(H = J^T J\)
Levenberg-Marquardt: \(H = J^T J + \alpha J\)
Outlook on optimization#
The optimization problem can be solved using a Newton method. It can be onerous to implement the needed derivatives.
The Gauss-Newton method is often more practical than Newton while being faster than gradient descent, though it lacks robustness.
The Levenberg-Marquardt method provides a sort of middle-ground between Gauss-Newton and gradient descent.
Many globalization techniques are used for models that possess many local minima.
One pervasive approach is stochastic gradient descent, where small batches (e.g., 1, 10 or 20) are selected randomly from the corpus of observations (500 in the example we’ve seen with many realizations), and a step of gradient descent is applied to that reduced set of observations. This helps to escape saddle points and weak local minima.
Among expressive models \(f(x,c)\), some may converge much more easily than others. Having a good optimization algorithm is essential for nonlinear regression with complicated models, especially those with many parameters \(c\).
Classification is a very similar problem to regression, but the observations \(y\) are discrete, thus, in this case
models \(f(x,c)\) must have discrete output
the least squares loss function is not appropriate.
Reading: Why momentum really works