Mapping of attention mechanisms to a generalized Potts model

Transformers are neural networks that revolutionized natural language processing and machine learning. They process sequences of inputs, like words, using a mechanism called self-attention, which is trained via masked language modeling (MLM). In MLM, a word is randomly masked in an input sequence, and the network is trained to predict the missing word. Despite the practical success of transformers, it remains unclear what type of data distribution self-attention can learn efficiently. Here, we show analytically that if one decouples the treatment of word positions and embeddings, a single layer of self-attention learns the conditionals of a generalized Potts model with interactions between sites and Potts colors. Moreover, we show that training this neural network is exactly equivalent to solving the inverse Potts problem by the so-called pseudo-likelihood method, well known in statistical physics. Using this mapping, we compute the generalization error of self-attention in a model scenario analytically using the replica method.


Introduction.
Transformers [1] are a powerful type of neural network that have achieved state-of-the art results in natural language processing (NLP) [2][3][4][5][6], image classification [7], and even protein structure prediction [8].While standard neural networks can be thought of as functions of a single input, transformers act on sets of "tokens", like words in a sentence.The key to the success of transformers is a technique called masked language modeling (MLM), where transformers are trained to predict missing words in a sentence [2][3][4][5][6], cf.fig.1a.This technique has the advantage that it can leverage large amounts of raw text (or images, or protein sequences) without any annotation.By learning the conditional distribution of having a word in a specific position of the sentence, given the other words, transformers ostensibly learn the relationships between words in a robust way.
The basic building block of transformers is the selfattention (SA) mechanism [9,10], which transforms a sequence of tokens x j into another sequence h j .We illustrate self-attention on a masked language modeling task in fig. 1.The sentence is first transformed into a set of representations x j = e j + p j , where e j is a vector representing the jth word and the vector p j encodes its position.SA then computes a linear transformation of the representations to yield the values v j .The kth output vector h k is then a linear combination of the values v j weighted by an attention matrix A, whose elements A kj quantify the relative importance of the jth input token for the kth output vector, for example based on their semantic similarity.The functions to compute values v j and the attention matrix A both have trainable parameters; see eq. ( 2) for a precise definition.The flexibility of self-attention comes from the attention weights A kj , which are not fixed, but computed given the context, i.e. the surrounding tokens.
The practical success of transformers raises several fundamental questions: what are the statistical structures that self-attention learns with MLM?More precisely, since the MLM objective is to learn the conditional probability distribution of words given a set of surrounding words, which family of conditional probabilities can selfattention learn?And how many samples are required to achieve good performance?Here, we make a step towards answering these questions by exploiting tools from the statistical physics of learning [11][12][13][14].
The first challenge is to design a data model that mimics the structure of real sentences.While classical works modelled inputs as vectors of i.i.d.random variables, recent work has introduced more sophisticated data models for neural networks [15][16][17][18][19][20][21] which allowed the study of unsupervised learning [22,23].To analyze FIG. 1. Masked language modeling (MLM) with a single layer of self-attention.The goal of MLM is to predict the masked word in a given sentence.Self-attention first maps words into representations ej + pj, where ej are embedding vectors representing words, and pj encode their positions.For a given masked word, the associated attention vector h k is computed as a linear combination of the values vj = V (ej + pj) of all other tokens, weighted by the attention weights A kj .In vanilla self-attention, values and attention weights depend on embeddings and positional vectors, while in factored attention, attention weights depend only on positions, and values only on the embeddings.By identifying the attention weights A with the interaction matrix J of a Potts model eq.( 1), the value matrix V with the color similarity matrix U and the embedding vectors with the one-hot spins, we get a learning model identical to a Potts model.arXiv:2304.07235v4[cond-mat.dis-nn]4 Apr 2024 the self-supervised learning of MLM, we model sequences of words as system of spins, interacting via a generalized Potts Hamiltonian [24,25] with couplings between colors (=words) and positions.We sample a synthetic data set from the Potts model using Monte Carlo, and we perform masked language modeling by training a transformer to predict masked spins in spin sequences.While an off-theshelf transformer requires several layers of self-attention to learn this simple probability distribution, we show analytically that a single layer of factored self-attention, where we separate the treatment of positions and inputs, can reconstruct the couplings of the Potts model exactly in the limit of a large training set.In particular, we derive an exact mapping between the output of the selfattention mechanism and the conditional distribution of a Potts spin given the others.We finally use this mapping to compute the generalization loss of a single layer of self-attention analytically using the replica method.
A generalized Potts model to sample sequences.We model sentences as sequences of spins s = (s 1 , . . ., s L ), with s i ∈ R C taking values from a vocabulary of C colors, which we encode as one-hot vectors.Each color can be thought of as a word in natural text, an amino acid in a protein, etc.In a standard Potts Hamiltonian, only spins of the same color interact with each other via an interaction matrix J.This is an unrealistic model for real data: it treats all colors as orthogonal, even though words and amino acids have varying degrees of similarity.We therefore generalize the Potts Hamiltonian to where J ∈ R L×L governs the interactions between spins at different positions, and U ∈ R C×C encodes the similarities between colors (we denote matrices by capital letters and vectors in boldface).Without loss of generality, we set J ii = 0 and sample sequences from the Boltzmann distribution P (s) ∝ exp [−βH(s)].We recover the standard Potts model by choosing U as the identity matrix.
Masked language modeling with transformers.Given the generative model (1), the MLM objective amounts to predicting the ith spin given the sequence s \i where that spin is "masked", i.e. s i = t, the masking token.To apply self-attention to a sequence s \i , we first compute the values v j = V (Es j +ap j ), where the embedding matrix E ∈ R d×C maps Potts colors into d-dimensional representation vectors, and V ∈ R d×d is a weight matrix; both E and V are trainable parameters.The scalar parameter a controls the relative importance between the embedding and positional encoding vectors.The output vector h i corresponding to the masked token is a linear combination of the values, weighted by an exponential attention function [1]: (2) Crucially, the ith spin s i in this expression has to be replaced with the masking token t, since it is the masked input.The matrices Q, K ∈ R d×d are also trainable parameters of the model.In the following we take the embedding dimension equal to the number of colors, d = C, in order to be able to map the output vector h i into a probability distribution pi over the colors through the softmax non-linearity [26].
Training a vanilla transformer on the generalized Potts model.For our first experiment, we emulate the setting of protein structure prediction, so we choose a vocabulary of size C = 20 and sample a symmetric interaction matrix J ij = {0, 1} which we show in fig.2b).We draw the entries of the symmetric interaction matrix U i.i.d.from the standard Gaussian distribution.Given these parameters, we use Gibbs sampling to generate a data set with M = 3000 sequences of length L = 20.We tune the inverse temperature β to ensure an average Hamming distance of 0.3 between sampled sequences, typical for protein families [27].
We then train off-the-shelf transformers consisting of one and three layers on this data set by minimising the cross-entropy loss between the output distribution and the missing spin using stochastic gradient descent (see supplementary material for the numerical details) on the loss L(s) = −L −1 L i=1 C α=1 s iα log piα (s), for a sequence s.In fig.2, we show the test loss during training, where E s∼P denotes an average over the generative model (1).A transformer with a single layer of self-attention does not attain the optimal generalization error (black dashed line).By plotting the attention matrix of the single layer, we see that the transformer recovers the original interaction matrix to some degree, albeit not perfectly.Training transformers with three layers on the same data set improves the accuracy at the cost of loosing interpretability: there is no straightforward way to collapse several layers of non-linear transformations of the input sequence into a single attention map; we show the average of the final two attention layers in fig.2b.Factored self-attention learns the generalized Potts model.We now consider a variant of self-attention in which the treatment of positions and values is decoupled.We set a = 0 in eq. ( 2) , set the masking token t = 0, choose one-hot encodings for the positions, and fix the embedding matrix at E = I C , so that where ) ik .This modified self-attention has exactly the same form as the conditional distribution of the generalized Potts model if one sets U = V and βJ = A, which is This equivalence between factored self-attention and the Potts model is our first main result; we now discuss its ramifications.
Decoupling positions and colors leads to a significant improvement in the performance of a single layer, which reaches the optimal generalization error and converges faster, cf.fig.2a.Factored self-attention recovers the interaction matrix J perfectly, cf.fig.2b for the attention map and fig.2c for the reconstruction error of the interaction matrix.In fig.2d, we show that decoupling the treatment of positions and colors completely performs better than any intermediate solution with a > 0.
Factored attention layers, and thus input-independent attention weights, have been already used as a building block for deep transformers, outperforming standard attention in different applications: Bhattacharya et al. [28] used it to analyze protein sequences and found that a single layer of factored self-attention performed as well as a deep transformer, and significantly better than a single layer of vanilla self-attention, without explicitly explaining this observation.Moreover, using factored attention is key to obtaining state-of-the-art results in approximating ground states of many-body quantum systems [29][30][31].
Intriguingly, the form of the loss for masked language modeling with factored self-attention as described above exactly matches the loss of the pseudo-likelihood method, which has been used for solving the inverse Ising problem [32][33][34][35][36].The pseudo-likelihood method is statistically consistent [37][38][39], i.e. its parameter estimates converge to the true parameters as the number of samples goes to infinity.A direct consequence of the mapping in eq. ( 4) is thus that MLM with factored self-attention enjoys the same asymptotic optimality.
The sample complexity of self-attention.A key quantity in machine learning problems is the sample complexity, namely how many samples are required to achieve a small generalization loss ϵ g with a given model.The mapping introduced in this work allows us to address this question precisely for a single layer of self-attention by means of the replica method from statistical physics.The main difficulty in the calculation lies in handling the non-trivial data distribution (1).This difficulty can be mitigated thanks to recent advances in statistical physics, which allow us to extend the replica method of disordered systems to structured data [18,20,40].To perform the replica calculation, we first relax the discrete nature of Potts spins by re-writing the generalized Potts Hamiltonian (Eq. 1) in terms of spin magnetization m = ⟨s⟩ P (s) following mean-field theory.The associated Boltzmann measure then turns into a multivariate Gaussian distribution, whose covariance matrix is the negative inverse of the interaction matrix, i.e.Σ = −J −1 [41,42].We then draw sequences {m µ } M µ=1 of length L from the multivariate Gaussian with zero mean and covariance matrix where Ω is a symmetric full-rank random matrix sampled from the Gaussian Orthogonal ensemble while νI is a diagonal matrix centering the spectrum of Ω in ν.To ensure Σ to be positive definite, we set ν > 2 due to the semicircle law [43].By fixing the location i of the masked spin across all input sequences, solving the MLM task is equivalent to inferring the ith row of the interaction matrix J i .
To accomplish this task, we train a single layer of factored self-attention by empirical risk minimisation of a square loss with ℓ 2 -regularisation: Our goal is to characterise the generalization loss ϵ g (3) of factored attention with the parameters obtained from minimising the loss eq.( 6).In the high-dimensional limit, where the number of samples and the sequence length M, L tend to infinity while their ratio α ≡ M/L ∼ O (1), we can express ϵ g using replica theory as a function of four scalar quantities: Here, ν is the center of the spectrum of Ω, ρ = tr Σ \i / ν 2 L is a function of the covariance matrix Σ \i where we have removed the ith row and column, while q ⋆ and r ⋆ are the so-called overlap parameters.They correspond to practically measurable quantities over different realisation of the training set, involving the estimator of the ith row of the interaction matrix: The parameter r * can be thus interpreted as an overlap between the estimated attention Âi and the ground-truth value A i while q * is the overlap of the estimated attention, both mediated by the modified covariance matrix Σ \i .As we show in the supplementary material, the values of these order parameters for a given training set of size α can be obtained by solving the optimisation problem ) which yields the typical value (over the data) of the free energy density associated with a Gibbs measure at inverse temperature β whose Hamiltonian corresponds to the loss function in eq. ( 6).Note that the optimisation only involves the scalar parameters q, r, δq and their conjugates q, r and δ q, with δq = q − q s , δ q = q + qs and q s being the self-overlap among replicas.The so-called entropic potential Ψ s is a function of the input covariance Σ \i : while the energetic potential Ψ e only depends on the specific choice of the loss function.As shown in the SM, for the optimization problem in eq.( 6) Ψ e is given by: Using this approach, we estimated analytically the generalization loss in Eq. 7 as a function of the rescaled number of samples α = M/L (see supplementary materials for details).The result is shown in the right panel of fig. 3.As can be noticed, the test loss increases in the small data regime, before peaking at α = 1.This value corresponds to the interpolation threshold, which is the largest number of samples that the neural network can perfectly fit, which in fact happens at M = L. Below this threshold, the model overfits to its training data; beyond this threshold, the generalization error decreases monotonically with the training set size; for large α, we found ϵ g ∼ α −1/2 .A similar peak in the generalization loss has been observed in supervised learning [44] and it is connected to the well-known "double descent" curve observed in deep neural networks in the presence of label noise [45].There, the peak is a consequence of overfitting induced by the noise in the labels and it appears after an initial decay of the test loss at small α.In the self-supervised learning regime explored in this work, we find instead that the peak appears naturally as a consequence of the intrinsic stochasticity of the inputs.Indeed, in masked language modeling, the labels are a part of the input itself.The noise affecting the labels is thus highly correlated to that affecting the input.If the noise in the input is too high, the model starts immediately overfitting and the initial descent is not observed.The absence of the initial descent can be therefore ascribed to the high-level of the noise in the input.
We verify the predictions of the replica theory by plotting the generalization loss of a single layer of factored self-attention trained on the generalized Potts Model in the setting of fig. 2 at small regularisation (right side of fig.3).We see the same qualitative behaviour as predicted by replica theory, even though in this case we did not apply any of the assumptions required for the replica analysis (the mean field limit and the usage of a full-rank J matrix and of a U matrix fixed to the identity).In particular, the test loss increases when adding more data for small α.The only difference between the plots is the location of the peak.For the square loss that we analyzed with replicas, as we already commented the peak is at the interpolation threshold M = L.For the simulations with logistic loss, the peak appears at the linearly separability threshold, which is the largest number of points a linear classifier can classify correctly, and which can be larger than one [11,18].
Concluding perspectives.In this work, we have characterised the probability distributions that a single layer of self-attention can learn when trained on a masked language modeling task, considered as a simple prototype of self-supervised learning.In particular, we have shown analytically and numerically that with a single factoredattention layer, it is possible to exactly reconstruct the couplings of a generalized Potts model with two-body in-teractions between both sites and colors.More precisely, we showed that training factored self-attention on the MLM objective is equivalent to solving the inverse Potts problem using the pseudo-likelihood method [32][33][34][35][36], and therefore it yields consistent estimators of the parameters.These findings make factored attention a powerful, theoretically-driven building block for deep transformers.Our replica analysis of self-attention enabled us to compute the generalization loss of the model exactly and yielded a non-trivial generalization behaviour.
Learning higher-order interactions will require additional layers: a detailed study of how this can be achieved is an interesting direction for future research.It will be interesting to also study the learning dynamics of selfattention using methods from statistical physics [46][47][48][49], both on MLM and on supervised tasks [50][51][52][53].In short, our work clarifies the limits of standard self-attention trained on data where two-body interactions dominate and highlights the potential of factored attention as a component of transformer models.

