注意力究竟是什么?

向量的点积和矩阵表示

x\mathbf{x}y\mathbf{y} 表示维度为 nn 的向量,即:

x=[x1,x2,...,xn]y=[y1,y2,...,yn]\begin{aligned} \mathbf{x} &= [x_1, x_2, ..., x_n] \\ \mathbf{y} &= [y_1, y_2, ..., y_n] \end{aligned}

x\mathbf{x}y\mathbf{y} 的点积运算(\cdot)可以表示这两个向量之间的相似度,向量点积越大,表明两个向量越相似。

xy=i=1nxiyi=x1y1+x2y2+...+xnyn\mathbf{x} \cdot \mathbf{y} = \sum_{i=1}^{n}{x_iy_i} = x_1y_1 + x_2y_2 +...+ x_ny_n

如果用向量 x\mathbf{x} 表示一个词的词向量,对于 mm 个词向量而言,其任意两个词向量之间的相似度可以表示为:

xixj=k=1nxikyjk=xi1yj1+xi2yj2+...+xinyjni,j=1,2,,m\mathbf{x_i} \cdot \mathbf{x_j} = \sum_{k=1}^{n}{x_{ik}y_{jk}} = x_{i1}y_{j1} + x_{i2}y_{j2} + ... + x_{in}y_{jn} \quad \forall i,j = 1,2,\cdots,m

如果用 m×nm \times n 的矩阵 X\mathbf{X} 来表示 mm 个词向量,

X=[x1x2xm]=[x11x12x1nx21x22x2nxm1xm2xmn]\mathbf{X} = \begin{bmatrix} \mathbf{x_1} \\ \mathbf{x_2} \\ \vdots \\ \mathbf{x_m} \end{bmatrix} = \begin{bmatrix} x_{11} & x_{12} & \cdots & x_{1n} \\ x_{21} & x_{22} & \cdots & x_{2n} \\ \vdots & \vdots & \vdots & \vdots \\ x_{m1} & x_{m2} & \cdots & x_{mn} \end{bmatrix}

S=XXT\mathbf{S} = \mathbf{X} \mathbf{X}^Tm×mm \times m 的矩阵,并且 sij=xixjs_{ij} = \mathbf{x_i} \cdot \mathbf{x_j} 表示第 ii 个词向量和第 jj 个词向量之间的相似度。

XXT=[x1x1x1x2x1xmx2x1x2x2x2xmxmx1xmx2xmxm]\mathbf{X}\mathbf{X}^T = \begin{bmatrix} \mathbf{x_1} \cdot \mathbf{x_1} & \mathbf{x_1} \cdot \mathbf{x_2} & \cdots & \mathbf{x_1} \cdot \mathbf{x_m} \\ \mathbf{x_2} \cdot \mathbf{x_1} & \mathbf{x_2} \cdot \mathbf{x_2} & \cdots & \mathbf{x_2} \cdot \mathbf{x_m} \\ \vdots & \vdots & \vdots & \vdots \\ \mathbf{x_m} \cdot \mathbf{x_1} & \mathbf{x_m} \cdot \mathbf{x_2} & \cdots & \mathbf{x_m} \cdot \mathbf{x_m} \end{bmatrix}

使用 softmax() 函数对矩阵 S\mathbf{S} 进行归一化处理得到矩阵 W=softmax(XXT)\mathbf{W} = \text{softmax}\left(\mathbf{X}\mathbf{X}^T\right),使得 W\mathbf{W} 中每一行的所有元素之和都为 1,于是每一行的各个元素就可以看作是一个权重。

wi1,wi2,,wimw_{i1},w_{i2},\cdots,w_{im} 分别表示第 ii 个词向量和第 1,2,,m1,2,\cdots,m 个词向量之间的相似度权重,用该权重分别乘以对应的词向量,于是我们得到了第 ii 个词的新的表达形式 zi\mathbf{z_i},某个词向量的权重越大,表示相似度越高。

