Understanding how AI learns is not an easy task. When faced with difficult tasks, we often resort to complex methods to find solutions, which I actively try to avoid in my scientific work (e.g., I almost never use PCA). Oftentimes, though, solutions are simple and straightforward. This post describes an interpretability technique that is simple, effective, and, I strongly suspect, applicable to other architectures. The post thus serves both as a description of the method and as a call for researchers to apply the technique to their own models.
The primary objective of most current AI interpretability research is to ascertain a causal relationship between individual units in deep neural networks and distinct, meaningful features in the output. This blog post discusses one method of uncovering such relationships that I call disentanglement with extreme values.
My first venture into the world of interpretability was when I wanted to understand how unsupervised generative AI can learn sounds of human language. Humans acquire language early in life, absorbing the sounds of speech that surround us. These sounds are a continuous, measurable physical property, yet we convert them into mental representational units. In essence, we transform the continuous physical world into discrete representational units within our brains. Consider the word ‘spit,’ which comprises a sequence of four sounds — [s], [p], [ɪ], and [t] — each contributing to its distinct meaning. I wanted to know if AI can represent continuous space with discrete units in a similar way as humans do. After all, we used to refer to AI as representation learning.
So I trained a Generative Adversarial Network (GAN) on some speech sounds. GANs are a very interesting architecture, but the main relevance here is that the Generator network takes a small vector of uniformly distributed variables z and transforms them into data (pictures or sounds).
In this paper, I discovered that a few of the 100 latent variables z correlate with the sound [s] in the generated output. For instance, LASSO logistic regression analysis showed that negative values of the eleventh latent variable (z11) are associated with the presence of [s], while positive values indicate its absence.”
This was nice, but correlations are correlations, and I wanted to dig deeper to find a causal relationship. To do this, I specifically manipulated the eleventh latent variable, setting z11 to -1 during inference, and generated a large dataset to observe any effects. This manipulation resulted in a slight increase in the generation of outputs containing the [s] sound. But this still did not fully satisfy me.
Upon further analysis, I discovered that the relationship between z11 and the likelihood of [s] appearing in the output was linear, even though the data was initially analyzed using a non-linear regression model.
Given the linear relationship, it became possible to hypothesize about the outcomes beyond the training range, such as for values more negative than -1. While the Generator was trained with z variables constrained to the interval (-1, 1), there are no such limits during generation; any value for the latent variables can be applied. This is when the fun part begins. I went all in and started generating outputs with values of z11 set to -25. The results were great 96/100 generated outputs contained an [s].
Latent space is a complex space with a lot of interactions between individual variables that we don’t understand. But setting a single unit to extreme values overrides those low-level interactions and gives you the underlying value of each variable.
Even better, by interpolating from these extreme values back to 0, it’s possible to observe a causal relationship between the specific values of a single latent variable and the loudness of [s] in the output (until it disappears from the output):
With Andrej Leban and Shane Gero, we applied the Causal Disentanglement with Discrete Values (CDEV) technique to try to understand what is meaningful in the sperm whale communication system. Drawing on principles from causal inference, we demonstrated that there exists a causal relationship between individual units in GANs and various meaningful properties in sperm whale communication system.
In sum, the disentanglement with extreme values technique is a pretty simple method for causally uncovering underlying values of single units in deep neural networks. I have a suspicion that it would transfer well to other architectures. So if you’re working with deep learning and you’ve been trying to find a causal relationship between individual units and something meaningful in your data, it might be worth trying to apply this technique and see what happens.