rueki

Self-Attention, Transformer 본문

MLOPS/full stack deep learning review

Self-Attention, Transformer

륵기 2021. 9. 29. 21:30
728x90
반응형

Self - Attention

 

input : x1,...,xt

output : y1,...,yt

 

input과 output 모두 k 차원의 vector로 구성되어 있으며, y를 구하기 위해서 적용할 self-attention 개념은 모든 input vector에 대해 weighted average 연산을 하는 것이다.

 

 

yi=jwijxj

여기서 인덱스 j는 전체 sequence에 대하여 어우르는 인덱스 값이며, i는 현재 step에 대한 인덱스이다.

wij는 파라메터의 개념은 아니고, 일반적인 뉴럴 넷이다. 그러나 이것은 xixj로 부터 도출된 결과로도 볼 수가 있다.

 

w를 구하기 위한 dot product 연산은 아래와 같다.

xi 는 현재 step에 대한 input vector, yi는 같은 step의 output vector이다. 현재 step의 input vector와 다른 step의 input vector를 dot product해서 나온 값이 현재 input에 대한 다른 step vector와의 가중치로 사용이 되게 된다.

결국 연산에 사용되는 vector들이 input 값들 내에서 전부 구성이 되어있기 때문에 self-attention이라는 이름이 붙지않은 것일까 하는 생각이다.

 

결과적으로는, 서로 다른 weighted sum 결과를 가지게 되는 것이다. 이 값은 -inf ~ inf사이의 값을 가지기 때문에, 여기서 softmax를 취해서 0에서 1 사이의 확률값으로 변환시켜준다. Self-Attention의 전체적인 process는 아래의 그림과 같다.

 

여기서 사용된 벡터간의 연산은 유일하게 서로의 feature 혹은 information을 공유하는 operation이며, Transformer에서

그 외의 연산들은 input vector에만 각각 적용이 된다고 한다(다른 벡터들과의 상호작용 X).

 

import torch
import torch.nn.functional as F

x = torch.randn(16, 64, 256) #batch, seq len, dimension of sequence
raw_weights = torch.bmm(x, x.transpose(1,2)) # 16, 64, 64
weights = F.softmax(raw_weights, dim = 2)
y = torch.bmm(weights, x) #(64 x 64) * (64 x 356)
y.shape # (16, 64, 256)

 

 

- Query, Key, Value

연산 트릭으로 input vector xi를 세 개의 attention operation으로 나눌 수 있다.

* output yi를 위한 가중치 설정을 위해 모든 다른 vector와 비교

* output yj를 위한 가중치 설정을 위해 모든 다른 vector와 비교

* 설정된 가중치로 각 출력 vector를 계산하는 weigthed sum의 일부로 사용되는 부분

 

현재의 step index i를 2, 그 외의 인덱스 중 하나 j를 3으로 설정했다고 가정했을 때,

 

1. Query는 Query weight matrix와 input xi를 곱한 형태

2. Key는 Key weight matrix와 input xi를 곱한 형태

3. Query와 Key를 dot product해서 하나의 w값을 구해서 softmax를 취해준다.

4. output y_{i}는 위에서 구한 w값과 value vj를 weighte sum 한 것이다.

 

연산을 요약하면, 현재의 input과 다른 인덱스의 input과 첫번째로 연산해서 w를 구해서 이 값과 다른 인덱스의 input과 weigthed sum을 통하여 현재의 output을 도출해낸다.

 

주의해야할 사항은 input length가 길어지면, 스케일이 매우 커지고 결과적으로 graident가 거의 없어지고 학습이 잘 되지 않기 때문에, scailing이 필요한데 scailing은 아래와 같다.

 

Multi - head  Attention

Self-attention에서는 모든 정보들에 대해서 sum 연산 하면서 진행이 되었지만, 여기서 더 발전해서 몇 개의 메커니즘을 결합해서 사용한다면 더 효율적이다. 그래서 위에서 살펴보았던 Query, key, value weight matrix를 R개의 head로 구성할 것인데 이는 Wrq,Wrk,Wrv로 구성한다. 명칭은 attention-heads로 일컫는다.

 

그래서 inputxi를 넣게 되면, 각 head 마다 output yri이 나오게 되는데, 이 값들을 concat해서 linear transform을 함에따라 다시 처음의 k 차원으로 줄어들게 된다.

이러한 구조의 특징이 병렬 처리가 가능해지는 것인데 input의 차원이 256인 경우에 head를 8개로 나누게 되면, 하나의 head당 32 dimensions만 담당하게 되면 된다.

 

