How Do Vision Transformers Work? by Namuk Park & Songkuk Kim
Transformers in the sense now common (QKV-style multi-head self-attention, a.k.a. MSA) always struck me as too clever, at least for images. The computational complexity, the translation variance, and the hype were all too much to rest on mere great results. So I was pleased by ConvNeXt, which showed that, in one example anyway, a refined but purely convolutional neural network (convnet) could match a vision transformer (vit)’s results. Apparently, when we had seen contemporary vits benchmark better than old convnets, some or most or even all of that improvement was really from a collection of individually minor developments in normalizations, activations, training techniques, and so on: a dozen things other than the transformers themselves.
Still, though. Transformers work well in language-land, or so I’m told, and even if they’re only as good as convolutions in vision-land, let’s ask why. In particular, the idea of transformers as a way to learnably duct information across distance is inspiring and worth exploring. Well, tough. It’s easy to find transformers explained but very hard to find an intuition for transformers developed. I felt like there was a gap in the literature until I found this paper:
We show that MSAs and Convs exhibit opposite behaviors. MSAs aggregate feature maps, but Convs diversify them. Moreover, […] the Fourier analysis of feature maps shows that MSAs reduce high-frequency signals, while Convs, conversely, amplif[y] high-frequency components. In other words, MSAs are low-pass filters, but Convs are high-pass filters. In addition, […] Convs are vulnerable to high-frequency noise but that MSAs are not. Therefore, MSAs and Convs are complementary.
This knocked me off balance when I first read it, but on reflection it was just below the surface of things I already knew. For example, I’d been thinking of transformers as something like k-means–style clustering, and in an image context that’s approximately non-local means, a classical denoising algorithm that is perforce a low-pass filter. A small step in hindsight; conceptually, a giant leap.
The paper leaves me with, in one hand, a fresh sense that even as a fan of convolutions I should think seriously about transformers; in the other, a curiosity about ways to get transformer advantages out of ingredients I like better. For example, how could we compare transformers with and against other low-pass filter setups, like U-nets’ resampling or scale space? What can we learn from decades of thinking about clustering and edge-preserving blurs? Could a transformer-inspired [1×1, scale decomposition, 1×1]
block (for example) do something strictly less powerful than a QKV block yet still worth it in quality per flop?
I don’t know, but thinking about it has helped me think about other things. And so, although I am a mild skeptic of transformers, this paper on transformers is one of the most productive things I’ve read about ML.