Transformer and Bert#

2018๋…„ ๋‹น์‹œ์— ๋‰ด์˜ฅ ํƒ€์ž„์ง€์—์„œ Oren Etzioni, chief executive of the Allen Institute for Artificial Intelligence๊ฐ€ ๋งํ•˜๊ธธ,
๊ธฐ๊ณ„๊ฐ€ ์•„์ง ์ธ๊ฐ„์˜ ๋ณดํ†ต ๊ฐ๊ฐ์„ ํ‘œํ˜„ํ•  ์ˆ˜๋Š” ์—†์ง€๋งŒ, Bert๋Š” ํญ๋ฐœ์ ์ธ ๋ฐœ์ „์˜ ์ˆœ๊ฐ„์ด๋ผ๊ณ  ํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด Bert ๋ชจ๋ธ์— ๊ธฐ์ดˆ๊ฐ€ ๋œ Transformer๋Š” ์–ดํ…์…˜ ๋งค์ปค๋‹ˆ์ฆ˜์„ ์‚ฌ์šฉํ•˜์—ฌ Encoder-Decoder๋กœ ๊ตฌ์„ฑ๋˜๋Š” ๊ตฌ์กฐ๋ฅผ ๋ณด๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

๋˜ํ•œ Bert๊ตฌํ˜„์˜ ๋ฐ”ํƒ•์ด ๋ฌ๋‹ค๊ณ  ์ง์ ‘ ๋…ผ๋ฌธ์—์„œ ์–ธ๊ธ‰ํ•œ The Annotated Transformer๋„ ๋ด…๋‹ˆ๋‹ค

Attension Mechanism#

์–ดํ…์…˜์€ ํŠน์ • ์‹œํ€€์Šค๋ฅผ ์ถœ๋ ฅํ•˜๊ธฐ ์œ„ํ•ด ์ž…๋ ฅ ์‹œํ€€์Šค์˜ ์–ด๋– ํ•œ ๋ถ€๋ถ„์„ ๊ฐ•์กฐํ•ด์•ผ ๋ ๋Š”์ง€ ํ•™์Šต์„ ํ•  ์ˆ˜ ์žˆ๋Š” ๋งค์ปค๋‹ˆ์ฆ˜์ด๋‹ค.

์–ดํƒ ์…˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์€ ๊ฐ„๋‹จํžˆ ๋งํ•ด์„œ ํŠน์ • ๋‹จ์–ด๋ฅผ ๊ฐ•์กฐํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ์ž…๋ ฅ ์‹œํ€€์Šค ์ค‘์—์„œ ํŠน์ • ๋‹จ์–ด์™€ ๋‹ค๋ฅธ ๋‹จ์–ด๊ฐ€ ์‹œํ€€์Šค์—์„œ ์ถœํ˜„์‹œ ๊ฐ•์กฐ๋˜๋Š” ๊ฒƒ์ด๋ฉฐ, ๊ทธ๋Ÿฌํ•œ ๊ฐ•์กฐ ์ •๋ณด๊ฐ€ ์ž…๋ ฅ ์‹œํ€€์Šค์— ์ ์šฉ๋˜์–ด์„œ ๋””์ฝ”๋”์— ์ž…๋ ฅ๋˜๊ณ , ๋งค ๋””์ฝ”๋” ์‹œํ€€์Šค๋งˆ๋‹ค ์ด๋Ÿฌํ•œ ๊ณ„์‚ฐ์ด ์ง„ํ–‰๋˜๋ฉฐ ์ˆ˜๋งŽ์€ ๋ฌธ์žฅ์ด ํ•™์Šต๋˜๋ฉด์„œ ์ธ์ฝ”๋” ๋””์ฝ”๋”์— ์ž…๋ ฅ๋˜๋Š” ๋‹จ์–ด๋“ค์˜ ์ƒํ˜ธ๊ฐ„์˜ ์ปจํ…์ŠคํŠธ๊ฐ€ ํ•™์Šต์ด ๋˜๋Š” ๊ตฌ์กฐ์ด๋‹ค.

$$score(s_t, h_i )= s_t^T h_i \e^t=[s_t^T h_1,โ€ฆ,s_t^T h_N] \a^t=softmax(e^t) \c_t = \sum_{i=1}^{N}a_i^t h_i$$

