Source code for crslab.model.crs.redial.modules

# @Time   : 2020/12/4
# @Author : Chenzhan Shang
# @Email  : czshang@outlook.com

# UPDATE:
# @Time   : 2020/12/16
# @Author : Xiaolei Wang
# @Email  : wxl1999@foxmail.com

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

from crslab.model.utils.functions import sort_for_packed_sequence


[docs]class HRNN(nn.Module): def __init__(self, utterance_encoder_hidden_size, dialog_encoder_hidden_size, dialog_encoder_num_layers, pad_token_idx, embedding=None, use_dropout=False, dropout=0.3): super(HRNN, self).__init__() self.pad_token_idx = pad_token_idx # embedding self.embedding_size = embedding.weight.shape[1] self.embedding = embedding # utterance encoder self.utterance_encoder_hidden_size = utterance_encoder_hidden_size self.utterance_encoder = nn.GRU( input_size=self.embedding_size, hidden_size=self.utterance_encoder_hidden_size, batch_first=True, bidirectional=True ) # conversation encoder self.dialog_encoder = nn.GRU( input_size=(2 * self.utterance_encoder_hidden_size), hidden_size=dialog_encoder_hidden_size, num_layers=dialog_encoder_num_layers, batch_first=True ) # dropout self.use_dropout = use_dropout if self.use_dropout: self.dropout = nn.Dropout(p=dropout)
[docs] def get_utterance_encoding(self, context, utterance_lengths): """ :param context: (batch_size, max_conversation_length, max_utterance_length) :param utterance_lengths: (batch_size, max_conversation_length) :return utterance_encoding: (batch_size, max_conversation_length, 2 * utterance_encoder_hidden_size) """ batch_size, max_conv_length = context.shape[:2] utterance_lengths = utterance_lengths.reshape(-1) # (bs * conv_len) sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(utterance_lengths) # reshape and reorder sorted_utterances = context.view(batch_size * max_conv_length, -1).index_select(0, sorted_idx) # consider valid sequences only(length > 0) num_positive_lengths = torch.sum(utterance_lengths > 0) sorted_utterances = sorted_utterances[:num_positive_lengths] sorted_lengths = sorted_lengths[:num_positive_lengths] embedded = self.embedding(sorted_utterances) if self.use_dropout: embedded = self.dropout(embedded) packed_utterances = pack_padded_sequence(embedded, sorted_lengths, batch_first=True) _, utterance_encoding = self.utterance_encoder(packed_utterances) # concat the hidden states of the last layer (two directions of the GRU) utterance_encoding = torch.cat((utterance_encoding[-1], utterance_encoding[-2]), 1) if self.use_dropout: utterance_encoding = self.dropout(utterance_encoding) # complete the missing sequences (of length 0) if num_positive_lengths < batch_size * max_conv_length: pad_tensor = utterance_encoding.new_full( (batch_size * max_conv_length - num_positive_lengths, 2 * self.utterance_encoder_hidden_size), self.pad_token_idx) utterance_encoding = torch.cat((utterance_encoding, pad_tensor), 0) # retrieve original utterance order and Reshape to separate contexts utterance_encoding = utterance_encoding.index_select(0, rev_idx) utterance_encoding = utterance_encoding.view(batch_size, max_conv_length, 2 * self.utterance_encoder_hidden_size) return utterance_encoding
[docs] def forward(self, context, utterance_lengths, dialog_lengths): """ :param context: (batch_size, max_context_length, max_utterance_length) :param utterance_lengths: (batch_size, max_context_length) :param dialog_lengths: (batch_size) :return context_state: (batch_size, context_encoder_hidden_size) """ utterance_encoding = self.get_utterance_encoding(context, utterance_lengths) # (bs, conv_len, 2 * utt_dim) sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(dialog_lengths) # reorder in decreasing sequence length sorted_representations = utterance_encoding.index_select(0, sorted_idx) packed_sequences = pack_padded_sequence(sorted_representations, sorted_lengths, batch_first=True) _, context_state = self.dialog_encoder(packed_sequences) context_state = context_state.index_select(1, rev_idx) if self.use_dropout: context_state = self.dropout(context_state) return context_state[-1]
[docs]class SwitchingDecoder(nn.Module): def __init__(self, hidden_size, context_size, num_layers, vocab_size, embedding, pad_token_idx): super(SwitchingDecoder, self).__init__() self.pad_token_idx = pad_token_idx self.hidden_size = hidden_size self.context_size = context_size self.num_layers = num_layers if context_size != hidden_size: raise ValueError("The context size {} must match the hidden size {} in DecoderGRU".format( context_size, hidden_size)) self.embedding = embedding embedding_dim = embedding.weight.shape[1] self.decoder = nn.GRU(input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True) self.out = nn.Linear(hidden_size, vocab_size) self.switch = nn.Linear(hidden_size + context_size, 1)
[docs] def forward(self, request, request_lengths, context_state): """ :param request: (batch_size, max_utterance_length) :param request_lengths: (batch_size) :param context_state: (batch_size, context_encoder_hidden_size) :return log_probabilities: (batch_size, max_utterance_length, vocab_size + 1) """ batch_size, max_utterance_length = request.shape # sort for pack sorted_lengths, sorted_idx, rev_idx = sort_for_packed_sequence(request_lengths) sorted_request = request.index_select(0, sorted_idx) embedded_request = self.embedding(sorted_request) # (batch_size, max_utterance_length, embed_dim) packed_request = pack_padded_sequence(embedded_request, sorted_lengths, batch_first=True) sorted_context_state = context_state.index_select(0, sorted_idx) h_0 = sorted_context_state.unsqueeze(0).expand( self.num_layers, batch_size, self.hidden_size ).contiguous() # require context_size == hidden_size sorted_vocab_state, _ = self.decoder(packed_request, h_0) sorted_vocab_state, _ = pad_packed_sequence(sorted_vocab_state, batch_first=True) # (batch_size, max_request_length, decoder_hidden_size) _, max_request_length, decoder_hidden_size = sorted_vocab_state.shape pad_tensor = sorted_vocab_state.new_full( (batch_size, max_utterance_length - max_request_length, decoder_hidden_size), self.pad_token_idx) sorted_vocab_state = torch.cat((sorted_vocab_state, pad_tensor), dim=1) # (batch_size, max_utterance_length, decoder_hidden_size) sorted_language_output = self.out(sorted_vocab_state) # (batch_size, max_utterance_length, vocab_size) # expand context to each time step expanded_sorted_context_state = sorted_context_state.unsqueeze(1).expand( batch_size, max_utterance_length, self.context_size ).contiguous() # (batch_size, max_utterance_length, context_size) # compute switch switch_input = torch.cat((expanded_sorted_context_state, sorted_vocab_state), dim=2) # (batch_size, max_utterance_length, context_size + decoder_hidden_size) switch = self.switch(switch_input) # (batch_size, max_utterance_length, 1) sorted_output = torch.cat(( F.logsigmoid(switch) + F.log_softmax(sorted_language_output, dim=2), F.logsigmoid(-switch) # for item ), dim=2) output = sorted_output.index_select(0, rev_idx) # (batch_size, max_utterance_length, vocab_size + 1) return output