SUPPLEMENTAL MATERIAL Appendix A: Numerical details
The numerical simulations were perfomed using JAX [54].Both the factored attention layer and the vanilla transformer architecture were optimised using SGD with mini-batch size of 100 and a cosine annealear as the learning rate decay scheduler, both standard choices in the literature.The initial learning rate was adjusted to the specific simulations, choosing it between 0.1 and 0.01.The vanilla transformer code has been taken from Ref. [55], with no modifications made.In particular, as already pointed out in the main text, each element s i of a sequence s = (s 1 , ..., s L ) is first transformed into a token x i = e i + p i , with e i being the embedding of s i and p i being the positional encoding.The tokenized sequences are then fed to a layer made of two distinct sub-layers.The first sub-layer is composed of a single-head attention, while the second sub-layer contains a two layer fully connected neural network.The inputs of both sublayers are connected to their outputs through skip-connections, and layer normalization is then applied.
Finally, there is an output layer consisting of a linear transformation from the d-to the C-dimensional space, in order to obtain a probability distribution over the colors through the softmax non-linearity.For a graphic visualization of the transformer encoder architecture, reference can be made to the original paper of Vaswani et al. [1].Below is the list of transformer hyperparameters used for the simulations of fig.2: embedding dimension 20, number of heads 1, number of layers 1-3, dropout probability 0.0, number of classes 20.
The dataset was generated using Gibbs sampling, starting from a random sequence of L = 20 sites and C = 20 Potts colors and cyclically sampling the spins by exploiting the knowledge of the exact conditional probabilities, eq. ( 5).In order to decorrelate the samples, 10000 Gibbs sweeps were made between each of the two saved configurations.
The simulations on the left panel of fig.3, have been performed by sampling the input data-points from a multivariate Gaussian distribution and the masked token from the same distribution, conditioned on the other elements in the sequence.The optimization problem in eq. ( 6) is then solved in closed-form thanks to the Moore-Penrose inverse as in Ref. [18].
Statistical physics considers learning as a dynamical and exploratory process across the space of the learnable parameters.At equilibrium, these parameters are assumed to follow a Boltzmann-Gibbs distribution, where the role of the Hamiltonian is actually played by the loss function: with β being the inverse temperature and D the training set.In the zero temperature limit (i.e.β → ∞), the Boltzmann-Gibbs distribution concentrates around the minima of the loss function, which are merely the solutions of the optimization problem in eq.6: Up to this point, re-framing a machine learning problem in terms of statistical physics did not seem to be very advantageous since sampling from a high-dimensional Boltzmann-Gibbs distribution is known to be impracticable.This is where the replica theory comes into play.In particular, it states that, in the high-dimensional limit (i.e.M, L → ∞ with α ≡ M/L ∼ O (1)) the free-energy of a learning system concentrates around its typical value over the input data distribution: As we will see in the next section, this expectation can be tackled by means of the replica trick.From this quantity, all the high-dimensional metrics of interests, can be computed as a function of simple scalar quantities.This is for instance the case of the generalization loss in eq. ( 7).
In particular, the overlap parameters m ⋆ and q ⋆ correspond to practically measurable quantities over different realisation of the training set, involving the estimator of the ith row of the interaction matrix: In the next section we will outline the main steps of the replica trick leading to the generalization loss formula in eq. ( 7).