์› ๋…ผ๋ฌธ์—์„œ๋Š” t๋ฅผ ํ˜„์žฌ์‹œ์ ์ด๋ผ๊ณ  ํ•  ๋•Œ, ์ธ์ฝ”๋” ์ถœ๋ ฅ๋ฒกํ„ฐ(s)์™€ ์€๋‹‰ ์ƒํƒœ ๋ฒกํ„ฐ(h)๋ฅผ ๋‚ด์ ํ•œ ํ›„์— ์†Œํ”„ํŠธ๋งฅ์Šค(softmax)๋ฅผ ํ•œ๋‹ค๋ฉด ์ด๋ฅผ ์–ดํ…์…˜ ๋ถ„ํฌ(attention distribution), ๊ฐ๊ฐ์˜ ๊ฐ’์„ ์–ดํ…์…˜ ๊ฐ€์ค‘์น˜(attention weight)๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ฐ€์ค‘์น˜๋ฅผ ๋ชจ๋‘ ๋”ํ•œ๋‹ค๋ฉด ์ตœ์ข… ์ถœ๋ ฅ ์–ดํ…์…˜ ๊ฐ’(attention value)์ด์ž ๋ฌธ๋งฅ ๋ฒกํ„ฐ(context vector)๋ผ๊ณ  ์ •์˜ํ•œ๋‹ค. ๊ทธ ํ›„ ์‹ค์ œ ์˜ˆ์ธก์„ ์œ„ํ•ด ์–ดํ…์…˜ ๋ฒกํ„ฐ์™€ ์ธ์ฝ”๋” ์ถœ๋ ฅ๋ฒกํ„ฐ๋ฅผ ๊ฒฐํ•ฉ(concatenate)์‹œ์ผœ ์˜ˆ์ธกํ•œ๋‹ค.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
%matplotlib inline

Transformer#

ํŠธ๋žœ์Šคํฌ๋จธ๋Š” ์ „๋ฐ˜์ ์ธ ์‹œํ€€์Šค ์ „๋‹ฌ ๋ชจ๋ธ์€ ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋”๋ฅผ ํฌํ•จํ•˜๋Š” ๋ณต์žกํ•œ ์ˆœํ™˜(recurrent) ๋˜๋Š” ํ•ฉ์„ฑ๊ณฑ(convolution)์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•œ๋‹ค. ์ตœ๊ณ ์˜ ์„ฑ๋Šฅ์˜ ๋ชจ๋ธ๋“ค์€ ๋˜ํ•œ ์–ดํ…์…˜(attention) ์„ ํ†ตํ•ด ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋”๋ฅผ ์—ฐ๊ฒฐํ•˜๋Š”๋ฐ, ์ƒˆ๋กœ์šด ๋‹จ์ˆœํ•œ ๋„คํŠธ์›Œํฌ ์•„ํ‚คํ…์ฒ˜์ธ ํŠธ๋žœ์Šคํฌ๋จธ๋ฅผ ์˜ค๋กœ์ง€ ์–ดํ…์…˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์— ๊ธฐ์ดˆํ•˜๊ณ , recurrent์™€ convolution๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š”๋‹ค. ๋…ผ๋ฌธ์—์„œ๋Š” 2๊ฐ€์ง€ ๋ฒˆ์—ญ ์ž‘์—… ์‹คํ—˜์—์„œ ๋ชจ๋ธ์€ ๋ณ‘๋ ฌ์ ์ด๊ณ  ๋™์‹œ์— ํ’ˆ์งˆ์ด ์šฐ์ˆ˜ํ•˜๋‹ค๋Š” ๊ฒƒ์„ ๋ณด์—ฌ์คฌ๋‹ค.

Positional Encoding#

ํ•ด๋‹น ๋ชจ๋ธ์—์„œ๋Š” ์ˆœํ™˜(recurrence)์ด๋‚˜ ํ•ฉ์„ฑ๊ณฑ(convolution)์„ ์ „ํ˜€ ์‚ฌ์šฉํ•˜์ง€ ์•Š์•˜๊ธฐ ๋•Œ๋ฌธ์—, ๋ฐ˜๋“œ์‹œ ์œ„์น˜ ์ •๋ณด๋ฅผ ๋„ฃ์–ด์ค˜์•ผ ํ•œ๋‹ค. ๋”ฐ๋ผ์„œ positional encoding์„ ์‚ฌ์šฉํ•ด์„œ ์ž…๋ ฅ ์ž„๋ฒ ๋”ฉ์— ์œ„์น˜ ์ •๋ณด๋ฅผ ๋„ฃ์–ด์ค€๋‹ค. ๊ฐ ์œ„์น˜์— ๋Œ€ํ•ด์„œ ์ž„๋ฒ ๋”ฉ๊ณผ ๋™์ผํ•œ ์ฐจ์›์„ ๊ฐ€์ง€๋„๋ก ์ธ์ฝ”๋”ฉ์„ ํ•ด์ค€ ๋’ค ๊ทธ ๊ฐ’์„ ์ž„๋ฒ ๋”ฉ๊ฐ’๊ณผ ๋”ํ•ด์„œ ์‚ฌ์šฉํ•œ๋‹ค.