zi=k=1mwikxk=wi1x1+wi2x2++wimxm\mathbf{z_i} = \sum_{k=1}^{m}{w_{ik}\mathbf{x_k}} = w_{i1}\mathbf{x_1} + w_{i2}\mathbf{x_2} + \cdots + w_{im}\mathbf{x_m}

对于 mm 个词向量都执行如上的操作可以得到:

Z=[z1z2zm]=[w11x1+w12x2++w1mxmw21x1+w22x2++w2mxmwm1x1+wm2x2++wmmxm]\mathbf{Z} = \begin{bmatrix} \mathbf{z_1} \\ \mathbf{z_2} \\ \vdots \\ \mathbf{z_m} \end{bmatrix} = \begin{bmatrix} w_{11}\mathbf{x_1} + w_{12}\mathbf{x_2} + \cdots + w_{1m}\mathbf{x_m} \\ w_{21}\mathbf{x_1} + w_{22}\mathbf{x_2} + \cdots + w_{2m}\mathbf{x_m} \\ \vdots \\ w_{m1}\mathbf{x_1} + w_{m2}\mathbf{x_2} + \cdots + w_{mm}\mathbf{x_m} \end{bmatrix}

实际上,可以证明:

Z=WX=softmax(XXT)X\mathbf{Z} = \mathbf{W}\mathbf{X} = \text{softmax}\left(\mathbf{X}\mathbf{X}^T\right)\mathbf{X}

于是,对于原始的词向量矩阵 X\mathbf{X} 而言,经过一些列的矩阵乘法运算,我们得到了根据相关性权重的加权词向量表示。也就是说原始的词向量 xi\mathbf{x_i} 可以表示为所有词向量的加权表示 zi\mathbf{z_i}。所以,可以用 Attention 来解释一句话中不同词之间的相互关系。Transform 模型中的 Attention 其实就是矩阵 Z\mathbf{Z},所以 Transform 可以理解输入序列中的不同的部分,并分析输入序列中不同词之间的关系,进而捕获到上下文信息。

了解如上介绍的 softmax(XXT)X\text{softmax}\left(\mathbf{X}\mathbf{X}^T\right)\mathbf{X} 背后的逻辑对于理解 Transform 中的各个矩阵的含义至关重要,因此我们花了很大的篇幅来对其进行分析。

接下来,我们用一个具体的例子来展示如上的过程。

举个例子🌰

I am good 这句话为例,我们用词向量 x1\mathbf{x_1} 表示 Ix2\mathbf{x_2} 表示 amx3\mathbf{x_3} 表示 good,对应的词向量分别为:

x1=[1,3,2]x2=[1,1,3]x3=[1,2,1]\begin{aligned} \mathbf{x_1} &= [1, 3, 2] \\ \mathbf{x_2} &= [1, 1, 3] \\ \mathbf{x_3} &= [1, 2, 1] \end{aligned}

所以,我们有矩阵 X\mathbf{X}

X=[132113121]\mathbf{X} = \begin{bmatrix} 1 & 3 & 2 \\ 1 & 1 & 3 \\ 1 & 2 & 1 \end{bmatrix}

于是,W\mathbf{W} 矩阵的计算过程如下:

权重矩阵 W\mathbf{W} 中某一行分别与 X\mathbf{X} 的一列相乘,如前所述,该操作相当于对 X\mathbf{X} 中各行向量(不同的词向量)加权求和,得到的结果是每个词向量 xi\mathbf{x_i} 经过加权求和之后的新表示 zi\mathbf{z_i},而权重矩阵则是经过相似度和归一化计算而得到的。

