본문 바로가기
AI/파이토치(Pytorch)

[NLP][파이토치] seq2seq - Attention(어텐션)

by Hyen4110 2021. 8. 30.

김기현 강사님의 '자연어처리 딥러닝 캠프(파이토치편)' 책을 공부하면서 정리한 글입니다.

http://www.kyobobook.co.kr/product/detailViewKor.laf?ejkGb=KOR&mallGb=KOR&barcode=9791162241974&orderClick=LAG&Kc= 

 

김기현의 자연어 처리 딥러닝 캠프: 파이토치 편 - 교보문고

딥러닝 기반의 자연어 처리 기초부터 심화까지 | 저자의 현장 경험과 인사이트를 녹여낸 본격적인 활용 가이드 이 책은 저자가 현장에서 실제로 시스템을 구축하며 얻은 경험과 그로부터 얻은

www.kyobobook.co.kr

 

* 수식으로 확인하는 Attention

 > 지난 글 2021.05.12 - [자연어처리(NLP)] - [NLP] Attention Mechanism(어텐션)

 

class Attention(nn.Module):

    def __init__(self, hidden_size):
        super(Attention, self).__init__()

        self.linear = nn.Linear(hidden_size, hidden_size, bias=False)
        self.softmax = nn.Softmax(dim=-1)
torch.nn.Linear(in_features,out_features,bias=True,device=None,dtype=None)
: 선형회귀모델을 구현하는 함수로 입력 차원과 출력차원을 인수로 입력합니다. 
: 만들어진 선형회귀 모델의 가중치와 편향은, weight()과 bias()로 구할수있으며, parameters()로 동시에 구할수도 있습니다


- Attention에서는 Query를 잘 변환(linear-transform)하는 방법을 배우는 과정이라는것 기억! 

torch.nn.Softmax(dim=None)
 - input 텐서에 있는 숫자를 소프트맥수 함수를 통과시켜서, 0과 1사이에 있고, 그 합이 1인 output 텐서를 반환합니다.


  (※' dim'은 소프트맥스를 구할 차원의 방향을 의미합니다)

 

 

    def forward(self, h_src, h_t_tgt, mask=None):
        # |h_src| = (batch_size, length, hidden_size)
        # |h_t_tgt| = (batch_size, 1, hidden_size)
        # |mask| = (batch_size, length) -> source senetence의 pad의 위치

        query = self.linear(h_t_tgt)
        # |query| = (batch_size, 1, hidden_size)

        weight = torch.bmm(query, h_src.transpose(1, 2))
        # |weight| = (batch_size, 1, length)
        if mask is not None:
            # Set each weight as -inf, if the mask value equals to 1.
            # Since the softmax operation makes -inf to 0,
            # masked weights would be set to 0 after softmax operation.
            # Thus, if the sample is shorter than other samples in mini-batch,
            # the weight for empty time-step would be set to 0.
            weight.masked_fill_(mask.unsqueeze(1), -float('inf'))
        weight = self.softmax(weight)

        context_vector = torch.bmm(weight, h_src)
        # |context_vector| = (batch_size, 1, hidden_size)

        return context_vector

 

 

댓글