positional encoding์—๋Š” ์—ฌ๋Ÿฌ ๋ฐฉ๋ฒ•์ด ์žˆ์ง€๋งŒ ์—ฌ๊ธฐ์„œ๋Š” ์ž…๋ ฅ ๋ฌธ์žฅ๊ธธ์ด์— ๋Œ€ํ•œ ์ œ์•ฝ์‚ฌํ•ญ์„ ์ค„์ด๋ฆฌ ์œ„ํ•ด sin, cos ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด์„œ ์ •ํ˜•ํŒŒ๋กœ ๊ตฌํ˜„ํ•ด์„œ ์‚ฌ์šฉํ–ˆ๋‹ค. ๊ฐ ์œ„์น˜ pos์™€ dimension i์— ๋Œ€ํ•œ positional encoding๊ฐ’์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๊ตฌํ•œ๋‹ค.

$${PE}{(pos,2i)}=sinโก(pos/10000^{2i/d{model}}) \{PE}{(pos,2i+1)}=cosโก(pos/10000^{2i/{d}{model} })$$

class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return self.dropout(x)
plt.figure(figsize=(15, 5))
pe = PositionalEncoding(20, 0)
y = pe.forward(Variable(torch.zeros(1, 100, 20)))
plt.plot(np.arange(100), y[0, :, 4:8].data.numpy())
plt.legend(["dim %d"%p for p in [4,5,6,7]])
None
../../_images/0fb5738135ee74279f190fab71f5a7ad720d3e348d635e905311d360961b9e77.png

Positional Encoding Bert#

์ด์— ๋ฐ˜ํ•ด bert์—์„œ๋Š” ์ž…๋ ฅ token embedding๊ณผ Position Embeddings, Segment Embeddings, ์ถ”๊ฐ€ํ•ด ๊ฐ๊ฐ์˜ ์ž„๋ฒ ๋”ฉ, ์ฆ‰ 3๊ฐœ์˜ ์ž„๋ฒ ๋”ฉ์„ ํ•ฉ์‚ฐํ•˜์—ฌ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค. ์—ฌ๊ธฐ์„œ Segment๋Š” Bert์˜ ํŠน์„ฑ์ƒ 2๊ฐœ์˜ ๋ฌธ์žฅ์„ ์‚ฌ์šฉํ•˜๋Š”๋ฐ ๊ทธ ๋ฌธ์žฅ์„ ๊ตฌ๋ถ„ํ•˜๋Š” ๋ฒกํ„ฐ์ด๋‹ค.

Encoder & Decoder#

์ถœ๋ ฅ์€ ๋ชจ๋‘ 512์ฐจ์›

Encoder#

์ธ์ฝ”๋”๋Š” ๋™์ผํ•œ ๊ณ„์ธต(layer)๊ฐ€ N๊ฐœ ๋ฐ˜๋ณต๋˜๋Š” ํ˜•ํƒœ์ธ๋ฐ, ์› ๋…ผ๋ฌธ์—์„œ๋Š” 6๋ฒˆ ๋ฐ˜๋ณตํ–ˆ๋‹ค. Encoder๋Š” ๊ณ„์ธต์€ ๋‘๊ฐœ์˜ ํ•˜์œ„ ๊ณ„์ธต(sub-layer)๋กœ ๊ตฌ์„ฑ๋œ๋‹ค. ์ฒซ ํ•˜์œ„ ๊ณ„์ธต์€ ๋ฉ€ํ‹ฐํ—ค๋“œ(multi-head) ์ž๊ฐ€ ์–ดํ…์…˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜(self-attention mechanism)์ด๊ณ  ๋‘๋ฒˆ์งธ๋Š” ๊ฐ„๋‹จํ•˜๊ฒŒ ์ ๋ณ„์ˆ˜๋ ด(point-wise)ํ•˜๋Š” ์™„์ „์—ฐ๊ฒฐ์ธต(fc-layer)์ด. ๊ทธ๋ฆฌ๊ณ  ๋ชจ๋ธ ์ „์ฒด์ ์œผ๋กœ ๊ฐ ํ•˜์œ„ ๊ณ„์ธต์— RC(residual connection)๊ฐ€ ์ „๋‹ฌ๋˜๊ณ , ์ด๋Š” ์—ญ์ „ํŒŒ๊ฐ€ ๊ณ„์‚ฐ๋˜์–ด ๊ฒฝ์‚ฌ ํ•˜๊ฐ•์ด ๋  ๋•Œ ์›๋ณธ ๊ฐ’์„ ๋”ํ•œํ›„์— ์˜ค์ฐจ(Loss)๊ฐ€ ๊ณ„์‚ฐ๋œ๋‹ค. ๊ทธ ํ›„ ๊ณ„์ธต ๊ฐ’์„ ๋ ˆ์ด์–ด ์ •๊ทœํ™”(Layer Normalization)ํ•œ๋‹ค. ์ฆ‰ ๊ฐ ํ•˜์œ„ ๊ณ„์ธต์€ ๊ฒฐ๊ณผ์— ๋Œ€ํ•ด ์ž”์ฐจ ๊ฐ’์„ ๋”ํ•˜๊ณ  ๊ทธ ๊ฐ’์„ ๋ ˆ์ด์–ด ์ •๊ทœํ™” ํ•œ ๊ฐ’์ด ์ถœ๋ ฅ์œผ๋กœ ๋‚˜์˜ค๊ฒŒ ๋œ๋‹ค.