Replica Calculation
As anticipated in the previous section, the replica trick allows to compute the typical value of the free-energy density in eq.(B3) by expressing this quantity as a function of the solely replicated partition function Z n β , obtained by constructing n > 0 different and independent copies of the same learning system: a. Average over the training set.As a first step, the replica calculation focuses on the expectation of the replicated partition function over the training set, which, written in a more explicit form, looks like: with P G and P A being respectively the Gibbs and the Gaussian measure associated with the ith row of the attention matrix as in eq.(B1).Indeed, as already pointed out in the main manuscript, the interaction matrix is drawn from the Gaussian Orthogonal Ensemble, therefore its rows will correspond to Gaussian random vectors.At this point, we can notice an important aspect of MLM tasks.In this case, the labels are not provided by a teacher vector as in standard teacher-student settings.On the contrary, the masked tokens are directly sampled from the input distribution by conditioning over all the other elements composing the sequence: Note that, the noise in the labels arises a consequence of the one already affecting the input data points, meaning that its intensity can not be chosen independently from the intrinsic stochasticity of the input.Due to these considerations, by explicitly expressing the outer expectation in eq.(B6), we then obtain To compute the expectation over the input sequences with a masked element at position i, we first define the pre-activations as: then we express these definitions in terms of Dirac-deltas and their corresponding integral representation: By plugging these factors one into eq.(B8), we then get: The expectation over the masked input sequences is a simple multivariate Gaussian integral, whose solution is given by: where, as already pointed out in the main text, Σ \i corresponds to the covariance matrix of the masked input sequences, that is the input covariance matrix without the contribution of the row and column associated to the ith masked token.By replacing the solution of the ex-pectation in eq.(B11), we then get: Once again, to proceed further in the calculation, we can insert their definition by means of Dirac-deltas and their integral representation: By substituting the overlap definition in eq.(B13), plugging in the corresponding factors one and performing the change of variables: iρ → −ρ, ir a → ra and iq ab → qab , we can rewrite the averaged replicated partition function in terms of saddle-point integrals over the overlap parameters: dr a dr a 2π a≤b dq ab dq ab 2π e LΨ (n) (B16) where the action Ψ (n) is a non trivial function of the overlap parameters: B17) where Ψ s and Ψ e are the so-called entropic and energetic potential and, in the specific case of a single-layer factored attention, are given by: Note that, we have drop the dependency of Ψ e on the µ index since all the µ-dependent terms decouple with respect to µ.
c. Replica Symmetric Assumption.To proceed further in the calculation, we need to assume a specific replica structure.Since all replicas have been introduced independently from each other with no specific differences among them, it seems natural to assume that replicas should all play the same role and that, therefore, the overlap parameters should not depend on the specific replica index.In particular, under the replica symmetric ansatz, we assume: By plugging the replica symmetric assumption in eq.(B17)-(B18) and applying the following Hubbard Stratonovich transformations: with ξ ∼ N (0, 1), we then get the expression for the replica symmetric action: where we have defined δ q = ĝ + q and δq = g − q. the replica symmetric potentials Ψ s and Ψ e are are given by: d. Zero Replica limit.By taking the limit of n → 0 in eq. ( B21)-(B22) and solving the integrals with respect to the ẑ and ĥ variables, we then get the following expression for the action potential: e (B23) with the entropic and energetic potentials in the zero replicas limit given by: Note that, as in standard teacher-student settings [18], in order to avoid divergent terms in this limit, the overlap ρ and its conjugate ρ need to be constrained to E Ji J t i Σ \i J i /ν 2 and 0 respectively.e.Typical Free-Energy density.Having determined the expression for the replicated partition function in the zero-temperature limit, we can actually compute the typical free-energy density as: (B25) In the high-dimensional limit, we can solve the integrals over the overlap parameters by saddle-point, thus getting: f β = extr q,r,δq,q,r,δ q νrr + 1 2 (δq + q) (δ q − q) + 1 2 q q + + lim where the values of the overlap parameters extremizing the typical free-energy density can therefore be determined by solving the following system of coupled saddlepoint equations: (B27) Up to this point, we have performed the replica calculation in full generality, without specifying neither the interaction matrix nor the loss function.In the next section, we will evaluate the typical free-energy density for the specific MLM task of eq. ( 6) under the simplified assumptions of the sec.The sample complexity of selfattention of the main text.f.Zero temperature limit and Gaussian Priors.As already pointed out in the main text, the interaction matrix J is sampled from the GOE, it is then natural to assume a Gaussian prior on the ith row of the attention matrix: with β being the inverse temperature parameter, while λ is the L 2 regularization strength.Moreover, the optimization problem in eq. ( 6) optimizes a square loss to solve the corresponding MLM task.This means that, the Gibbs Measure of eq.(B1) associated with this task is: By plugging these two specific forms of both the prior and the Gibbs measure in eq.(B6) and taking the zero temperature limit as exemplified in [18,40], we get the following expression for the typical free-energy density in the zero-temperature limit: f = β→∞ extr q,r,δq,q,r,δ q νrr − 1 2 (δq q − qδ q) + + lim

