# -*- coding: utf-8 -*-
# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/24
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class SelfAttentionBatch(nn.Module):
def __init__(self, dim, da, alpha=0.2, dropout=0.5):
super(SelfAttentionBatch, self).__init__()
self.dim = dim
self.da = da
self.alpha = alpha
self.dropout = dropout
self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True)
nn.init.xavier_uniform_(self.a.data, gain=1.414)
nn.init.xavier_uniform_(self.b.data, gain=1.414)
[docs] def forward(self, h):
# h: (N, dim)
e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b).squeeze(dim=1)
attention = F.softmax(e, dim=0) # (N)
return torch.matmul(attention, h) # (dim)
[docs]class SelfAttentionSeq(nn.Module):
def __init__(self, dim, da, alpha=0.2, dropout=0.5):
super(SelfAttentionSeq, self).__init__()
self.dim = dim
self.da = da
self.alpha = alpha
self.dropout = dropout
self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da)), requires_grad=True)
self.b = nn.Parameter(torch.zeros(size=(self.da, 1)), requires_grad=True)
nn.init.xavier_uniform_(self.a.data, gain=1.414)
nn.init.xavier_uniform_(self.b.data, gain=1.414)
[docs] def forward(self, h, mask=None, return_logits=False):
"""
For the padding tokens, its corresponding mask is True
if mask==[1, 1, 1, ...]
"""
# h: (batch, seq_len, dim), mask: (batch, seq_len)
e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b) # (batch, seq_len, 1)
if mask is not None:
full_mask = -1e30 * mask.float()
batch_mask = torch.sum((mask == False), -1).bool().float().unsqueeze(-1) # for all padding one, the mask=0
mask = full_mask * batch_mask
e += mask.unsqueeze(-1)
attention = F.softmax(e, dim=1) # (batch, seq_len, 1)
# (batch, dim)
if return_logits:
return torch.matmul(torch.transpose(attention, 1, 2), h).squeeze(1), attention.squeeze(-1)
else:
return torch.matmul(torch.transpose(attention, 1, 2), h).squeeze(1)