Skip to content

Latest commit

 

History

History
99 lines (53 loc) · 4.34 KB

阿里RE2-将残差连接和文本匹配模型融合.md

File metadata and controls

99 lines (53 loc) · 4.34 KB

RE2 这个名称来源于该网络三个重要部分的合体:Residual vectors;Embedding vectors;Encoded vectors;

掌握这个论文,最重要的一个细节点就是了解如何将增强残差连接融入到模型之中

1.架构图

先来看架构图,如下:

这个架构图很精简,所以不太容易理解。

大体上区分可以分为三层。第一层就是输入层,第二个就是中间处理层,第三个就是输出层。

中间处理层我们可以称之为block,就是画虚线的部分,可以被循环为n次,但是需要注意的是每个block不是共享的,参数是不同的,是独立的,这点需要注意。

2.增强残差连接

其实这个论文比较有意思的点就是增强残差连接这里。架构图在这里其实很精简,容易看糊涂,要理解还是要看代码和公式。

2.1 第一个残差

首先假设我们的句子长度为$l$,然后对于第n个block(就是第n个虚线框的部分)。

它的输入和输出分别是:$x^{(n)}=(x_{1}^{(n)},x_{2}^{(n)},...,x_{l}^{(n)})$ 和$o^{(n)}=(o_{1}^{(n)},o_{2}^{(n)},...,o_{l}^{(n)})$;

首先对一第一个block,也就是$x^{(1)}$,它的输入是embedding层,注意这里仅仅是embedding层;

对于第二个block,也就是$x^{(2)}$,它的输入是embedding层(就是初始的embedding层)和第一个block的输出$o^{(1)}$拼接在一起;

紧接着对于n大于2的情况下,也就是对于第三个,第四个等等的block,它的输入形式是这样的;

理解的重点在这里:在每个block的输入,大体可以分为两个部分,第一部分就是初始的embedding层,这个永远不变,第二个部分就是此时block之前的两层的blocks的输出和;这两个部分进行拼接。

这是第一个体现残差的部分。

2.2第二个残差

第二个残差的部分在block内部:

alignment层之前的输入就有三个部分:第一部分就是embedding,第二部分就是前两层的输出,第三部分就是encoder的输出。

这点结合着图就很好理解了。

3.Alignment Layer

attention这里其实操作比较常规,和ESIM很类似,大家可以去看之前这个文章。

公式大概如下:

这里有一个细节点需要注意,在源码中计算softmax之前,也是做了类似TRM中的缩放,也就是参数,放个代码:

#核心代码
def __init__(self, args, __):
        super().__init__()
        self.temperature = nn.Parameter(torch.tensor(1 / math.sqrt(args.hidden_size)))

def _attention(self, a, b):
        return torch.matmul(a, b.transpose(1, 2)) * self.temperature

4.Fusion Layer

融合层,就是对attentino之前和之后的特征进行一个融合,具体如下:

三种融合方式分别是直接拼接,算了对位减法然后拼接,算了对位乘法然后拼接。最后是对三个融合结果进行拼接。

有一个很有意思的点,作者说到减法强调了两句话的不同,而乘法强调了两句话相同的地方。

5.Prediction Layer

Pooling层之后两个句子分别得到向量表达:$v_{1}$和$v_{2}$

三个表达方式,各取所需就可以:

6. 总结

简单总结一下,这个论文最主要就是掌握残差连接。

残差体现在模型两个地方,一个是block外,一个是block内;

对于block,需要了解的是,每一个block的输入是有两部分拼接而成,一个是最初始的embeddding,一个是之前两层的输出和。

对于block内,需要注意的是Alignment之前,有三个部分的输入一个是最初始的embeddding,一个是之前两层的输出和,还有一个是encoder的输出。