Source code for crslab.model.utils.modules.attention

# -*- coding: utf-8 -*-
# @Time   : 2020/11/22
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/11/24
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class SelfAttentionBatch(nn.Module): def __init__(self, dim, da, alpha=0.2, dropout=0.5): super(SelfAttentionBatch, self).__init__() self.dim = dim self.da = da self.alpha = alpha self.dropout = dropout self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True) self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True) nn.init.xavier_uniform_(self.a.data, gain=1.414) nn.init.xavier_uniform_(self.b.data, gain=1.414)
[docs] def forward(self, h): # h: (N, dim) e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b).squeeze(dim=1) attention = F.softmax(e, dim=0) # (N) return torch.matmul(attention, h) # (dim)
[docs]class SelfAttentionSeq(nn.Module): def __init__(self, dim, da, alpha=0.2, dropout=0.5): super(SelfAttentionSeq, self).__init__() self.dim = dim self.da = da self.alpha = alpha self.dropout = dropout self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True) self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True) nn.init.xavier_uniform_(self.a.data, gain=1.414) nn.init.xavier_uniform_(self.b.data, gain=1.414)
[docs] def forward(self, h, mask=None, return_logits=False): """ For the padding tokens, its corresponding mask is True if mask==[1, 1, 1, ...] """ # h: (batch, seq_len, dim), mask: (batch, seq_len) e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b) # (batch, seq_len, 1) if mask is not None: full_mask = -1e30 * mask.float() batch_mask = torch.sum((mask == False), -1).bool().float().unsqueeze(-1) # for all padding one, the mask=0 mask = full_mask * batch_mask e += mask.unsqueeze(-1) attention = F.softmax(e, dim=1) # (batch, seq_len, 1) # (batch, dim) if return_logits: return torch.matmul(torch.transpose(attention, 1, 2), h).squeeze(1), attention.squeeze(-1) else: return torch.matmul(torch.transpose(attention, 1, 2), h).squeeze(1)