zI=0.97xI+0.02xam+0.01xgoodzam=0.27xI+0.73xam+0.00xgoodzgood=0.90xI+0.05xam+0.05xgood\begin{aligned} \mathbf{z_{I}} &= 0.97 \cdot \mathbf{x_{I}} + 0.02 \cdot \mathbf{x_{am}} + 0.01 \cdot \mathbf{x_{good}} \\ \mathbf{z_{am}} &= 0.27 \cdot \mathbf{x_{I}} + 0.73 \cdot \mathbf{x_{am}} + 0.00 \cdot \mathbf{x_{good}} \\ \mathbf{z_{good}} &= 0.90 \cdot \mathbf{x_{I}} + 0.05 \cdot \mathbf{x_{am}} + 0.05 \cdot \mathbf{x_{good}} \end{aligned}

R 实现过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
softmax <- function(mat) {
exp_mat <- exp(mat) # 对矩阵中的每个元素取指数
row_sums <- rowSums(exp_mat) # 计算每行的总和
softmax_mat <- sweep(exp_mat, 1, row_sums, FUN = "/") # 对每行进行归一化
return(softmax_mat)
}

X <- matrix(c(1, 3, 2, 1, 1, 3, 1, 2, 1), nrow = 3, byrow = TRUE)
W <- softmax(X %*% t(X))
Z <- W %*% X
print(Z)

#=================================================================

[,1] [,2] [,3]
[1,] 1 2.957691 2.011295
[2,] 1 1.540148 2.722573
[3,] 1 2.864164 2.000000

Transformer 中的 Attention

论文 Attention Is All You Need 的 3.2 节对 Attention 的描述如下[1]

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors.

The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

论文中也给出了 Attention 计算公式:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V})=\text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\mathbf{V}

其中,Q\mathbf{Q}K\mathbf{K}V\mathbf{V} 都是矩阵,分别代表 QueryKeyValuedkd_kK\mathbf{K} 的行向量维度。 QueryKeyValue 都是为了计算 注意力 而引入的抽象的概念,它们都是对原始的输入 X\mathbf{X} 的线性变换。

Q=XWQK=XWKV=XWV\begin{aligned} \mathbf{Q} &= \mathbf{X}\mathbf{W}^Q \\ \mathbf{K} &= \mathbf{X}\mathbf{W}^K \\ \mathbf{V} &= \mathbf{X}\mathbf{W}^V \end{aligned}

因为 X\mathbf{X}m×nm \times n 的矩阵,所以如果令 WQ\mathbf{W}^QWK\mathbf{W}^KWV\mathbf{W}^V 均是 n×nn \times n 的单位矩阵 I\mathbf{I},那么 Q\mathbf{Q}K\mathbf{K}V\mathbf{V} 经过线性变换(I\mathbf{I})后仍然是 X\mathbf{X},此时我们可以用 X\mathbf{X} 替换 Q\mathbf{Q}K\mathbf{K}V\mathbf{V},那么公式就变成了:

Attention(Q,K,V)=softmax(XXTdk)X\text{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V})=\text{softmax}\left(\frac{\mathbf{X}\mathbf{X}^T}{\sqrt{d_k}}\right)\mathbf{X}

这也就是为什么我们之前面说:Transform 模型中的 Attention 其实就是对原始输入词向量的加权求和而得到的新的表示,在新的表示中,Transform 可以理解输入序列中的不同的部分,并分析输入序列中不同词之间的关系,进而捕获到上下文信息。

实际上,为了增强增强模型的拟合能力,我们并不会采用单位矩阵 I\mathbf{I} 对矩阵 X\mathbf{X} 做线性变换,而是分别采用 WQ\mathbf{W}^QWK\mathbf{W}^KWV\mathbf{W}^V 这三个可以通过大量语料训练而学习到的参数矩阵(参数矩阵可以是任何维度,但行向量个数必须和 X\mathbf{X} 的行向量维度一致)。所以,根据 Attention(Q,K,V)\text{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V}),Transform 可以理解一句话中不同词之间的相互关系。

Q,K,V\mathbf{Q},\mathbf{K},\mathbf{V} 和原始输入之间的关系如下图所示:

参考文献


  1. Attention Is All You Need ↩︎

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2020-2024 wangwei
  • 本站访问人数: | 本站浏览次数:

请我喝杯咖啡吧~

微信