This repository contains JAX code accompanying my internal talk titled "Representation Theory and
$SO(3)$ in GNNs".
Note
This is intended to be a not-too-technical introduction to representation theory, and only assumes basic linear algebra.
Warning
When reading up representation theory for
Groups in mathematics can be quite abstract, so let’s just relate the rules for something to be a “group” to familiar rotations.
Closure. We are happy with the concept that if we rotate a real-life object around an axis, then rotate it around a different axis, we could have gotten the same end result through a single rotation about a single axis. [Proof: grab a nearby object - it will be useful for this entire section.] That is, composing two rotations gives another rotation. So if I construct a set that includes all possible rotations (and you forgive my loose language since there are uncountably many rotations), then under the binary operation of composition, two elements in my set give another element in my set.
Identity. Another intuitive fact is that there is only one rotation that does nothing: the “don’t bother” operation. [Proof: try thinking of another one.]
Inverse. We are also happy that any rotation that we do has an inverse rotation that undoes that rotation, yielding the “don't bother” operation. [Proof: play the footage of you rotating said object in reverse.]
Associativity. With these observations, we can almost say that what I have described above is a group, since a group requires a set of elements
A group is specified by the set of elements and a table specifying the result of applying the binary operation on every ordered pair of elements.
Note that one thing a group does not require is commutativity, meaning
The name of the group I have been describing is
It turns out that one can satisfy the specification of any group by replacing group elements with certain (non-singular) matrices, and letting the binary operation be matrix multiplication.
The mapping from the set of elements to matrices is called a representation. (A point about language worth repeating to avoid confusion: the representation is the entire mapping from group elements to matrices - a single matrix is not a “representation” of a single group element, even though sounds like a reasonable English statement.)
For
We have not yet specified the size of these matrices, because that depends on the exact details of the group at hand.
In the land of representations, an important divide exists between those that are reducible, and the privileged few that are irreducible.
You may have noticed from the definitions above that if you have a representation, it is easy to construct another representation by stacking matrices along the block diagonal. For example, the representation
will satisfy the right group relations, provided
A representation is reducible if the matrices all group elements can simultaneously be brought to block diagonal form through a basis change. (Simultaneous is an important word here, because the basis change has to be consistent across all group elements - otherwise each matrix in a representation could get its own basis change we would have a useless definition.) Predictably, irreducible representations are representations that are not reducible: in some ways they are like “atoms” of a group. (One could still have two different-looking irreps that are actually related by a change of basis, but we’d still call them both irreps. Later on, we will avoid this confusion by picking a physically-motivated basis.)
If one has studied linear algebra in science or engineering, it is not surprising that the group of rotations can be represented by matrices, because one has already been exposed to the
We know that we can represent an anti-clockwise rotation about the
and likewise for the
One can verify that in fact all rotations about
Notice also that these matrices are all antisymmetric,
Now for the cool bit, where we make the jump from mundane
Hence, if I have some other set of three larger matrices that also satisfy such commutation relations, they must also be generators in some larger representation! (A proper derivation of why this is necessary might rely on some less obvious lemmas, but this motivates why such commutation relations would be sufficient to create a new representation.)
(One can also verify that the operator
Funnily enough, these commutation relations can be satisfied by matrices of many shapes, not just
So far, we have only discussed rotations about the Cartesian axes, rather than a generic axis.
One way to rotate around a generic axis
We are now ready to introduce spherical harmonics.
Thus far, the basis we were using for our matrices were the Cartesian axes, and each of these were on equal footing.
It is more convenient, however, to choose a single rotation axis as a reference.
We choose this to be the
How do we make the
Note that
Fixing this basis lets us instantiate the Wigner D-matrix, which is the representation
which fixes their functional form (to quite unwieldy expressions, at least according to my beauty standards).
(Choosing a different “special axis” from the start would just be an orthogonal transformation of the Wigner D-matrix, which would then modify the functional form of the spherical harmonics.)
❔Interlude: why is it $\boldsymbol{Y}(R \cdot \mathbf{r}) = \boldsymbol{R}^{-1} \boldsymbol{Y}(\mathbf{r})$ ?
At first glance, this looks different to the equivariance relation for functions like GNNs.
There, GNNs like
Why? A self-consistent definition of rotations on signals should satisfy
Yet, on the RHS we should have
One can also intuit this graphically, which I plan to add as a figure.
One final thing: Because rotations don’t affect the distance of a point in space to the origin, if we want the spherical harmonics to be easily normalisable, it’s better to treat them as functions of the spherical angles only, i.e. a function on the unit sphere
In the same way that Fourier series basis functions become increasingly fine with higher momentum (e.g.
Here, we plot a surface with its spherical polar coordinates satisfying
and the colour of the surface gives the sign of
(Note, the real “real spherical harmonics” are not simply the real components of the complex spherical harmonics, but I was originally lazy when making the Gradio demo below, and things broke when I updated them to the real real spherical harmonics. The shapes look pretty similar in any case.)
Recall that equivariant graph neural networks (GNNs) are GNNs that behave sensibly when their inputs are transformed according to a group operation e.g. rotations.
Consider two GNNs both taking as input a point cloud
The point is, if I input an upside down cat (implemented through a rotation operator
The first GNN to use spherical harmonics as its building block for constructing equivariant GNNs was Tensor Field Networks (TFN), which acted on point clouds (treated as fully connected graphs). This inspired many other works, arguably the second most famous example being SE(3) Transformers which acts on graphs (not just fully connected) and simply adds an attention mechanism during the message passing steps.
I previously introduced some easy ways to make reducible representations from irreducible ones: taking the Kronecker sum (i.e. putting things on the block diagonal) like
We can also take the Kronecker product of two representations, like
This resulting representation will either be reducible or irreducible. Without loss of generality we can say that
where
We can of course take the
Let’s take
I have claimed that this will be the Kronecker sum of some spherical harmonics multiplied by a matrix
Letting
That paragraph was rather unfortunate, but the mnemonic is easy:
The Clebsch-Gordan coefficients are nothing but the elements of the boring change-of-basis matrix
Good News: The basic ideas of TFN are easy to understand once we’re happy with representations, spherical harmonics and tensor products. Using real spherical harmonics also means that we mostly don’t need to use complex number floating point operations.
Bad News: Dealing with latent features that must all be treated differently in the neural network gets finicky, especially when dealing with multiple channels. This means that notation in TFN has quite a few indices floating about, and keeping track of weights can be slightly annoying.
Essentially, latent node features
If you’re interested in an unfinished, unpolished, undocumented barebones implementation of TFN in JAX from someone who’s never used JAX before, then boy do I have the repository for you…
Specifically, the Tetris example shows how to construct an equivariant Tetris shape classifier that only gets trained on one orientation of each shape.
Also of interest may be the TFNLayer
module (in layers.py
), and functions for calculating spherical harmonics and tensor products in spherical.py
and tensor\_product.py
, respectively.
One cute part of this repository is that reasonably-efficient JAX implementations of spherical harmonics are computed on the fly (without being hardcoded in like in e3nn
) through metaprogramming.
This happens by using the computer algebra of SymPy to generate simplified real spherical harmonics in terms of Cartesian coordinates, which can then be compiled into JAX functions.
(To me this is quite a bit simpler than the
What’s not so cute is how spherical harmonics are recomputed many times by individual neural network layers, even though they could be reused (and a similar story holds for Clebsch-Gordan coefficients). At least this recalculation makes it easier to read for pedagogical purposes, but I may update this in the future to make it more efficient.
In some ways TFN is beautiful. In other ways, it is quite ugly.
When implementing TFN, one of the ugliest things is the fact that each feature with different angular momentum has a different number of components, which means one has to be careful with how they mix together. (Indeed, in my implementation I have kept different feature vectors as separate elements in a dictionary to avoid the headache.) Concatenating everything in one big tensor that can be efficiently operated on requires having a very intricate indexing scheme (which I gave up on for this talk).
Another non-beautiful thing is that when converting displacement vectors to spherical harmonics, one always has
As for something that is more than just aesthetics, swapping from cartesian components to spherical harmonics and performing tensor products for the large-$\ell$ representations adds up to a lot of floating point operations.
Having to store all the
This is why Passaro and Zitnick’s improvement is very cool!
They simplify every aspect of TFN by noticing it is better to have the arbitrary “special” axis not be arbitrary, and rather have it match the axis along which messages are being sent (i.e. the displacement vector between neighbouring nodes).
This makes everything much more sparse and efficient.
This has already been implemented in some modern architectures like EquiformerV2, and will probably soon replace TFN and