Decoder#

๋””์ฝ”๋”๋„ ์ธ์ฝ”๋”์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ๋™์ผํ•œ ๊ณ„์ธต์ด N๊ฐœ ๋ฐ˜๋ณต๋˜๋Š” ํ˜•ํƒœ์ด๊ณ  6๋ฒˆ ๋ฐ˜๋ณตํ•œ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋ฐ˜๋ณต๋˜๋Š” ๊ณ„์ธต์€ ์ธ์ฝ”๋”์™€๋Š” ๋‹ค๋ฅด๊ฒŒ 3๊ฐœ์˜ ํ•˜์œ„ ๊ณ„์ธต์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ๋Š”๋ฐ, 2๊ฐœ๋Š” ๊ธฐ์กด์˜ ์ธ์ฝ”๋”์˜ ํ•˜์œ„ ๊ณ„์ธต๊ณผ ๋™์ผํ•˜๊ณ  ๋‚˜๋จธ์ง€ ํ•˜๋‚˜๋Š” ์ธ์ฝ”๋”์˜ ์ถœ๋ ฅ์— ๋Œ€ํ•ด ๋ฉ€ํ‹ฐํ—ค๋“œ ์–ดํ…์…˜์„ ๊ณ„์‚ฐํ•˜๋Š” ํ•˜์œ„ ๊ณ„์ธต์ด ์ถ”๊ฐ€๋ฌ๊ณ  RC์™€ ์ •๊ทœํ™”๊ฐ€ ์ด๋ฃจ์–ด์ง„๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ž๊ฐ€ ์–ดํ…์…˜์€ ์ธ์ฝ”๋”์™€๋Š” ์•ฝ๊ฐ„ ๋‹ค๋ฅด๊ฒŒ ๋งˆ์Šคํ‚น(masking)์„ ์ถ”๊ฐ€ํ–ˆ๋Š”๋ฐ, ์ž๊ฐ€ ์–ดํ…์…˜์„ ํ•  ๋•Œ ํ˜„์žฌ ์œ„์น˜๋ณด๋‹ค ๋’ค์— ์žˆ๋Š” ๋‹จ์–ด๋Š” ๋ณ€ํ•˜์ง€ ๋ชปํ•˜๋„๋ก ๋งˆ์Šคํ‚น์„ ์ถ”๊ฐ€ํ•ด์คฌ๋‹ค. ๋‹ค๋ฅธ์œ„์น˜์˜ ๋‹จ์–ด๋Š” auto-regressiveํ•œ ํŠน์„ฑ์„ ์ด์šฉํ•ด ์•Œ๊ณ  ์žˆ๋Š” ์ •๋ณด๋กœ๋งŒ ๊ณ„์‚ฐํ•œ๋‹ค.

class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)
 
    def forward(self, x, memory, src_mask, tgt_mask):
        "Follow Figure 1 (right) for connections."
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)    

RC(Residual Connection)#

๊ฐ ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋”๋Š” residual connection

residual connection์„ ์ˆ˜์‹์œผ๋กœ ๋‚˜ํƒ€๋‚ธ๋‹ค๋ฉด $y_l=h(x_l)+F(x_l,W_l )$
์—ฌ๊ธฐ์„œ $f(y_l)$๋Š” ${x}_{l+1}$์˜ ํ•ญ๋“ฑํ•จ์ˆ˜๊ณ  $h(x_l )$๋Š” $x_l$ ๋กœ ๋งตํ•‘๋œ๋‹ค.

์ด ๋•Œ, $x_{(l+1)}$ โ‰ก $y_l$ ๋ผ๊ณ ํ•œ๋‹ค๋ฉด, $x_{(l+1)}=x_l+F(x_l,W_l )$ ์ด๊ณ 
์žฌ๊ท€์ ์œผ๋กœ $(x_{(l+2)}=x_{(l+1)}+F(x_{(l+1)},W_{(l+1)}) =x_l+ F(x_l, W_l)+F(x_{(l+1)},W_{(l+1)}), etc.).$

$$x_L=x_l+\sum\limits^{L-1}{i=1}F(x_i,W_i)$$ ์ด ์‹์„ ๋ฏธ๋ถ„ํ•˜๋ฉด $\frac{โˆ‚ฮต}{โˆ‚x_l}=\frac{โˆ‚ฮต}{โˆ‚x_L} \frac{โˆ‚x_L}{โˆ‚x_l} = \frac{โˆ‚ฮต}{โˆ‚x_L} (1+\frac{โˆ‚}{โˆ‚x_l} \sum\limits^{L-1}{i=1} F(x_i,W_i))$

์—ฌ๊ธฐ์„œ $\frac{โˆ‚ฮต}{โˆ‚x_L}$ ๋Š” ๋ชจ๋“  ๋ ˆ์ด์–ด์— ์ ์šฉ ๋˜๊ณ , F๊ฐ€ 0์ด ๋˜๋Š” ๊ฒฝ์šฐ๋Š” ํฌ๋ฐ•ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๊ฐ€์ค‘์น˜ $ฮต$ ๊ฐ€ ๋งค์šฐ ์ž‘๋”๋ผ๋„ Vanishing Gradient๋˜๋Š” ๊ฒฝ์šฐ๋Š” ๊ฑฐ์˜ ์—†๋‹ค.

class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

Scaled Dot-Product Attention#

ํ•ด๋‹น ์–ดํ…์…˜์˜ ์ž…๋ ฅ์€ 3๊ฐ€์ง€์ด๋‹ค. D๊ฐœ ์ฐจ์›์„ ๊ฐ€์ง€๋Š” queries(Q)์™€ keys(K), values(V)๋กœ ๊ตฌ์„ฑ๋œ๋‹ค. ๋จผ์ € Q๋Š” ์ฃผ๋กœ ๋””์ฝ”๋”์˜ ์€๋‹‰ ์ƒํƒœ ๋ฒกํ„ฐ, K๋Š” ์ธ์ฝ”๋”์˜ ์€๋‹‰ ์ƒํƒœ ๋ฒกํ„ฐ, V๋Š” K์— ์ •๋ ฌ ๋ชจ๋ธ(alignment model)๋กœ ๊ณ„์‚ฐ๋œ ์–ดํ…์…˜ ๊ฐ€์ค‘์น˜์ด๋‹ค.

  • Query: query๋Š” ๋‹ค๋ฅธ ๋ชจ๋“  ๋‹จ์–ด์— ๋Œ€ํ•ด ์ ์ˆ˜๋ฅผ ๋งค๊ธฐ๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ๋˜๋Š” ํ˜„์žฌ ๋‹จ์–ด์˜ ํ‘œํ˜„์ด๋‹ค(ํ‚ค ์‚ฌ์šฉ). ์šฐ๋ฆฌ๋Š” ํ˜„์žฌ ์ง„ํ–‰์ค‘์ธ ํ”„๋กœ์„ธ์Šค ํ† ํฐ์˜ ์งˆ์˜์—๋งŒ ์‹ ๊ฒฝ์„ ์“ด๋‹ค.

  • Key: key ๋ฒกํ„ฐ๋Š” ์„ธ๊ทธ๋จผํŠธ์— ์žˆ๋Š” ๋ชจ๋“  ๋‹จ์–ด์— ๋Œ€ํ•œ ๋ผ๋ฒจ๊ณผ ๊ฐ™๋‹ค. ๊ด€๋ จ๋œ ๋‹จ์–ด๋“ค์„ ์ฐพ์„ ๋•Œ ๋งค์นญ๋œ๋‹ค.

  • Value: value ๋ฒกํ„ฐ๋Š” ์‹ค์ œ ๋‹จ์–ด ํ‘œํ˜„์ด๋‹ค. ๊ฐ ๋‹จ์–ด๊ฐ€ ์–ผ๋งˆ๋‚˜ ๊ด€๋ จ์ด ์žˆ๋Š”์ง€ ๊ณ„์‚ฐ๋˜์—ˆ์„ ๋–„, value๋ฅผ ์ด์šฉํ•ด์„œ ํ˜„์žฌ ๋‹จ์–ด์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.

๋น„์œ ๋ฅผ ํ•˜์ž๋ฉด, query๋Š” ์—ฐ๊ตฌํ•˜๊ณ  ์žˆ๋Š” ์ฃผ์ œ๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š” ๋ฉ”๋ชจ๋ผ๊ณ  ํ•œ๋‹ค๋ฉด, key๋Š” ์บ๋น„๋‹› ์•ˆ์— ์žˆ๋Š” ํด๋”์˜ ํƒœ๊ทธ๋ผ๊ณ  ํ•  ๋•Œ, ํƒœ๊ทธ๋ฅผ ์Šคํ‹ฐ์ปค ๋ฉ”๋ชจ์™€ ์ผ์น˜์‹œํ‚ค๋ฉด, ํด๋”์˜ ๋‚ด์šฉ์€ value. ์ด ๋–„, ํ•˜๋‚˜์˜ ๊ฐ’๋งŒ ์ฐพ๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ ํด๋”์˜ ํ˜ผํ•ฉ๋œ ๊ฐ’์˜ ์กฐํ•ฉ์„ ์ฐพ์€ ํ›„ query ๋ฒกํ„ฐ์™€ ๊ฐ ํ‚ค vector๋ฅผ ๊ณฑํ•œ ๊ฐ’์ด ๊ฐ ํด๋”์˜ ์ ์ˆ˜์ด๋‹ค.
์ฆ‰, ํ•˜๋‚˜์˜ query์— ๋Œ€ํ•ด ๋ชจ๋“  key๋“ค๊ณผ ๋‚ด์ ์„ ํ•œ ๋’ค ๊ฐ ๊ฐ’์„ k์˜ ์ฐจ์›์ˆ˜์ธ $\sqrt{d}_{k}$๋กœ ๋‚˜๋ˆ ์ฃผ๋ฉด์„œ ์Šค์ผ€์ผ๋งํ•œ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์†Œํ”„ํŠธ๋งฅ์Šค ํ•จ์ˆ˜๋ฅผ ์”Œ์šด ํ›„ ๋งˆ์ง€๋ง‰์œผ๋กœ ๊ฐ’์„ ๊ณฑํ•ฉ๋‹ค. $\ Attension(Q, K, V)=softmax(\frac{(QK^T)}{โˆš(d_k)})V$

Multi-Head Attention#

Query, key, value ๋“ค์— ๊ฐ๊ฐ ๋‹ค๋ฅธ ํ•™์Šต๋œ ์„ ํ˜• ํˆฌ์˜(linear projection)์„ h๋ฒˆ ์ˆ˜ํ–‰ํ•œ๋‹ค. ์ฆ‰, ๋™์ผํ•œ Q,K,V์— ๊ฐ๊ฐ ๋‹ค๋ฅธ weight matrix W๋ฅผ ๊ณฑํ•ฉ๋‹ˆ๋‹ค. ๊ทธ ํ›„ ๊ฐ๊ฐ ์–ดํ…์…˜์„ ๋ณ‘ํ•ฉ(concatenate)ํ•œ๋‹ค.

$$MultiHead(Q,K,V)=Concat(head_1,โ€ฆ,head_h)W^o \ where head_i=Attention(QW_i^Q,KW_i^K,VW_i^V) $$

์–ดํ…์…˜ ๋ ˆ์ด์–ด๊ฐ€ h๊ฐœ ์”ฉ์œผ๋กœ ๋‚˜๋ˆ ์ง์— ๋”ฐ๋ผ ๋ชจ๋ธ์€ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ํ‘œํ˜„ ๊ณต๊ฐ„(representation subspaces)๋“ค์„ ๊ฐ€์ง€๊ฒŒ ํ•ด์ฃผ๋Š”๋ฐ, Query, key, Value weight ํ–‰๋ ฌ๋“ค์€ ํ•™์Šต์ด ๋œ ํ›„ ๊ฐ๊ฐ์˜ ์ž…๋ ฅ๋ฒกํ„ฐ๋“ค์—๊ฒŒ ๊ณฑํ•ด์ ธ ๋ฒกํ„ฐ๋“ค์„๋‹จ์–ด์˜ ์ •๋ณด์— ๋งž์ถ”์–ด ํˆฌ์˜์‹œํ‚ค๊ฒŒ ๋œ๋‹ค.

Position-wise Feed-Forward Networks#

์–ดํ…์…˜ ํ•˜์œ„ ๊ณ„์ธต์—์„œ fully connected feed-forward network๋กœ ์ง„ํ–‰ํ•˜๋Š” ๊ณผ์ •์ด๊ณ  ๋‘๊ฐœ์˜ ์„ ํ˜• ํšŒ๊ท€์œผ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ๋‹ค.

$\ FFN(x)=maxโก(0,xW1+b1)W2+b2FFN(x)=maxโก(0,xW1+b1)W2+b2$

๋‘ ๋ ˆ์ด์–ด ์‚ฌ์ด์— Trasformer๋Š” ReLU ํ•จ์ˆ˜๋ฅผ Bert๋Š” erf off GELU๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.

def gelu(x):
    cdf = 0.5 * (1.0 + torch.erf(x / np.sqrt(2.0)))
    return x * cdf

Why Self-Attention#

์› ๋…ผ๋ฌธ์—์„œ๋Š” ์ด ๋ชจ๋ธ์—์„œ ์ˆœํ™˜๋‚˜ ํ•ฉ์„ฑ๊ณฑ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ์ž๊ฐ€ ์–ดํƒ ์…˜(self-attention)๋งŒ์„ ์‚ฌ์šฉํ•œ ์ด์œ ์— ๋Œ€ํ•ด์„œ ์•Œ์•„๋ณด๋ฉด, 3๊ฐ€์ง€ ์ด์œ ๋กœ ์ž๊ฐ€ ์–ดํƒ ์…˜์„ ์„ ํƒํ–ˆ๋‹ค

  1. ๋ ˆ์ด์–ด๋‹น ์ „์ฒด ์—ฐ์‚ฐ๋Ÿ‰์ด ์ค„์–ด๋“ ๋‹ค(์‹œ๊ฐ„๋ณต์žก๋„).

  2. ๋ณ‘๋ ฌํ™”๊ฐ€ ๊ฐ€๋Šฅํ•œ ์—ฐ์‚ฐ๋Ÿ‰์ด ๋Š˜์–ด๋‚œ๋‹ค.

  3. ๊ฑฐ๋ฆฌ๊ฐ€ ๋จผ ๋‹จ์–ด๋“ค์˜ ์ข…์†์„ฑ(long-range ๋˜๋Š” long-term dependency)๋•Œ๋ฌธ

๊ทธ๋ฆฌ๊ณ  ์œ„์˜ 3๊ฐ€์ง€ ์™ธ์— ๋˜ ๋‹ค๋ฅธ ์ด์œ ๋Š” ์–ดํƒ ์…˜์„ ์‚ฌ์šฉํ•˜๋ฉด ๋ชจ๋ธ ์ž์ฒด์˜ ๋™์ž‘์„ ํ•ด์„ํ•˜๊ธฐ ์‰ฌ์›Œ์ง„๋‹ค(interpretable). ์–ดํƒ ์…˜ ํ•˜๋‚˜์˜ ๋™์ž‘ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ multi-head์˜ ๋™์ž‘ ๋˜ํ•œ ์–ด๋–ป๊ฒŒ ๋™์ž‘ํ•˜๋Š”์ง€ ์ดํ•ดํ•˜๊ธฐ ์‰ฝ๋‹ค๋Š” ์žฅ์ ์ด ์žˆ๋‹ค.

reference#

โ€ขTransformer
โ€ขBert
โ€ขAttention
โ€ขResidual Connection
โ€ขLayer Normalization
โ€ขLabel Smoothing

โ€ขThe Illustrated Transformer
โ€ขThe Illustrated GPT-2 (Visualizing Transformer Language Models) โ€ขhttps://pozalabs.github.io/transformer/
โ€ขhttp://freesearch.pe.kr/archives/4876#easy-footnote-bottom-2-4876
โ€ขhttps://wikidocs.net/22893
โ€ขhttp://docs.likejazz.com/bert/

BERT extention#

Bidirectional#

์ €์ž๋Š” unidirectional์€ token-level์—์„œ ๋‹จ์ ์ด ๋œ๋‹ค๊ณ  ํ•œ๋‹ค. GPT ๊ฐ™์€ unidirectional ํ•œ ๋ชจ๋ธ, ์ฆ‰ ๋ชจ๋“  ํ† ํฐ์€ ์ด์ „ ํ† ํฐ๋งŒ ์ฐธ๊ณ ํ•  ์ˆ˜ ์žˆ๋Š” auto-regressive ๋ชจ๋ธ๊ณผ๋Š” ๋‹ฌ๋ฆฌ BERT์˜ ๊ฒฝ์šฐ ์ด์ „๊ณผ ์ดํ›„์˜ ์ •๋ณด๋ฅผ ๋ชจ๋‘ ํ™œ์šฉํ•œ๋‹ค. ๊ณ„์‚ฐ๋Ÿ‰๋„ ๋ฌผ๋ก  2๋ฐฐ.