class SelfAttention(nn.Module):
	def __init__(self, k, heads = 8):
    	super().__init__()
        self.k = k
        self.heads = heads
        
        self.tokeys = nn.Linear(k, k * heads)
        self.toqueries = nn.Linear(k, k*heads)
        self.tovalues =nn.Linear(k , k*heads)
        self.unifyheads = nn.Linear(heads * k, k)
        
	def forward(self, x):
    	b, t, k = x.size()
        h = self.heads
        
        # (b,t,h,k) -> (b*h, t, k)
        queries = self.toqueries(x).view(b, t, h, k)
        keys = self.tokeys(x).view(b, t, h, k)
        values = self.tovalues(x).view(b, t, h, k)
        
        queries = queries.transpose(1, 2).contiguous().view(b * h, t, k)
        keys = keys.transpose(1, 2).contiguous().view(b * h, t, k)
        values = values.transpose(1, 2).contiguous().view(b * h, t, k)
        
        queries = queries / (k ** (1/4))
    	keys    = keys / (k ** (1/4))
        
        dot = torch.bmm(queries, keys.transpose(1,2)) #(b*h, t, t)
        dot = F.softmax(dot, dim = 2)
        out = torch.bmm(dot, values).view(b, h, t, k)
        out = out.transpose(1, 2).contiguous().view(b, t, h * k)
        return self.unifyheads(out)

 

● Transformer

Self-attention을 이용해서 만든 것이 하나의 transformer block구조이다. self-attention, layer norm, mlp로 구성되어 있으며, residual connection까지 활용하였다. 여기서 layer norm을 사용한 이유는 임베딩 차원에 해당되는 부분씩만 norm을 적용하기 위하여 사용했다고 한다.

class TransformerBlock(nn.Module):
	def __init__(self, k, heads):
    	super().__init__()
        self.attention = SelfAttention(k, heads = heads)
        self.norm1(k)
        self.norm2(k)
        
        self.feedforward = nn.Sequential(
        	nn.Linear(k, 4*k),
            nn.ReLU(),
            nn.Linear(4*k, k)
        )
        
	def forward(self, x):
    	att = self.attention(x)
        x = self.norm1(att + x)
        ff = self.feedforward(x)
        return self.norm2(ff + x)

 

- position embeddings, positional encoding

input data에서 각각의 값에 대해 위치에 대한 정보를 임베딩 시키는데 이것을 하는 이유는 트랜스포머가 결국에는 input을 rnn처럼 sequential하게 받는 게 아니기 때문에, 위치 정보를 제공해줘야 하기 때문이다.

그래서 각 입력값의 임베딩 벡터에 위치정보들을 더하여 모델의 input으로 넣게 된다.

 

위의 그림에서 보면 워드 임베딩 벡터 값과 포지션 임베딩 한 벡터 값을 합쳐서 모델 input으로 들어가게 되는 구조이다.

position 정보는 어떻게 나타내는지 궁금해질수가 있는데, 각 input의 step 마다 하나의 유일한 encoding값을 가져야 하며, 서로 다른 길이의 데이터에 있어서 두 time-step 간 거리는 일정해야 한다.

주기가 10000의 2i/d * 2ㅠ 의 함수를 사용하여, 0 ~ d차원까지 값을 짝수는 sin, 홀수는 cos 값을 따르게 하며, 각각 input에 대한 위치 정보값을 갖게끔 한다고 한다.

 

class Transformer(nn.Module):
	def __init__(self, k, heads, depth, seq_len, num_tokens, num_classes):
    	super().__init__()
        self.num_tokens = num_tokens
        self.token_emb = nn.Embedding(num_tokens, k)
        self.pos_emb = nn.Embedding(seq_len, k)
        
        tblocks = []
        for i in range(depth):
        	tblocks.append(TransformerBlock(k=k, heads = heads))
        self.tblocks = nn.Sequential(*tblocks)
        
        self.toprobs = nn.Linear(k, num_classes)
        
    def forward(self, x):
    	tokens = self.token_emb(x)
        b, t, k = tokens.size()
        positions = torch.arange(t)
        positions = self.pos_emb(positions)[None, :, :].expand(b, t, k)
        x = tokens + positions
        x = self.tblocks(x)
        x = self.toprobs(x.mean(dim = 1))
        return F.log_softmax(x, dim = 1)

- masking

Transformer를 거쳐서 output을 내보낼 때, sequential 하게 내보내느데, 여기서 masking 연산을 적용한다.

적용하는 이유는 position i보다 이후에 있는 position에 attention연산을 적용치 않기 위해서다.

그래서 masking 값은 -inf를 적용하여, 이 값들이 attention 마지막 연산 때 value 값과 곱해지게 되는데 -inf와 곱해진 값은 0에 가까운 값을 갖게 되어 영향을 미치지 않게 된다.

 

728x90
반응형

'MLOPS > full stack deep learning review' 카테고리의 다른 글

Week 8. Data Management  (0) 2021.10.28
Week 7. Troubleshooting Deep Neural Networks  (0) 2021.10.21
week6. Infrastructure & Tooling  (0) 2021.10.11
Week2. CNN  (0) 2021.09.25
Week 3. RNN(Recurrent neural network)  (0) 2021.09.22
Comments