Source code for crslab.model.crs.kgsf.kgsf

# @Time   : 2020/11/22
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/11/24, 2020/12/29, 2021/1/4
# @Author : Kun Zhou, Xiaolei Wang, Yuanhang Zhou
# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com, sdzyh002@gmail.com

r"""
KGSF
====
References:
    Zhou, Kun, et al. `"Improving Conversational Recommender Systems via Knowledge Graph based Semantic Fusion."`_ in KDD 2020.

.. _`"Improving Conversational Recommender Systems via Knowledge Graph based Semantic Fusion."`:
   https://dl.acm.org/doi/abs/10.1145/3394486.3403143

"""

import os

import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
from torch import nn
from torch_geometric.nn import GCNConv, RGCNConv

from crslab.config import MODEL_PATH
from crslab.model.base import BaseModel
from crslab.model.utils.functions import edge_to_pyg_format
from crslab.model.utils.modules.attention import SelfAttentionSeq
from crslab.model.utils.modules.transformer import TransformerEncoder
from .modules import GateLayer, TransformerDecoderKG
from .resources import resources


[docs]class KGSFModel(BaseModel): """ Attributes: vocab_size: A integer indicating the vocabulary size. pad_token_idx: A integer indicating the id of padding token. start_token_idx: A integer indicating the id of start token. end_token_idx: A integer indicating the id of end token. token_emb_dim: A integer indicating the dimension of token embedding layer. pretrain_embedding: A string indicating the path of pretrained embedding. n_word: A integer indicating the number of words. n_entity: A integer indicating the number of entities. pad_word_idx: A integer indicating the id of word padding. pad_entity_idx: A integer indicating the id of entity padding. num_bases: A integer indicating the number of bases. kg_emb_dim: A integer indicating the dimension of kg embedding. n_heads: A integer indicating the number of heads. n_layers: A integer indicating the number of layer. ffn_size: A integer indicating the size of ffn hidden. dropout: A float indicating the dropout rate. attention_dropout: A integer indicating the dropout rate of attention layer. relu_dropout: A integer indicating the dropout rate of relu layer. learn_positional_embeddings: A boolean indicating if we learn the positional embedding. embeddings_scale: A boolean indicating if we use the embeddings scale. reduction: A boolean indicating if we use the reduction. n_positions: A integer indicating the number of position. response_truncate = A integer indicating the longest length for response generation. pretrained_embedding: A string indicating the path of pretrained embedding. """ def __init__(self, opt, device, vocab, side_data): """ Args: opt (dict): A dictionary record the hyper parameters. device (torch.device): A variable indicating which device to place the data and model. vocab (dict): A dictionary record the vocabulary information. side_data (dict): A dictionary record the side data. """ self.device = device self.gpu = opt.get("gpu", [-1]) # vocab self.vocab_size = vocab['vocab_size'] self.pad_token_idx = vocab['pad'] self.start_token_idx = vocab['start'] self.end_token_idx = vocab['end'] self.token_emb_dim = opt['token_emb_dim'] self.pretrained_embedding = side_data.get('embedding', None) # kg self.n_word = vocab['n_word'] self.n_entity = vocab['n_entity'] self.pad_word_idx = vocab['pad_word'] self.pad_entity_idx = vocab['pad_entity'] entity_kg = side_data['entity_kg'] self.n_relation = entity_kg['n_relation'] entity_edges = entity_kg['edge'] self.entity_edge_idx, self.entity_edge_type = edge_to_pyg_format(entity_edges, 'RGCN') self.entity_edge_idx = self.entity_edge_idx.to(device) self.entity_edge_type = self.entity_edge_type.to(device) word_edges = side_data['word_kg']['edge'] self.word_edges = edge_to_pyg_format(word_edges, 'GCN').to(device) self.num_bases = opt['num_bases'] self.kg_emb_dim = opt['kg_emb_dim'] # transformer self.n_heads = opt['n_heads'] self.n_layers = opt['n_layers'] self.ffn_size = opt['ffn_size'] self.dropout = opt['dropout'] self.attention_dropout = opt['attention_dropout'] self.relu_dropout = opt['relu_dropout'] self.learn_positional_embeddings = opt['learn_positional_embeddings'] self.embeddings_scale = opt['embeddings_scale'] self.reduction = opt['reduction'] self.n_positions = opt['n_positions'] self.response_truncate = opt.get('response_truncate', 20) # copy mask dataset = opt['dataset'] dpath = os.path.join(MODEL_PATH, "kgsf", dataset) resource = resources[dataset] super(KGSFModel, self).__init__(opt, device, dpath, resource)
[docs] def build_model(self): self._init_embeddings() self._build_kg_layer() self._build_infomax_layer() self._build_recommendation_layer() self._build_conversation_layer()
def _init_embeddings(self): if self.pretrained_embedding is not None: self.token_embedding = nn.Embedding.from_pretrained( torch.as_tensor(self.pretrained_embedding, dtype=torch.float), freeze=False, padding_idx=self.pad_token_idx) else: self.token_embedding = nn.Embedding(self.vocab_size, self.token_emb_dim, self.pad_token_idx) nn.init.normal_(self.token_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) nn.init.constant_(self.token_embedding.weight[self.pad_token_idx], 0) self.word_kg_embedding = nn.Embedding(self.n_word, self.kg_emb_dim, self.pad_word_idx) nn.init.normal_(self.word_kg_embedding.weight, mean=0, std=self.kg_emb_dim ** -0.5) nn.init.constant_(self.word_kg_embedding.weight[self.pad_word_idx], 0) logger.debug('[Finish init embeddings]') def _build_kg_layer(self): # db encoder self.entity_encoder = RGCNConv(self.n_entity, self.kg_emb_dim, self.n_relation, self.num_bases) self.entity_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim) # concept encoder self.word_encoder = GCNConv(self.kg_emb_dim, self.kg_emb_dim) self.word_self_attn = SelfAttentionSeq(self.kg_emb_dim, self.kg_emb_dim) # gate mechanism self.gate_layer = GateLayer(self.kg_emb_dim) logger.debug('[Finish build kg layer]') def _build_infomax_layer(self): self.infomax_norm = nn.Linear(self.kg_emb_dim, self.kg_emb_dim) self.infomax_bias = nn.Linear(self.kg_emb_dim, self.n_entity) self.infomax_loss = nn.MSELoss(reduction='sum') logger.debug('[Finish build infomax layer]') def _build_recommendation_layer(self): self.rec_bias = nn.Linear(self.kg_emb_dim, self.n_entity) self.rec_loss = nn.CrossEntropyLoss() logger.debug('[Finish build rec layer]') def _build_conversation_layer(self): self.register_buffer('START', torch.tensor([self.start_token_idx], dtype=torch.long)) self.conv_encoder = TransformerEncoder( n_heads=self.n_heads, n_layers=self.n_layers, embedding_size=self.token_emb_dim, ffn_size=self.ffn_size, vocabulary_size=self.vocab_size, embedding=self.token_embedding, dropout=self.dropout, attention_dropout=self.attention_dropout, relu_dropout=self.relu_dropout, padding_idx=self.pad_token_idx, learn_positional_embeddings=self.learn_positional_embeddings, embeddings_scale=self.embeddings_scale, reduction=self.reduction, n_positions=self.n_positions, ) self.conv_entity_norm = nn.Linear(self.kg_emb_dim, self.ffn_size) self.conv_entity_attn_norm = nn.Linear(self.kg_emb_dim, self.ffn_size) self.conv_word_norm = nn.Linear(self.kg_emb_dim, self.ffn_size) self.conv_word_attn_norm = nn.Linear(self.kg_emb_dim, self.ffn_size) self.copy_norm = nn.Linear(self.ffn_size * 3, self.token_emb_dim) self.copy_output = nn.Linear(self.token_emb_dim, self.vocab_size) self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, "copy_mask.npy")).astype(bool), ).to(self.device) self.conv_decoder = TransformerDecoderKG( self.n_heads, self.n_layers, self.token_emb_dim, self.ffn_size, self.vocab_size, embedding=self.token_embedding, dropout=self.dropout, attention_dropout=self.attention_dropout, relu_dropout=self.relu_dropout, embeddings_scale=self.embeddings_scale, learn_positional_embeddings=self.learn_positional_embeddings, padding_idx=self.pad_token_idx, n_positions=self.n_positions ) self.conv_loss = nn.CrossEntropyLoss(ignore_index=self.pad_token_idx) logger.debug('[Finish build conv layer]')
[docs] def pretrain_infomax(self, batch): """ words: (batch_size, word_length) entity_labels: (batch_size, n_entity) """ words, entity_labels = batch loss_mask = torch.sum(entity_labels) if loss_mask.item() == 0: return None entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) word_representations = word_graph_representations[words] word_padding_mask = words.eq(self.pad_word_idx) # (bs, seq_len) word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) word_info_rep = self.infomax_norm(word_attn_rep) # (bs, dim) info_predict = F.linear(word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) loss = self.infomax_loss(info_predict, entity_labels) / loss_mask return loss
[docs] def recommend(self, batch, mode): """ context_entities: (batch_size, entity_length) context_words: (batch_size, word_length) movie: (batch_size) """ context_entities, context_words, entities, movie = batch entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) entity_padding_mask = context_entities.eq(self.pad_entity_idx) # (bs, entity_len) word_padding_mask = context_words.eq(self.pad_word_idx) # (bs, word_len) entity_representations = entity_graph_representations[context_entities] word_representations = word_graph_representations[context_words] entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask) word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) user_rep = self.gate_layer(entity_attn_rep, word_attn_rep) rec_scores = F.linear(user_rep, entity_graph_representations, self.rec_bias.bias) # (bs, #entity) rec_loss = self.rec_loss(rec_scores, movie) info_loss_mask = torch.sum(entities) if info_loss_mask.item() == 0: info_loss = None else: word_info_rep = self.infomax_norm(word_attn_rep) # (bs, dim) info_predict = F.linear(word_info_rep, entity_graph_representations, self.infomax_bias.bias) # (bs, #entity) info_loss = self.infomax_loss(info_predict, entities) / info_loss_mask return rec_loss, info_loss, rec_scores
[docs] def freeze_parameters(self): freeze_models = [self.word_kg_embedding, self.entity_encoder, self.entity_self_attn, self.word_encoder, self.word_self_attn, self.gate_layer, self.infomax_bias, self.infomax_norm, self.rec_bias] for model in freeze_models: for p in model.parameters(): p.requires_grad = False
[docs] def _starts(self, batch_size): """Return bsz start tokens.""" return self.START.detach().expand(batch_size, 1)
def _decode_forced_with_kg(self, token_encoding, entity_reps, entity_emb_attn, entity_mask, word_reps, word_emb_attn, word_mask, response): batch_size, seq_len = response.shape start = self._starts(batch_size) inputs = torch.cat((start, response[:, :-1]), dim=-1).long() dialog_latent, _ = self.conv_decoder(inputs, token_encoding, word_reps, word_mask, entity_reps, entity_mask) # (bs, seq_len, dim) entity_latent = entity_emb_attn.unsqueeze(1).expand(-1, seq_len, -1) word_latent = word_emb_attn.unsqueeze(1).expand(-1, seq_len, -1) copy_latent = self.copy_norm( torch.cat((entity_latent, word_latent, dialog_latent), dim=-1)) # (bs, seq_len, dim) copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze( 0) # (bs, seq_len, vocab_size) gen_logits = F.linear(dialog_latent, self.token_embedding.weight) # (bs, seq_len, vocab_size) sum_logits = copy_logits + gen_logits preds = sum_logits.argmax(dim=-1) return sum_logits, preds def _decode_greedy_with_kg(self, token_encoding, entity_reps, entity_emb_attn, entity_mask, word_reps, word_emb_attn, word_mask): batch_size = token_encoding[0].shape[0] inputs = self._starts(batch_size).long() incr_state = None logits = [] for _ in range(self.response_truncate): dialog_latent, incr_state = self.conv_decoder(inputs, token_encoding, word_reps, word_mask, entity_reps, entity_mask, incr_state) dialog_latent = dialog_latent[:, -1:, :] # (bs, 1, dim) db_latent = entity_emb_attn.unsqueeze(1) concept_latent = word_emb_attn.unsqueeze(1) copy_latent = self.copy_norm(torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) gen_logits = F.linear(dialog_latent, self.token_embedding.weight) sum_logits = copy_logits + gen_logits preds = sum_logits.argmax(dim=-1).long() logits.append(sum_logits) inputs = torch.cat((inputs, preds), dim=1) finished = ((inputs == self.end_token_idx).sum(dim=-1) > 0).sum().item() == batch_size if finished: break logits = torch.cat(logits, dim=1) return logits, inputs def _decode_beam_search_with_kg(self, token_encoding, entity_reps, entity_emb_attn, entity_mask, word_reps, word_emb_attn, word_mask, beam=4): batch_size = token_encoding[0].shape[0] inputs = self._starts(batch_size).long().reshape(1, batch_size, -1) incr_state = None sequences = [[[list(), list(), 1.0]]] * batch_size for i in range(self.response_truncate): if i == 1: token_encoding = (token_encoding[0].repeat(beam, 1, 1), token_encoding[1].repeat(beam, 1, 1)) entity_reps = entity_reps.repeat(beam, 1, 1) entity_emb_attn = entity_emb_attn.repeat(beam, 1) entity_mask = entity_mask.repeat(beam, 1) word_reps = word_reps.repeat(beam, 1, 1) word_emb_attn = word_emb_attn.repeat(beam, 1) word_mask = word_mask.repeat(beam, 1) # at beginning there is 1 candidate, when i!=0 there are 4 candidates if i != 0: inputs = [] for d in range(len(sequences[0])): for j in range(batch_size): text = sequences[j][d][0] inputs.append(text) inputs = torch.stack(inputs).reshape(beam, batch_size, -1) # (beam, batch_size, _) with torch.no_grad(): dialog_latent, incr_state = self.conv_decoder( inputs.reshape(len(sequences[0]) * batch_size, -1), token_encoding, word_reps, word_mask, entity_reps, entity_mask, incr_state ) dialog_latent = dialog_latent[:, -1:, :] # (bs, 1, dim) db_latent = entity_emb_attn.unsqueeze(1) concept_latent = word_emb_attn.unsqueeze(1) copy_latent = self.copy_norm(torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) gen_logits = F.linear(dialog_latent, self.token_embedding.weight) sum_logits = copy_logits + gen_logits logits = sum_logits.reshape(len(sequences[0]), batch_size, 1, -1) # turn into probabilities,in case of negative numbers probs, preds = torch.nn.functional.softmax(logits).topk(beam, dim=-1) # (candeidate, bs, 1 , beam) during first loop, candidate=1, otherwise candidate=beam for j in range(batch_size): all_candidates = [] for n in range(len(sequences[j])): for k in range(beam): prob = sequences[j][n][2] logit = sequences[j][n][1] if logit == []: logit_tmp = logits[n][j][0].unsqueeze(0) else: logit_tmp = torch.cat((logit, logits[n][j][0].unsqueeze(0)), dim=0) seq_tmp = torch.cat((inputs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1))) candidate = [seq_tmp, logit_tmp, prob * probs[n][j][0][k]] all_candidates.append(candidate) ordered = sorted(all_candidates, key=lambda tup: tup[2], reverse=True) sequences[j] = ordered[:beam] # check if everyone has generated an end token all_finished = ((inputs == self.end_token_idx).sum(dim=1) > 0).sum().item() == batch_size if all_finished: break logits = torch.stack([seq[0][1] for seq in sequences]) inputs = torch.stack([seq[0][0] for seq in sequences]) return logits, inputs
[docs] def converse(self, batch, mode): context_tokens, context_entities, context_words, response = batch entity_graph_representations = self.entity_encoder(None, self.entity_edge_idx, self.entity_edge_type) word_graph_representations = self.word_encoder(self.word_kg_embedding.weight, self.word_edges) entity_padding_mask = context_entities.eq(self.pad_entity_idx) # (bs, entity_len) word_padding_mask = context_words.eq(self.pad_word_idx) # (bs, seq_len) entity_representations = entity_graph_representations[context_entities] word_representations = word_graph_representations[context_words] entity_attn_rep = self.entity_self_attn(entity_representations, entity_padding_mask) word_attn_rep = self.word_self_attn(word_representations, word_padding_mask) # encoder-decoder tokens_encoding = self.conv_encoder(context_tokens) conv_entity_emb = self.conv_entity_attn_norm(entity_attn_rep) conv_word_emb = self.conv_word_attn_norm(word_attn_rep) conv_entity_reps = self.conv_entity_norm(entity_representations) conv_word_reps = self.conv_word_norm(word_representations) if mode != 'test': logits, preds = self._decode_forced_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb, entity_padding_mask, conv_word_reps, conv_word_emb, word_padding_mask, response) logits = logits.view(-1, logits.shape[-1]) response = response.view(-1) loss = self.conv_loss(logits, response) return loss, preds else: logits, preds = self._decode_greedy_with_kg(tokens_encoding, conv_entity_reps, conv_entity_emb, entity_padding_mask, conv_word_reps, conv_word_emb, word_padding_mask) return preds
[docs] def forward(self, batch, stage, mode): if len(self.gpu) >= 2: # forward function operates on different gpus, the weight of graph network need to be copied to other gpu self.entity_edge_idx = self.entity_edge_idx.cuda(torch.cuda.current_device()) self.entity_edge_type = self.entity_edge_type.cuda(torch.cuda.current_device()) self.word_edges = self.word_edges.cuda(torch.cuda.current_device()) self.copy_mask = torch.as_tensor(np.load(os.path.join(self.dpath, "copy_mask.npy")).astype(bool), ).cuda(torch.cuda.current_device()) if stage == "pretrain": loss = self.pretrain_infomax(batch) elif stage == "rec": loss = self.recommend(batch, mode) elif stage == "conv": loss = self.converse(batch, mode) return loss