FIG. 2 .
FIG. 2. A single layer of factored self-attention learns the generalized Potts model efficiently.(a) Test loss (3) for factored self-attention and for vanilla transformers with one and three layers during training with stochastic gradient descent.The optimal generalization loss is shown as a black dashed line.(b) Interaction matrix J of the generative Potts Model (1) compared to the attention maps learned by transformers with vanilla and factored self-attention.For the threelayer transformer, the attention map was obtained by averaging the maps of the last two layers.(c) Reconstruction error of the interaction (J − A) 2 as a function of the number of epochs for all considered architectures.(d) Test loss as a function of perturbation level a. Decoupling the treatment between positions and colors by decreasing a decreases the test loss.Parameters: sequence length L = 20, vocabulary size C = 20, embedding dimension d = 20, M = 3000 data points.

FIG. 3 .
FIG. 3. The interpolation peak of factored attention in theory and practice.(Left) A replica analysis predicts the test loss exactly.Test loss of a single layer of factored self-attention as a function of the number of samples per input dimension, as computed using replica theory (solid line).The blue points represent the outcome of numerical minimisation of the square loss (6), averaged over 30 realisations, and show perfect agreement with the theory.Error bars are smaller than point size.(Right) Same plot for a single layer of factored self-attention in the setting of fig. 2 (L = C = 20), showing the same qualitative behaviour.The simulations are averaged over n = 30 different realisations.
the averaged replicated partition function in terms of saddle-point integrals.As a consequence of the average over the training set, the different replicas are now interacting among each other through the following set of overlap parameters: