Jekyll2023-11-22T19:31:48+01:00https://danielewworrall.github.io/feed.xmlDaniel WorrallMachine Learning ResearcherDaniel Worralldaniel.e.w.worrall at gmail.comOn rooted trees and differentiation2023-11-22T00:00:00+01:002023-11-22T00:00:00+01:00https://danielewworrall.github.io/blog/2023/11/algebra-of-differentiation<h1 id="introduction">Introduction</h1>
<p>The chain rule lies at the heart of the backpropagation algorithm in deep learning. Unbeknownst to many though, the chain rule for higher order derivatives boasts a wealth of beautiful mathematical structure touching the theory of special rooted trees, group theory, combinatorics of integer partitions, order theory, and many others. I’ve been meaning to write this post for a long time, but in the last year work has been quite busy. I’m glad I can finally share with you the beautiful maths connecting special rooted trees and differentiation.</p>
<h3 id="the-chain-rule">The chain rule</h3>
<p>We start with a composition of functions
\begin{align}
\textbf{z} = f(g(\textbf{x}))
\end{align}
where $f$ and $g$ are vector-in vector-out functions. We can introduce an intermediate variable $\textbf{y} = g(\textbf{x})$ so that $\textbf{z} = f(\textbf{y})$. The derivative of $\textbf{z}$ with respect to $\textbf{x}$ is then
\begin{align}
\frac{\partial \textbf{z}}{\partial \textbf{x}} = \frac{\partial \textbf{z}}{\partial \textbf{y}} \frac{\partial \textbf{y}}{\partial \textbf{x}}.
\end{align}
In any contemporary machine learning masters course, this is about as far as we go. Couple the chain rule with dynamic programming and you get the backpropagation algorithm and forward-mode differentiation. And for most practitioners, we do not even need to know as much. With the advent of packages like <a href="https://jax.readthedocs.io/en/latest/notebooks/quickstart.html">JAX</a> all this machinery is hidden away. Well not today!</p>
<p>Now while vector notation is neat, it’s actually really unhelpful when we wish to do calculus. Each Jacobian in the above expression is a matrix and I always forget how to order the rows and columns properly. Furthermore, the following is going to involve a lot of vector derivatives, matrix derivatives, and higher order tensor derivatives, which can all be very unwieldy, so to ease notation we shall adopt index notation instead. As we shall see, switching up our notation frequently is going to help our understanding and aid our ability to generalize.</p>
<p>So using $z^i$ to denote the $i$th component of a vector $\textbf{z}$, we could write
\begin{align}
\frac{\partial z^i}{\partial x^j} = \sum_{\alpha} \frac{\partial z^i}{\partial y^\alpha} \frac{\partial y^\alpha}{\partial x^j}.
\end{align}
As a second notational step, we are going to denote differentiation of a function $h$ with respect to the $\alpha$th dimension of its input as $h_\alpha$. Notice we do not need to make reference to $y$ in this notation, since it is understood at we differentiate with respect to the input of $f$, however we might wish to label it. So
\begin{align}
\frac{\partial f^i}{\partial y^\alpha} = f^i_\alpha
\end{align}
The chain rule is then just
\begin{align}
\frac{\partial f^i}{\partial x^j} = \sum_{\alpha} f^i_\alpha g^\alpha_j.
\end{align}
Notice how there is one $\alpha$ on the bottom and one $\alpha$ on the top. For this reason, as one final notational convenience, we will switch to Einstein notation, where we implicitly sum over repeated indices in upper–lower pairs, so the chain rule is
\begin{align}
\frac{\partial f^i}{\partial x^j} = f^i_\alpha g^\alpha_j.
\end{align}
I have always found this notation both very elegant and parsimonious. Back in my PhD, before automatic differentiation was commonplace in machine learning, I would often use this notation to work out gradients, because it is both uncluttered and unconfusing.</p>
<p>You may have noticed that I am using Greek letters for the dummy variables we sum over. This is just a choice mainly for me to remember what we are summing over. With this highly compressed notation, let’s write the $2$nd derivative of $f^i$ with respect to $x$. It’s
\begin{align}
\frac{\partial^2 f^i}{\partial x^j \partial x^k} = f^i_{\alpha \beta} g^\alpha_j g^\beta_k + f^i_{\alpha} g^\alpha_{jk}.
\end{align}
The 3th derivative is
\begin{align}
\frac{\partial^3 f^i}{\partial x^j \partial x^k \partial x^\ell} &= f^i_{\alpha \beta \gamma} g^\alpha_j g^\beta_k g^\gamma_\ell + f^i_{\alpha \beta} g^\alpha_{j\ell} g^\beta_k + + f^i_{\alpha \beta} g^\alpha_{j} g^\beta_{k\ell} + f^i_{\alpha \beta} g^\alpha_{jk} g^\beta_\ell + f^i_{\alpha} g^\alpha_{jk\ell} \newline
&= f^i_{\alpha \beta \gamma} g^\alpha_j g^\beta_k g^\gamma_\ell + 3 \cdot f^i_{\alpha \beta} g^\alpha_j g^\beta_{k\ell} + f^i_{\alpha} g^\alpha_{jk\ell}
\end{align}
These expressions get very unwieldy for higher order derivatives. Let’s try one fourth!
\begin{align}
\frac{\partial^4 f^i}{\partial x^j \partial x^k \partial x^\ell \partial x^m} &= f^i_{\alpha \beta \gamma \delta} g^\alpha_j g^\beta_k g^\gamma_\ell g^\delta_m + 6 \cdot f^i_{\alpha \beta \gamma} g^\alpha_j g^\beta_k g^\gamma_{\ell m} + 3 \cdot f^i_{\alpha \beta} g^\alpha_{j\ell} g^\beta_{km}
+ 4 \cdot f^i_{\alpha \beta} g^\alpha_{j} g^\beta_{k \ell m} + f^i_{\alpha} g^\alpha_{jk\ell m}.
\end{align}
OK, what is going on? This is tedious and confusing and it is not obvious if there is any structure to this. In fact there is a very simple structure and we can derive all the above with some simple rules involving <em>special labeled rooted trees</em>. To make the connection, we make two observations. Each derivative is a sum of factors of the form $f^i_{\alpha\beta…}g^\alpha_{ij…}g^\beta_{k\ell…} \cdots$ where there is a:</p>
<ol>
<li>single term in $f^i_{\alpha\beta…}$ with multiple subscripts,</li>
<li>multiple terms in $g^\alpha_{ij…}$ where $g$ has a single superscript and potentially many subscripts.</li>
</ol>
<p>We are going to replace each term in $f$ or $g$ with parts of a special rooted tree.</p>
<h1 id="special-labeled-rooted-trees">Special labeled rooted trees</h1>
<p>We begin by drawing the simplest tree $f^i$ as</p>
<p align="center">
<img src="/media/2023/aod_1.svg" />
</p>
<p>This is just a root node of a tree—hence special labeled <em>rooted</em> tree. Every time we differentiate $f^i$ we will draw a branch emanating from the root node. In other words, for every subscript of $f^i$ we draw a branch. The first derivative $f^i_{\alpha} g^\alpha_j$ we thus draw as</p>
<p align="center">
<img src="/media/2023/aod_2.svg" />
</p>
<p>This is simple enough. Note, we shall also label the nodes with the subscript of the attached branch—in this case $j$—so that we can keep track of what branch corresponds to what algebraïc terms. Hence special <em>labeled</em> rooted tree. We didn’t write $i$ by the root node, since it is not a <em>sub</em>script. In fact, since $i$ only ever appears in the superscript of $f$, we could drop it entirely, leaving $f$ as a vector-in scalar-out mapping, which we choose to do from now on.</p>
<p>Now what about the factor $f_{\alpha\beta} g^\alpha_j g^\beta_k$? It has two branches emanating from the root as</p>
<p align="center">
<img src="/media/2023/aod_3.svg" />
</p>
<p>What if $g$ has multiple subscripts? Well, we then extend the branch by as many subscripts in $g$ so $f_{\alpha} g^\alpha_{jk}$ and $f_{\alpha\beta} g^\alpha_{jk}g^\beta_\ell$ look like</p>
<p align="center">
<img src="/media/2023/aod_4.svg" />
</p>
<p>This notation is a little weird at first, but as expressions get longer and more cumbersome, the tree representations become easier to handle. Now we have everything we need to differentiate the tree representation of our function $f(g(\textbf{x}))$. The $1$st derivative of $f$ is $f_\alpha g^\alpha_j$, which is a single branched tree</p>
<p align="center">
<img src="/media/2023/aod_5.svg" />
</p>
<p>I have drawn the new branch in red to emphasize it. Differentiating again yields $f_{\alpha \beta} g^\alpha_j g^\beta_k + f_{\alpha} g^\alpha_{jk}$, so</p>
<p align="center">
<img src="/media/2023/aod_6.svg" />
</p>
<p>What just happened? When differentiating $f_\alpha g^\alpha_j$, which in the literature is called an <em>elementary differential</em>, we applied the product rule and made two copies of $f_\alpha g^\alpha_j$. To the first copy we differentiated the $f_{\alpha}$ term, adding a new subscript $\beta$ and an extra $g^\beta_k$ branch to the root. To the second copy we differentiated the $g^\alpha_j$ term, raising it to a $2$nd order deriviative, and thus extending the already existing $g^\alpha_j$ branch to a length $2$ $g^\alpha_{jk}$.</p>
<p>We can easily see how this technique generalizes to higher order factors. We apply the product rule and make as many copies of our special labeled rooted tree as there are terms in the factor. To the first copy we add a branch corresponding to differentiating $f$ and to the remaining copies we extend each of the existing branches, one by one. Let’s apply this technique to differentiate again, either adding a new branch to root or extending existing branches. This yields</p>
<p align="center">
<img src="/media/2023/aod_7.svg" />
</p>
<p>Now, noticing that the middle three trees are topologically the same, with permuted labels, we can rewrite this, but we need to strip the labels. This results in</p>
<p align="center">
<img src="/media/2023/aod_8.svg" />
</p>
<p>which corresponds to the expression $f_{\alpha \beta \gamma} g^\alpha_j g^\beta_k g^\gamma_\ell + 3 \cdot f_{\alpha \beta} g^\alpha_j g^\beta_{k\ell} + f_{\alpha} g^\alpha_{jk\ell}$ that we derived earlier! These new label-less trees are referred to as simply as <em>special rooted trees</em>. In maths-speak, a special rooted tree is an representative of the equivalence class of special labeled rooted trees.</p>
<h1 id="aside-where-does-that-3-come-from">Aside: Where does that 3 come from?</h1>
<p>That 3 we see popping up in front is the <em>cardinality</em> of the equivalence class–the total number of valid labelings of the tree. Without getting too distracted, for a labeling to be valid labels need to increase from the root, so</p>
<p align="center">
<img src="/media/2023/aod_9.svg" />
</p>
<p>is an invalid labeling, assuming we have chosen alphabetical ordering of labels. On the surface, it’s not very obvious why the coefficients that precede the elementary differentials in higher derivative expressions would naturally be the number of valid labelings. But staring at the diagram of how we differentiate special labeled rooted trees, we see that each row essentially generated all possible special rooted labeled trees. So all possible labelings of each special rooted labeled tree are enumerated. And hence these coefficients have a very beautiful origin.</p>
<p>For those with a background in combinatorix, you will probably be quick to realize that there is a bijection between special rooted labeled trees and integer partitions of sets. We can associate each of the following 4-node trees with partitions with integer partitions of the set ${j, k, \ell}$</p>
<p align="center">
<img src="/media/2023/aod_10.svg" />
</p>
<p>Each branch in the diagram is a grouping of letters into a subset. While each branch has to be ordered alphabetically from its root, there is only one such valid ordering, so the subset can just be left unordered. We could go deeper into partitions of sets, but Wikipedia is your friend here.</p>
<h1 id="back-to-differentiation">Back to differentiation</h1>
<p>For me I would say the tree representation is much easier to parse than the algebraïc representation, which, mind you, is still shorthand for
\begin{align}
\frac{\partial^3 f}{\partial y^\alpha \partial y^\beta \partial y^\gamma}\frac{\partial g^\alpha}{\partial y^j}\frac{\partial g^\beta}{\partial y^k}\frac{\partial g^\gamma}{\partial y^\ell} + 3\frac{\partial^2 f}{\partial y^\alpha \partial y^\beta}\frac{\partial g^\alpha}{\partial y^j}\frac{\partial^2 g^\beta}{\partial y^k \partial y^\ell}+ \frac{\partial f}{\partial y^\alpha}\frac{\partial^3 g^\alpha}{\partial y^j \partial y^k \partial y^\ell}.
\end{align}
What would be the expression for the $5$th order derivative?</p>
<p>So we can study higher order derivatives of compositions of functions via special rooted trees! This process of adding and extending branches can be applied recursively very easily and a list of the first few special rooted trees looks like</p>
<p align="center">
<img src="/media/2023/aod_11.svg" />
</p>
<p>The theory of rooted trees goes very deep. We have only considered the <em>special</em> variety, for which branching can only occur at the root node. People have gone far into defining entire algebras over rooted trees, defining operations such as multiplication and addition. This comes in handy when studying order conditions of Runge-Kutta solvers and renormalization in quantum field theory. I personally think this area is extremely beautiful and am even more happy that I have a quick trick to derive expressions for higher order derivatives of composed functions.</p>Daniel Worralldaniel.e.w.worrall at gmail.comThe chain rule for higher order derivatives boosts a wealth of beautiful mathematical structure touching the theory of special rooted trees, group theory, combinatorics of integer partitions, order theory, and many others.Dual numbers2021-08-09T00:00:00+02:002021-08-09T00:00:00+02:00https://danielewworrall.github.io/blog/2021/08/dual-numbers<h1 id="dual-numbers-i">Dual numbers I</h1>
<p><strong>TL;DR</strong>: There is a generalisation of the complex numbers where $i^2=0$ instead of $i^2=-1$. Functions extended to this <em>dual number</em> system have the curious property that we can read off their derivatives (at a point $x$) if we evaluate them at the dual number $x + i$. This has implications for automatic differentiation frameworks.</p>
<p>I found writing this next post is a real treat. It’s about <em>dual numbers</em>. Dual numbers are a bit strange, to say the least, and at first they seem like an abstract mathematical fancy, but as you will see they serve quite a useful purpose in the realm of automatic differentiation. We’re going to start by reviewing what we know about the complex numbers. It will turn out that by tweaking them a wee bit we end up with the dual numbers, which, as mentioned, have some strikingly elegant properties when it comes to evaluating derivatives on computation graphs.</p>
<h3 id="complex-numbers">Complex numbers</h3>
<p>The complex numbers $z \in \mathbb{C}$ are typically expressed in split real-imaginary form $z = a + ib$, where $a, b \in \mathbb{R}$ are real numbers and $i$ is the <em>imaginary unit</em>. In high school and as a young engineering undergraduate student these were the bane of my life. $i$ has this weird property that $i^2 = -1$. Apart from that though, the complex numbers seem to act just like the reals under algebraic manipulation, so if $z = a + ib$ and $y = c + i d$, then</p>
\[z y = (a + ib) (c + id) = ac + iad + ibc + \color{red}{i^2}bd = (ac - bd) + i(ad + bc).\]
<p>Now why did I have to learn to perform these mundane manipulations? Well the complex numbers have these beautiful geometric properties that connect them with the trigonometric functions. Since (periodic and analytic) functions can be expanded in a trigonometric basis, it turned out that we could study just about any function of interest in the complex domain and usually it was simpler to do so.</p>
<h3 id="hypercomplex-numbers-complex-double-and-dual">Hypercomplex numbers: complex, double, and dual</h3>
<p>But do we necessarily have to demand that $i^2 = -1$? Well no. In fact, allowing $i^2$ to equal other values opens up a garden of delights. If we set $i^2 = 1$ we have the <em>double numbers</em>, also known as the <em>split complex-numbers</em>, and if we set $i^2 = 0$ (making sure that $i \neq 0$), then we have the <em>dual numbers</em>. It turns out that all 3 number systems for $i^2=-1$, $i^2=1$, and $i^2=0$ are cases of <a href="https://en.wikipedia.org/wiki/Hypercomplex_numbers">hypercomplex numbers</a>. The above multiplication for dual numbers is</p>
\[z y = (a + ib) (c + id) = ac + iad + ibc + \color{red}{i^2}bd = ac + i(ad + bc).\]
<p>The term $i^2bd$ falls away since $i^2=0$, as we defined. A mathematical object, where $\underbrace{ii \cdots i}_{k \text{ times}} = 0$ is called <em>nilpotent</em> with degree $k$.</p>
<h3 id="taylor-expansions-and-exact-linearisation">Taylor expansions and exact linearisation</h3>
<p>Let’s do some computations and see why dual numbers are useful. We are going to take a function $f$ defined on the real domain and stick dual numbers $z = x + iy$ into it. This may seem cowboyish, and it is, but it will lead us somewhere very satisfying. Now how do we evaluate a function at a point $x + iy$? Well if the function is <a href="https://en.wikipedia.org/wiki/Analytic_function">analytic</a> (fancy word for smooth), then we can just use a Taylor expansion, so</p>
\[\begin{aligned}
f(x + iy) &= f(x) + iy f'(x) + \underbrace{\frac{(iy)^2}{2!}}_{=0} f''(x) + \underbrace{\frac{(iy)^3}{3!}}_{=0} f'''(a) + ... \newline
&= f(x) + iy f'(x).
\end{aligned}\]
<p>We dropped the second-order and higher-order terms because they all contained terms with $i^2$, which we have defined as zero. This is marvellous! By evaluating $f$ at the point $x + iy$, we can return its exact linearisation. No need for Big-$\mathcal{O}$ Notation and hand-waving about $iy$ being ‘small enough’. Furthermore, if we set $y=1$, then we can read off the derivative of $f$ as the dual component (analogous to imaginary component) of $f(x + i)$.</p>
<p>Regarding terminology, for a dual number $x+iy$ it is common to call $x$ the <em>primal</em> since it represents the primary component of the computation; $y$ is the <em>tangent</em>, giving a nod to the fact that it. represents a derivative, which lives in a tangent space; and $i$ is the tag, which is an odd name, but it will make sense in a following blog when I discuss higher-order derivatives.</p>
<h3 id="computation-graphs-and-the-chain-rule">Computation graphs and the chain rule</h3>
<p>In modern machine learning, we like to build composable functions and optimise all the parameters using automatic differentiation. Automatic differentiation is just a souped-up version of the chain rule. Let’s see how dual numbers pair with the chain rule. First of all a recap of the chain rule:</p>
\[\frac{\mathrm{d}}{\mathrm{d} x} f(g(x)) = \color{red}{f'(g(x)) g'(x)}.\]
<p>Now with dual numbers</p>
\[f(g(x + i)) = f(g(x) + ig'(x)) = f(g(x)) + i\color{red}{f'(g(x)) g'(x)}.\]
<p>So we see indeed that the tangent component of $f(g(x + i))$ is indeed the correct derivative, had we used the chain rule! How would we code this? How do we even represent dual numbers?</p>
<h3 id="automagic-dual-number-based-differentation">Automagic dual number-based differentation</h3>
<p>What do we need to implement dual number-based automatic differentation? First we need a dictionary of composible atomic functions $f_1, f_2, f_3, …$, typically called <em>primitives</em>, from which we can build a computation graph. All deep learning libraries contain them. For instance, think of PyTorch’s <code class="language-plaintext highlighter-rouge">torch.nn.functional.relu()</code>. Next we are going to require all of their derivative functions. Just like in standard backprop, we always have these at hand. Typically we may be used to defining a function with separate <code class="language-plaintext highlighter-rouge">forward()</code> and <code class="language-plaintext highlighter-rouge">backward()</code> methods for the evaluation and derivative separately. For instance, consider the tangent function $f(x) = \tan(x)$, which has derivative $f’(x) = 1 + \tan^2(x)$:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">Tan</span><span class="p">:</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">tan</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="mi">1</span> <span class="o">+</span> <span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">tan</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">tan</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
</code></pre></div></div>
<p>Differentiation by dual numbers works differently by overloading the arguments to a function $f$, using dual numbers instead of real ones. Likewise, we can take a function, and overload its input. The alternative is to define a separate function <code class="language-plaintext highlighter-rouge">dtan()</code> with the desired properties. This method is called <em>source code transformation</em>. We can represent dual number $z = x + iy$ as a tuple $(x, y)$. Then</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">dtan</span><span class="p">(</span><span class="n">z</span><span class="p">):</span>
<span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">z</span>
<span class="k">return</span> <span class="n">Tan</span><span class="p">().</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">y</span> <span class="o">*</span> <span class="n">Tan</span><span class="p">().</span><span class="n">backward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</code></pre></div></div>
<p>That’s it. At the start of your computational graph, just specify $z=(x,1)$ and away you go! A more modular way to implement this would be to specify a dual method that can operate on any function, not just <code class="language-plaintext highlighter-rouge">Tan()</code>. This would have form</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">dual</span><span class="p">(</span><span class="n">primitive</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">df</span><span class="p">(</span><span class="n">z</span><span class="p">):</span>
<span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">z</span>
<span class="k">return</span> <span class="n">primitive</span><span class="p">.</span><span class="n">forward</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">y</span> <span class="o">*</span> <span class="n">primitive</span><span class="p">.</span><span class="n">backward</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">df</span>
</code></pre></div></div>
<p>Then <code class="language-plaintext highlighter-rouge">dtan</code> is equivalent to <code class="language-plaintext highlighter-rouge">dual(Tan())</code>. For instance</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">dtan</span> <span class="o">=</span> <span class="n">dual</span><span class="p">(</span><span class="n">Tan</span><span class="p">())</span>
<span class="k">print</span><span class="p">(</span><span class="n">dtan</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span>
<span class="o">>>></span> <span class="p">(</span><span class="o">-</span><span class="mf">2.185039863261519</span><span class="p">,</span> <span class="mf">5.774399204041917</span><span class="p">)</span>
<span class="k">print</span><span class="p">((</span><span class="n">np</span><span class="p">.</span><span class="n">tan</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="mi">1</span><span class="o">+</span><span class="n">np</span><span class="p">.</span><span class="n">tan</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">**</span><span class="mi">2</span><span class="p">))</span>
<span class="o">>>></span> <span class="p">(</span><span class="o">-</span><span class="mf">2.185039863261519</span><span class="p">,</span> <span class="mf">5.774399204041917</span><span class="p">)</span>
</code></pre></div></div>
<p>Now the real test is to check whether composition works. Let’s take the derivative of $f(x) = \tan(\tan(x))$, where</p>
\[\frac{\mathrm{d} f}{\mathrm{d} x} = (1 + \tan^2(\tan(x)))\cdot(1 + \tan^2(x)).\]
<p>Evaluating this at $x=2$, in code this is</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">dtan</span><span class="p">(</span><span class="n">dtan</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">))))</span>
<span class="o">>>></span> <span class="p">(</span><span class="mf">1.417928575505387</span><span class="p">,</span> <span class="mf">17.383952637114582</span><span class="p">)</span>
<span class="k">print</span><span class="p">((</span><span class="n">np</span><span class="p">.</span><span class="n">tan</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">tan</span><span class="p">(</span><span class="mi">2</span><span class="p">)),</span> <span class="p">(</span><span class="mi">1</span><span class="o">+</span><span class="n">np</span><span class="p">.</span><span class="n">tan</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">tan</span><span class="p">(</span><span class="mi">2</span><span class="p">))</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="mi">1</span><span class="o">+</span><span class="n">np</span><span class="p">.</span><span class="n">tan</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">**</span><span class="mi">2</span><span class="p">)))</span>
<span class="o">>>></span> <span class="p">(</span><span class="mf">1.417928575505387</span><span class="p">,</span> <span class="mf">17.383952637114582</span><span class="p">)</span>
</code></pre></div></div>
<p>It really is amazing how embarrasingly simple this technique is. Notice how the order of gradient call executions is the same as the order of the forward pass. As such, this is an incarnation of so-called <em>forward-mode differention</em>, to be contrasted with <em>reverse-mode differentation</em>, which machine learners tend to call <em>backprop</em>.</p>
<p>Why don’t deep learners use this method then, if it’s so simple? The answer is that for high dimensions, the forward accumulated gradient is a Jacobian, which can be a prohibitively large matrix; whereas, the backpropagated gradient is a vector.</p>
<h3 id="next-steps">Next steps</h3>
<p>The method above is sweet and simple, but it has some failings when we wish to do slightly more complicated things, such as evaluate higher-order derivatives. In the next post, I’ll show a slightly more sophisticated way to implement dual number-based differentiation, which side-steps these issues.</p>
<h3 id="references">References</h3>
<ul>
<li>https://www.mdpi.com/2075-1680/8/3/77/html</li>
<li>http://blog.jliszka.org/2013/10/24/exact-numeric-nth-derivatives.html</li>
<li>https://en.wikipedia.org/wiki/Dual_number</li>
<li>https://encyclopediaofmath.org/wiki/Double_and_dual_numbers</li>
<li>https://en.wikipedia.org/wiki/Automatic_differentiation#Automatic_differentiation_using_dual_numbers</li>
</ul>Daniel Worralldaniel.e.w.worrall at gmail.comI found writing this next post is a real treat. It's about *dual numbers*. Dual numbers are a bit strange, to say the least, and at first they seem like an abstract mathematical fancy, but as you will see they have quite a useful purpose in the realm of automatic differentiation.Reversible optimisers2020-12-20T00:00:00+01:002020-12-20T00:00:00+01:00https://danielewworrall.github.io/blog/2020/12/reversible-optimisers<p>This post touches on a curious property of some common optimisers used by the machine learning community: <em>reversibility</em>.</p>
<p>I tend to hate reading through lengthy introductions, so let’s just dive in with an example. Take gradient descent with momentum, this has the following form
\begin{align}
\mu_{t+1} &= \alpha \mu_t + \nabla_{x} f(x_{t}) \newline
x_{t+1} &= x_t - \lambda \mu_{t+1}.
\end{align}
Here $x_t$ denotes the optimisation variable, or <em>position</em>, $x$ at time $t$, $\mu$ is the associated <em>momentum</em>, and $0 < \alpha < 1$ & $\lambda > 0$ are metaparameters, which govern the dynamics of the descent trajectory. I use the term <em>meta</em>parameters, instead of <em>hyper</em>parameters, to distinguish that they are part of the optimiser and not the model, even though some would nowadays say that the optimiser is in fact part of the model, implicitly regularising it.</p>
<p>Anyway, interestingly we can reverse these equations, given the state $[x_{t+1}, \mu_{t+1}]$ as
\begin{align}
x_t &= x_{t+1} + \lambda \mu_{t+1} \newline
\mu_{t} &= \frac{1}{\alpha} \left ( \mu_{t+1} - \nabla_{x} f(x_{t}) \right).
\end{align}
This seemingly arbitrary property is useful from a practical standpoint.</p>
<h3 id="memory-efficiency">Memory efficiency</h3>
<p>An oft-lauded property of reversible systems is that we do not have to store intermediate computations, since they should be easily reconstructed from the system’s end-state. Typically for reverse-mode differentiation to work (i.e. backpropagation), we have to store all the intermediate activations in the forward pass of a network. This has memory complexity, which scales linearly with the size of the computation graph. If we can dynamically reconstruct intermediate activations during the backward pass, then we instantly convert this linear memory complexity to a constant, which enables us to build (in theory) infinitely deep networks.</p>
<h3 id="momentum-is-additive-coupling">Momentum is additive coupling</h3>
<p>Indeed, if you look a little closer at the momentum equations, then you may spot that they resemble an <a href="https://arxiv.org/pdf/1410.8516.pdf">additive coupling layer</a>. Here we have that a state, split into two parts $x$ and $\mu$ (to mimic the momentum optimiser notation), is reversible with the following computation graph
\begin{align}
\mu_{t+1} &= \mu_t + g(x_t) \newline
x_{t+1} &= x_t + h(\mu_{t+1})
\end{align}
To make a direct comparison, $g(x) = \nabla_x f(x)$ and $h(x) = \lambda x$. The one slight discrepancy is the factor of $\alpha$, but we can sweep that under the rug. The reverse equations for the additive coupling layer are
\begin{align}
x_{t} &= x_{t-1} - h(\mu_{t+1}) \newline
\mu_{t} &= \mu_{t+1} - g(x_t).
\end{align}</p>
<div style="text-align:center"><img src="/images/coupling.png" width="50%" /></div>
<p><em>Source: <a href="https://arxiv.org/pdf/1902.02729.pdf">Reversible GANs for Memory-efficient Image-to-Image Translation</a>. This diagramme represents the additive coupling layer in its computation graph form. LEFT: forward pass. RIGHT: reverse pass. To link up the notation $x_1 = \mu_{t}$, $x_2 = x_{t}$, $y_1 = \mu_{t+1}$, $y_2 = x_{t+1}$, $g = \texttt{NN}_1$, and $h=\texttt{NN}_2$</em></p>
<h3 id="case-study">Case study</h3>
<p>Specifically in the case of optimisers, I was pointed towards this paper <a href="https://arxiv.org/pdf/1502.03492.pdf">Gradient-based Hyperparameter Optimization with Reversible Learning</a> (2015) by <a href="https://dougalmaclaurin.com/">Dougal Maclaurin</a>, <a href="http://www.cs.toronto.edu/~duvenaud/">David Duvenaud</a>, and <a href="https://www.cs.princeton.edu/~rpa/">Ryan Adams</a>. The authors exploited the reversibility property of SGD with momentum to train the optimiser metaparameters themselves. First they run the optimiser an arbitrary number of steps, say 100 iterations. This defines an optimisation trajectory $x_0, x_1, x_2, …, x_{99}$. Now the clever part is that you can view the unrolled optimisation trajectory as a computation graph in itself. They compute a loss at the end of the trajectory, then they backpropagate the loss in the reverse direction with respect to the optimiser’s metaparameters.</p>
<div style="text-align:center"><img src="/images/reversibility.png" width="50%" /></div>
<p><em>Source: <a href="https://arxiv.org/pdf/1502.03492.pdf">Gradient-based Hyperparameter Optimization with Reversible Learning</a>. The authors optimise metaparameters by backpropagating along optimisation roll outs. This is made possible with the reversibility of momentum-based SGD, to cap memory-complexity.</em></p>
<p>Could we not do this already, such as in <a href="https://arxiv.org/abs/1606.04474">Learning to learn by gradient descent by gradient descent</a> (Andrychowicz et al., 2016)? Well yes, but the crucial point is that you would usually have to store all the intermediate states $\{[x_t, \mu_t]\}_{t=0}^{99}$, which is costly memory-wise. Exploiting the reversibility property, this memory explosion falls away. Indeed there are issues with numerical stability of the inverse, which the papers dives into, but the principle is elegant.</p>
<h3 id="adam">Adam</h3>
<p>So what other optimisers are reversible? Let’s consider <a href="https://arxiv.org/pdf/1412.6980.pdf">Adam</a>, where
\begin{align}
\mu_{t+1} &= \beta_1 \mu_t + (1-\beta_1) \nabla_{x} f(x_{t}) \newline
\nu_{t+1} &= \beta_2 \nu_t + (1-\beta_2) (\nabla_{x} f(x_{t}))^2 \newline
x_{t+1} &= x_t - \lambda \frac{\mu_{t+1}}{\sqrt{\nu_{t+1}} + \epsilon}.
\end{align}
Given $x_{t+1}$, $\mu_{t+1}$ and $\nu_{t+1}$, we can easily reconstruct $x_t$ from the last line and from there, we can compute the gradient and recover $\mu_{t}$ and $\nu_{t}$. In maths
\begin{align}
x_{t} &= x_{t+1} + \lambda \frac{\mu_{t+1}}{\sqrt{\nu_{t+1}} + \epsilon} \newline
\mu_{t} &= \frac{1}{\beta_1} \left ( \mu_{t+1} - (1-\beta_1) \nabla_{x} f(x_{t}) \right ) \newline
\nu_{t} &= \frac{1}{\beta_2} \left ( \nu_{t+1} - (1-\beta_2) (\nabla_{x} f(x_{t}))^2 \right).
\end{align}
So Adam is reversible. We actually missed out the bias correction steps
\begin{align}
\mu_{t+1} &\gets \mu_{t+1} / (1 - \beta_1^{t+1}) \newline
\nu_{t+1} &\gets \nu_{t+1} / (1 - \beta_2^{t+1}).
\end{align}
You can also verify for yourself that these are reversible too.</p>
<h3 id="do-we-need-reversibility-in-optimisers">Do we need reversibility in optimisers?</h3>
<p>Well, no. In fact, in some ways, we would rather do without it. Optimisers are supposed to be many-to-one mappings. Starting from an infinity of initial conditions, we should converge to the global minimum of a convex function. This means we should discard information about initialisation along the way. To put it as Maclaurin et al. do:</p>
<blockquote>
<p>[O]ptimization moves a system from a high-entropy initial state to a low-entropy (hopefully zero entropy) optimized final state.</p>
</blockquote>
<p>It turns out that if you set $\alpha = 0$ for the momentum method; that is, you just run gradient descent, then this is not reversible. I think this may also be true for <a href="https://www.cs.toronto.edu/~fritz/absps/momentum.pdf">Nesterov accelerated momentum</a>, and <a href="http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf">RMSProp</a> which I couldn’t make reversible (I call this <em>proof by fatigue</em>). So I’m left wondering, is reversibility just some extra curious property that can be useful sometimes, but is completely arbitrary when it comes to doing optimisation? Or is there some deeper meaning to it? Is it just some artifact of how we think of optimisation, in terms of balls rolling down hills? Maybe more interestingly, what does the lack of reversibility for standard gradient descent and Nesterov entail? Could this be another reason why Nesterov works better than classical momentum? Could we measure the information loss somehow? And if we could, what would this mean?</p>Daniel Worralldaniel.e.w.worrall at gmail.comReversible neural architectures have been a popular research area in the last few years, but reversibility is also built into many modern day neural optimisers, perhaps serendipitously.On the ‘invention’ of randomness2019-12-15T00:00:00+01:002019-12-15T00:00:00+01:00https://danielewworrall.github.io/blog/2019/12/randomness<p><img src="/media/jaynes-himself.jpg" alt="The legend himself" height="25%" width="25%" style="float: right;margin-left: 20px;margin-top: 7px;" /></p>
<p>Recently in AMLAB we started a Jaynes reading group. <a href="https://en.wikipedia.org/wiki/Edwin_Thompson_Jaynes">E T Jaynes’</a> posthumous book and general all-round cult classic <em>Probability Theory: The Logic of Science</em> is the focus of our study. After having lectured a Bayesian statistics course for the last two years, I felt fairly confident in my understanding of the subject matter. It seems that while I am at ease with performing computations I have far from grasped Bayesianism from a conceptual, and some might say doctrinal, standpoint. And after a couple of conversations with others in a similar situation to me, it seems I may not be alone.</p>
<p>Now this book is littered with gems and Jaynes’ colourful written-style is literary gold, but I want to focus on a small snippet, which got me thinking hard about my understanding of what it means to be <em>random</em>. <strong>Jaynes essentially claims that randomness simply does not exist</strong>. It is a human invention, which you can find in Digression 3.8.1. Let’s step through the logic.</p>
<h1 id="my-truth-versus-your-truth">My truth versus your truth</h1>
<p>In true statistical tradition, we are going to consider coin tosses. Let’s assume that randomness does exist. What do I mean by randomness? I mean that when I throw the coin in the air it will land heads or tails in a 100% unpredictable fashion. Some intrinsically indeterminate process will drive the coin to come to rest in a state, independent of when it left my fingers. In this unpredictable world an observer would be unable to make any sure judgements about the outcome. She may assume the prior probability of the coin landing heads or tails is 50%. This is not some weird quantum-y line of reasoning, the coin will land either heads or either tails, but we just cannot say beforehand. Some would argue, that maintaining a uniform probability distribution over the outcome of the coin toss is really the best she can do. It is a reflection of the truth of the world she lives in.</p>
<p>Let’s now pretend that the world is 100% deterministic. Say I flip a coin and I happen to know its mass, moment of angular inertia, air resistance, initial momentum, etc. In this perfect world, an observer gifted with this knowledge would be able to predict with 100% accuracy whether it lands heads or tails. Even if the computations are intractable and we are reduced to brute-forcing over possible futures, we could at least agree that the outcome is calculable in principle. No randomness here. Now let’s add a twist. Let’s say that we do not tell the observer any of this privileged information about the physics of the coin-toss. In that case, the observer would be reduced to educated guesses about the outcome of the toss. Because she’s a good Bayesian, she will continue to assume the prior probability of the coin landing heads or tails is 50%, but now her motivations are different. Or are they?</p>
<p>Maintaining a probability distribution over the outcome of the coin toss is not just a cynically non-committal statement. It is a full and honest acknowledgement of her ignorance. She has not claimed that the world itself is random in any way—in fact she very well knows it’s deterministic—but she is merely asserting that her knowledge of its physical parameters are unknown to her. Whichever way the coin lands, the outcome will be a surprise to her, because she is unable to predict it with her imperfect knowledge.</p>
<p>This seemed a bit weird to me. In moving from the synthetically random world to the deterministic one, it would seem that modelled randomness is really just a statement of observer ignorance, rather than any innate property of the universe around us. The fascinating point for me is that from the point-of-view of the observer, <em>it does not matter whichever world she lives in (the inherently random or the inherently deterministic), in both cases she is forced to use the same mathematical reasoning!</em> To make a linguistic distinction between the intrinsic randomness of the outside world and the observer’s perceived randomness, we will term them <em>intrinsic randomness</em> and <em>epistemic randomness</em>. Epistemic randomness is the uncertainty I have over the outside world because I simply do not have enough information. It seems that epistemic randomness is uncontroversial. Intrinsic randomness on the other hand is this very much controversial concept that Nature itself is indeterminate and unpredictable.</p>
<p>Now this kind of thinking leads to a vast array of new questions. Could an observer ever determine whether the world she lives in is intrinsically random or just deterministic, given that her tools to analyse both are the same? If intrinsic randomness does not exist, is it then some kind of invention? What does this mean for the interpretation of stochastically derived quantities such as aleatoric and epistemic uncertainty? What about free will? (This last question is a little cliché, but people seem to like talking about free will). <strong>I plan to devote a follow up blog to aleatoric and epistemic uncertainties.</strong></p>
<h1 id="my-gripe-with-jaynes">My gripe with Jaynes</h1>
<p>Those acquainted with Jaynes’ writings will be all too familiar with his grandiloquent rhetorical style. He really seems to believe that intrinsic randomness does not exist, that it is a sort of what he calls <em>mind projection fallacy</em>.</p>
<blockquote>
<p>The belief that ‘randomness’ is some kind of real property existing in Nature is a form of the mind projection fallacy which says, in effect, ‘I don’t know the detailed causes – therefore – Nature does not know them.</p>
</blockquote>
<p>Of course <em>he</em> would put (intrinsic) <em>randomness</em> in quotation marks. Mind projection fallacies, a very Jaynesian invention, are assertions wherein an observer states that how they see the world really is reality. My reality is your reality and everyone else’s too. That someone should disagree on the nature of Nature itself is simple idiocy. Now my gripe with Jaynes’ stance is that he himself appears to be stuck in a mind projection fallacy of his own. His assertion that intrinsic randomness does not exist is a projection of his reality on to the reader. Just for fun here is another, and I daresay somewhat salacious, Jaynes quote</p>
<blockquote>
<p>For some, declaring a problem to be ‘randomized’ is an incantation with the same purpose and effect as those uttered by an exorcist to drive out evil spirits; i.e. it cleanses their subsequent calculations and renders them immune to criticism. We agnostics often envy the True Believer, who thus acquires so easily that sense of security which is forever denied to us.</p>
</blockquote>
<h3 id="quantum-weirdness">Quantum weirdness</h3>
<p><strong>Disclaimer: in this next bit I talk about physics, but be under no illusions, I am far from knowledgeable on this subject.</strong></p>
<p>So is the world deterministic? As far as I know the main source of randomness in the world rises out of the depths from the sub-atomic quantum world. The quantum world is very strange in that very small objects, called particles, such as electrons and neutrons, are described probabilistically and not using our everyday Newtonian world-view. Particles are seen to exhibit <em>wave-particle duality</em>—a behaviour where they act like waves, being able to be in multiple locations at once—until they are observed, at which point we observe a bizarre effect called <em>wavefunction collapse</em>, where they then assume a definite location in space. Wavefunction collapse has always been an incomprehensible phenomenon to me. Why should the act of observation change the nature of the underlying physics?</p>
<p>The central apparatus for modelling particles is a <em>wavefunction</em>, a function extending over all space and time. The squared modulus of the wavefunction is equal to the probability that a particle will be observed at a specific location and time if a measurement is taken. One of the big problems in quantum mechanics is understanding how to interpret the wavefunction. Is it a fundamental physical object? If so, it is very strange indeed, since in the classical setting where we model big objects, we never observe an object to be in two places at once. In the quantum world, this happens by necessity.</p>
<p>One school of thinking, in fact the one I learnt in secondary school (that’s high school for everyone else), is that the wavefunction is a fundamental object of nature. It is part of reality. This is an unsettling idea, but if true it would indicate that our world is very weird indeed. Einstein famously rejected this interpretation claiming “God does not play dice”. This is the (in)famous <em>Copenhagen interpretation</em> of quantum mechanics, which as it turns out is not universally accepted or rejected by all physicists. Some detractors indeed do subscribe to deterministic <em>hidden variable theories</em>, as I believe did Jaynes, in which random quantum effects are just disturbances caused by so-called hidden variables, unobserved by us humans (so far). But I think that one flavour of hidden variable theories, called local hidden variable theories, have been ruled out already by what is known as the Bell test experiments. A much more promising route is <em>QBism</em>, pronounced <em>cubism</em> like the artistic movement, championed by Christopher Fuchs. QBism was originally called Quantum Bayesianism, but apparently many hardcore Bayesians have pointed out it is not strictly Bayesian and QBism sounds cooler anyway. In this theory particles occupy one position at any one time, the wavefunction is just an expression of observer ignorance, and wavefunction collapse is analogous to making a measurement and updating our beliefs. Mystery resolved…if it’s true.</p>
<p>So whether intrinsic randomness does exists is an open and potentially unanswerable question. At least to determine the potentiality of intrinsic randomness would require us to step outside of the rôle of observer—a purely unscientific practice. All observers are bound to make world inferences via the methods of epistemic randomness and are thus at risk of succumbing to Jaynesian mind projection fallacies. Some might go as far to say that if something is unverifiable by experiment it cannot <em>exist</em>. Enter die-hard scientific methodists (theological pun intended). On a side note, the question of existence seems to be an argument I get into a lot nowadays with post-structuralists from the arts and humanities, who always tell me that scientists are mistaken in believing that what they study exists.</p>
<h2 id="inevitably-after-reading-a-story-about-flipping-coins-i-seem-to-have-found-myself-in-quite-the-bind-if-i-have-come-to-understand-anything-it-is-that-you-can-never-be-safe-in-thinking-you-know-anything">Inevitably after reading a story about flipping coins, I seem to have found myself in quite the bind. If I have come to understand anything, it is that you can never be safe in thinking you know anything.</h2>Daniel Worralldaniel.e.w.worrall at gmail.comRecently in AMLAB we started a Jaynes reading group. E T Jaynes' posthumous book and general all-round cult classic Probability Theory: The Logic of Science is the focus of our study. After having lectured a Bayesian statistics course for the last two years...