Source code for crslab.model.crs.redial.redial_rec

# @Time   : 2020/12/4
# @Author : Chenzhan Shang
# @Email  : czshang@outlook.com

# UPDATE
# @Time   : 2020/12/29, 2021/1/4
# @Author : Xiaolei Wang, Yuanhang Zhou
# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com

r"""
ReDial_Rec
==========
References:
    Li, Raymond, et al. `"Towards deep conversational recommendations."`_ in NeurIPS.

.. _`"Towards deep conversational recommendations."`:
   https://papers.nips.cc/paper/2018/hash/800de15c79c8d840f4e78d3af937d4d4-Abstract.html

"""

import torch.nn as nn

from crslab.model.base import BaseModel


[docs]class ReDialRecModel(BaseModel): """ Attributes: n_entity: A integer indicating the number of entities. layer_sizes: A integer indicating the size of layer in autorec. pad_entity_idx: A integer indicating the id of entity padding. """ def __init__(self, opt, device, vocab, side_data): """ Args: opt (dict): A dictionary record the hyper parameters. device (torch.device): A variable indicating which device to place the data and model. vocab (dict): A dictionary record the vocabulary information. side_data (dict): A dictionary record the side data. """ self.n_entity = vocab['n_entity'] self.layer_sizes = opt['autorec_layer_sizes'] self.pad_entity_idx = vocab['pad_entity'] super(ReDialRecModel, self).__init__(opt, device)
[docs] def build_model(self): # AutoRec if self.opt['autorec_f'] == 'identity': self.f = lambda x: x elif self.opt['autorec_f'] == 'sigmoid': self.f = nn.Sigmoid() elif self.opt['autorec_f'] == 'relu': self.f = nn.ReLU() else: raise ValueError("Got invalid function name for f : {}".format(self.opt['autorec_f'])) if self.opt['autorec_g'] == 'identity': self.g = lambda x: x elif self.opt['autorec_g'] == 'sigmoid': self.g = nn.Sigmoid() elif self.opt['autorec_g'] == 'relu': self.g = nn.ReLU() else: raise ValueError("Got invalid function name for g : {}".format(self.opt['autorec_g'])) self.encoder = nn.ModuleList([nn.Linear(self.n_entity, self.layer_sizes[0]) if i == 0 else nn.Linear(self.layer_sizes[i - 1], self.layer_sizes[i]) for i in range(len(self.layer_sizes))]) self.user_repr_dim = self.layer_sizes[-1] self.decoder = nn.Linear(self.user_repr_dim, self.n_entity) self.loss = nn.CrossEntropyLoss()
[docs] def forward(self, batch, mode): """ Args: batch: :: { 'context_entities': (batch_size, n_entity), 'item': (batch_size) } mode (str) """ context_entities = batch['context_entities'] for i, layer in enumerate(self.encoder): context_entities = self.f(layer(context_entities)) scores = self.g(self.decoder(context_entities)) loss = self.loss(scores, batch['item']) return loss, scores