Source code for crslab.data.dataloader.kbrd

# @Time   : 2020/11/27
# @Author : Xiaolei Wang
# @Email  : wxl1999@foxmail.com

# UPDATE:
# @Time   : 2020/12/2
# @Author : Xiaolei Wang
# @Email  : wxl1999@foxmail.com

import torch
from tqdm import tqdm

from crslab.data.dataloader.base import BaseDataLoader
from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt


[docs]class KBRDDataLoader(BaseDataLoader): """Dataloader for model KBRD. Notes: You can set the following parameters in config: - ``'context_truncate'``: the maximum length of context. - ``'response_truncate'``: the maximum length of response. - ``'entity_truncate'``: the maximum length of mentioned entities in context. The following values must be specified in ``vocab``: - ``'pad'`` - ``'start'`` - ``'end'`` - ``'pad_entity'`` the above values specify the id of needed special token. """ def __init__(self, opt, dataset, vocab): """ Args: opt (Config or dict): config for dataloader or the whole system. dataset: data for model. vocab (dict): all kinds of useful size, idx and map between token and idx. """ super().__init__(opt, dataset) self.pad_token_idx = vocab['pad'] self.start_token_idx = vocab['start'] self.end_token_idx = vocab['end'] self.pad_entity_idx = vocab['pad_entity'] self.context_truncate = opt.get('context_truncate', None) self.response_truncate = opt.get('response_truncate', None) self.entity_truncate = opt.get('entity_truncate', None)
[docs] def rec_process_fn(self): augment_dataset = [] for conv_dict in tqdm(self.dataset): if conv_dict['role'] == 'Recommender': for movie in conv_dict['items']: augment_conv_dict = {'context_entities': conv_dict['context_entities'], 'item': movie} augment_dataset.append(augment_conv_dict) return augment_dataset
[docs] def rec_batchify(self, batch): batch_context_entities = [] batch_movies = [] for conv_dict in batch: batch_context_entities.append(conv_dict['context_entities']) batch_movies.append(conv_dict['item']) return { "context_entities": batch_context_entities, "item": torch.tensor(batch_movies, dtype=torch.long) }
[docs] def conv_process_fn(self, *args, **kwargs): return self.retain_recommender_target()
[docs] def conv_batchify(self, batch): batch_context_tokens = [] batch_context_entities = [] batch_response = [] for conv_dict in batch: batch_context_tokens.append( truncate(merge_utt(conv_dict['context_tokens']), self.context_truncate, truncate_tail=False)) batch_context_entities.append(conv_dict['context_entities']) batch_response.append( add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx)) return { "context_tokens": padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False), "context_entities": batch_context_entities, "response": padded_tensor(batch_response, self.pad_token_idx) }
[docs] def policy_batchify(self, *args, **kwargs): pass