Source code for crslab.model.crs.kgsf.modules
import numpy as np
import torch
from torch import nn as nn
from crslab.model.utils.modules.transformer import MultiHeadAttention, TransformerFFN, _create_selfattn_mask, \
_normalize, \
create_position_codes
[docs]class GateLayer(nn.Module):
def __init__(self, input_dim):
super(GateLayer, self).__init__()
self._norm_layer1 = nn.Linear(input_dim * 2, input_dim)
self._norm_layer2 = nn.Linear(input_dim, 1)
[docs] def forward(self, input1, input2):
norm_input = self._norm_layer1(torch.cat([input1, input2], dim=-1))
gate = torch.sigmoid(self._norm_layer2(norm_input)) # (bs, 1)
gated_emb = gate * input1 + (1 - gate) * input2 # (bs, dim)
return gated_emb
[docs]class TransformerDecoderLayerKG(nn.Module):
def __init__(
self,
n_heads,
embedding_size,
ffn_size,
attention_dropout=0.0,
relu_dropout=0.0,
dropout=0.0,
):
super().__init__()
self.dim = embedding_size
self.ffn_dim = ffn_size
self.dropout = nn.Dropout(p=dropout)
self.self_attention = MultiHeadAttention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm1 = nn.LayerNorm(embedding_size)
self.encoder_attention = MultiHeadAttention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm2 = nn.LayerNorm(embedding_size)
self.encoder_db_attention = MultiHeadAttention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm2_db = nn.LayerNorm(embedding_size)
self.encoder_kg_attention = MultiHeadAttention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm2_kg = nn.LayerNorm(embedding_size)
self.ffn = TransformerFFN(embedding_size, ffn_size, relu_dropout=relu_dropout)
self.norm3 = nn.LayerNorm(embedding_size)
[docs] def forward(self, x, encoder_output, encoder_mask, kg_encoder_output, kg_encoder_mask, db_encoder_output,
db_encoder_mask):
decoder_mask = _create_selfattn_mask(x)
# first self attn
residual = x
# don't peak into the future!
x = self.self_attention(query=x, mask=decoder_mask)
x = self.dropout(x) # --dropout
x = x + residual
x = _normalize(x, self.norm1)
residual = x
x = self.encoder_db_attention(
query=x,
key=db_encoder_output,
value=db_encoder_output,
mask=db_encoder_mask
)
x = self.dropout(x) # --dropout
x = residual + x
x = _normalize(x, self.norm2_db)
residual = x
x = self.encoder_kg_attention(
query=x,
key=kg_encoder_output,
value=kg_encoder_output,
mask=kg_encoder_mask
)
x = self.dropout(x) # --dropout
x = residual + x
x = _normalize(x, self.norm2_kg)
residual = x
x = self.encoder_attention(
query=x,
key=encoder_output,
value=encoder_output,
mask=encoder_mask
)
x = self.dropout(x) # --dropout
x = residual + x
x = _normalize(x, self.norm2)
# finally the ffn
residual = x
x = self.ffn(x)
x = self.dropout(x) # --dropout
x = residual + x
x = _normalize(x, self.norm3)
return x
[docs]class TransformerDecoderKG(nn.Module):
"""
Transformer Decoder layer.
:param int n_heads: the number of multihead attention heads.
:param int n_layers: number of transformer layers.
:param int embedding_size: the embedding sizes. Must be a multiple of n_heads.
:param int ffn_size: the size of the hidden layer in the FFN
:param embedding: an embedding matrix for the bottom layer of the transformer.
If none, one is created for this encoder.
:param float dropout: Dropout used around embeddings and before layer
layer normalizations. This is used in Vaswani 2017 and works well on
large datasets.
:param float attention_dropout: Dropout performed after the multhead attention
softmax. This is not used in Vaswani 2017.
:param float relu_dropout: Dropout used after the ReLU in the FFN. Not used
in Vaswani 2017, but used in Tensor2Tensor.
:param int padding_idx: Reserved padding index in the embeddings matrix.
:param bool learn_positional_embeddings: If off, sinusoidal embeddings are
used. If on, position embeddings are learned from scratch.
:param bool embeddings_scale: Scale embeddings relative to their dimensionality.
Found useful in fairseq.
:param int n_positions: Size of the position embeddings matrix.
"""
def __init__(
self,
n_heads,
n_layers,
embedding_size,
ffn_size,
vocabulary_size,
embedding,
dropout=0.0,
attention_dropout=0.0,
relu_dropout=0.0,
embeddings_scale=True,
learn_positional_embeddings=False,
padding_idx=None,
n_positions=1024,
):
super().__init__()
self.embedding_size = embedding_size
self.ffn_size = ffn_size
self.n_layers = n_layers
self.n_heads = n_heads
self.dim = embedding_size
self.embeddings_scale = embeddings_scale
self.dropout = nn.Dropout(dropout) # --dropout
self.out_dim = embedding_size
assert embedding_size % n_heads == 0, \
'Transformer embedding size must be a multiple of n_heads'
self.embeddings = embedding
# create the positional embeddings
self.position_embeddings = nn.Embedding(n_positions, embedding_size)
if not learn_positional_embeddings:
create_position_codes(
n_positions, embedding_size, out=self.position_embeddings.weight
)
else:
nn.init.normal_(self.position_embeddings.weight, 0, embedding_size ** -0.5)
# build the model
self.layers = nn.ModuleList()
for _ in range(self.n_layers):
self.layers.append(TransformerDecoderLayerKG(
n_heads, embedding_size, ffn_size,
attention_dropout=attention_dropout,
relu_dropout=relu_dropout,
dropout=dropout,
))
[docs] def forward(self, input, encoder_state, kg_encoder_output, kg_encoder_mask,
db_encoder_output, db_encoder_mask, incr_state=None):
encoder_output, encoder_mask = encoder_state
seq_len = input.size(1)
positions = input.new(seq_len).long() # (seq_len)
positions = torch.arange(seq_len, out=positions).unsqueeze(0) # (1, seq_len)
tensor = self.embeddings(input) # (bs, seq_len, embed_dim)
if self.embeddings_scale:
tensor = tensor * np.sqrt(self.dim)
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
tensor = self.dropout(tensor) # --dropout
for layer in self.layers:
tensor = layer(tensor, encoder_output, encoder_mask, kg_encoder_output, kg_encoder_mask, db_encoder_output,
db_encoder_mask)
return tensor, None