Masked Language Model#

์€ ๋ฌธ์žฅ์˜ ๋‹ค์Œ ๋‹จ์–ด๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ ๋ฌธ์žฅ๋‚ด ๋žœ๋คํ•œ ๋‹จ์–ด๋ฅผ ๋งˆ์Šคํ‚นํ•˜๊ณ  ์ด๋ฅผ ์˜ˆ์ธกํ•˜๋„๋ก ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ Word2Vec์˜ CBOW ๋ชจ๋ธ๊ณผ ์œ ์‚ฌํ•˜๋‹ค. ํ•˜์ง€๋งŒ MLM์€ Context ํ† ํฐ์„ Center ํ† ํฐ์ด ๋˜๋„๋ก ํ•™์Šตํ•˜๊ณ  Weights๋ฅผ ๋ฒกํ„ฐ๋กœ ๊ฐ–๋Š” CBOW์™€ ๋‹ฌ๋ฆฌ, ๋งˆ์Šคํ‚น๋œ ํ† ํฐ์„ ๋งž์ถ”๋„๋ก ํ•™์Šตํ•œ ๊ฒฐ๊ณผ๋ฅผ ์ง์ ‘ ๋ฒกํ„ฐ๋กœ ๊ฐ–๊ธฐ ๋•Œ๋ฌธ์— ๋ณด๋‹ค ์ง๊ด€์ ์ธ ๋ฐฉ์‹์œผ๋กœ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ๋งˆ์Šคํ‚น์€ ์ „์ฒด ๋‹จ์–ด์˜ 15% ์ •๋„๋งŒ ์ง„ํ–‰ํ•˜๋ฉฐ, ๊ทธ ์ค‘์—์„œ๋„ ๋ชจ๋“  ํ† ํฐ์„ ๋งˆ์Šคํ‚น ํ•˜๋Š”๊ฒŒ ์•„๋‹ˆ๋ผ 80% ์ •๋„๋งŒ ๋กœ ์ฒ˜๋ฆฌํ•˜๊ณ  10%๋Š” ๋žœ๋คํ•œ ๋‹จ์–ด, ๋‚˜๋จธ์ง€ 10%๋Š” ์ •์ƒ์ ์ธ ๋‹จ์–ด๋ฅผ ๊ทธ๋Œ€๋กœ ๋‘”๋‹ค.

Next Sentence Prediction(NSP)#

Bert๋Š” 2๋ฌธ์žฅ์„ ์ž…๋ ฅ์œผ๋กœ ๋„ฃ์„ ๋•Œ, [cls] ๋ฌธ์žฅ [sep] ๋ฌธ์žฅ ํ˜•์‹์„ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค. ์›์ €์ž no NSP๋ผ๋ฉด Accuracy๊ฐ€ 0.5์ •๋„ ๋–จ์–ด์ง„๋‹ค๊ณ  ํ•˜๋Š”๋ฐ, ๋ฐ˜๋ก ์„ ํŽผ์น˜๋Š” ๋…ผ๋ฌธ๋„ ๋งŽ์œผ๋ฏ€๋กœ ๋„˜์–ด๊ฐ„๋‹ค.

Effect of Model Size#

Bert Base์˜ ๋ชจ๋ธ์€ L=12, H=768, A=12 110M parameter
Large ๋ชจ๋ธ์€ L=24, H=1024, A,16 340M parameter

๋‹ค๋ฅธ bi-LSTM(context2vec) ๊ฐ™์€ ๊ฒฝ์šฐ ํžˆ๋“ ๋ฒกํ„ฐ์˜ ์ฐจ์›์ˆ˜๋ฅผ 200์—์„œ 600๊นŒ์ง€ ์˜ฌ๋ฆฌ๋Š”๊ฒƒ์€ ์œ ์˜๋ฏธํ–ˆ์ง€๋งŒ 1000๊นŒ์ง€ ์˜ฌ๋ฆฌ๋Š” ๊ฒƒ์€ ๋„์›€์ด ๋˜์ง€ ์•Š์•˜๋‹ค.
๋ฐ˜๋ฉด transformer๋Š” ๋Œ€์šฉ๋Ÿ‰ ๋ชจ๋ธ์˜ ํ•™์Šต์„ ๊ฐ€๋Šฅ์ผ€ ํ•œ๋‹ค.