Processing math: 100%

2405.05219

Total: 1

#1 Conv-Basis: A New Paradigm for Efficient Attention Inference and Gradient Computation in Transformers [PDF9] [Copy] [Kimi22] [REL]

Authors: Jiuxiang Gu, Yingyu Liang, Heshan Liu, Zhenmei Shi, Zhao Song, Junze Yin

Large Language Models (LLMs) have profoundly changed the world. Their self-attention mechanism is the key to the success of transformers in LLMs. However, the quadratic computational cost O(n2) to the length n input sequence is the notorious obstacle for further improvement and scalability in the longer context. In this work, we leverage the convolution-like structure of attention matrices to develop an efficient approximation method for attention computation using convolution matrices. We propose a conv basis system, "similar" to the rank basis, and show that any lower triangular (attention) matrix can always be decomposed as a sum of k structured convolution matrices in this basis system. We then design an algorithm to quickly decompose the attention matrix into k convolution matrices. Thanks to Fast Fourier Transforms (FFT), the attention {\it inference} can be computed in O(kndlogn) time, where d is the hidden dimension. In practice, we have dn, i.e., d=3,072 and n=1,000,000 for Gemma. Thus, when kd=no(1), our algorithm achieve almost linear time, i.e., n1+o(1). Furthermore, the attention {\it training forward} and {\it backward gradient} can be computed in n1+o(1) as well. Our approach can avoid explicitly computing the n×n attention matrix, which may largely alleviate the quadratic computational complexity. Furthermore, our algorithm works on any input matrices. This work provides a new paradigm for accelerating attention computation in transformers to enable their application to longer contexts.

Subjects: Machine Learning , Artificial Intelligence

Publish: 2024-05-08 17:11:38 UTC