Source code for crslab.model.recommendation.popularity.popularity

# @Time   : 2020/12/16
# @Author : Yuanhang Zhou
# @Email  : sdzyh002@gmail.com

# UPDATE
# @Time   : 2020/12/29, 2021/1/4
# @Author : Xiaolei Wang, Yuanhang Zhou
# @email  : wxl1999@foxmail.com, sdzyh002@gmail.com

r"""
Popularity
==========
"""

from collections import defaultdict

import torch
from loguru import logger

from crslab.model.base import BaseModel


[docs]class PopularityModel(BaseModel): """ Attributes: item_size: A integer indicating the number of items. """ def __init__(self, opt, device, vocab, side_data): """ Args: opt (dict): A dictionary record the hyper parameters. device (torch.device): A variable indicating which device to place the data and model. vocab (dict): A dictionary record the vocabulary information. side_data (dict): A dictionary record the side data. """ self.item_size = vocab['n_entity'] super(PopularityModel, self).__init__(opt, device)
[docs] def build_model(self): self.item_frequency = defaultdict(int) logger.debug('[Finish build rec layer]')
[docs] def forward(self, batch, mode): context, mask, input_ids, target_pos, input_mask, sample_negs, y = batch if mode == 'train': for ids in input_ids: for id in ids: self.item_frequency[id.item()] += 1 bs = input_ids.shape[0] rec_score = [self.item_frequency.get(item_id, 0) for item_id in range(self.item_size)] rec_scores = torch.tensor([rec_score] * bs, dtype=torch.long) loss = torch.zeros(1, requires_grad=True) return loss, rec_scores