Source code for crslab.data.dataloader.kgsf

# @Time   : 2020/11/22
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/11/23, 2020/12/2
# @Author : Kun Zhou, Xiaolei Wang
# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com

from copy import deepcopy

import torch
from tqdm import tqdm

from crslab.data.dataloader.base import BaseDataLoader
from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, get_onehot, truncate, merge_utt


[docs]class KGSFDataLoader(BaseDataLoader): """Dataloader for model KGSF. Notes: You can set the following parameters in config: - ``'context_truncate'``: the maximum length of context. - ``'response_truncate'``: the maximum length of response. - ``'entity_truncate'``: the maximum length of mentioned entities in context. - ``'word_truncate'``: the maximum length of mentioned words in context. The following values must be specified in ``vocab``: - ``'pad'`` - ``'start'`` - ``'end'`` - ``'pad_entity'`` - ``'pad_word'`` the above values specify the id of needed special token. - ``'n_entity'``: the number of entities in the entity KG of dataset. """ def __init__(self, opt, dataset, vocab): """ Args: opt (Config or dict): config for dataloader or the whole system. dataset: data for model. vocab (dict): all kinds of useful size, idx and map between token and idx. """ super().__init__(opt, dataset) self.n_entity = vocab['n_entity'] self.pad_token_idx = vocab['pad'] self.start_token_idx = vocab['start'] self.end_token_idx = vocab['end'] self.pad_entity_idx = vocab['pad_entity'] self.pad_word_idx = vocab['pad_word'] self.context_truncate = opt.get('context_truncate', None) self.response_truncate = opt.get('response_truncate', None) self.entity_truncate = opt.get('entity_truncate', None) self.word_truncate = opt.get('word_truncate', None)
[docs] def get_pretrain_data(self, batch_size, shuffle=True): return self.get_data(self.pretrain_batchify, batch_size, shuffle, self.retain_recommender_target)
[docs] def pretrain_batchify(self, batch): batch_context_entities = [] batch_context_words = [] for conv_dict in batch: batch_context_entities.append( truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False)) batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False)) return (padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False), get_onehot(batch_context_entities, self.n_entity))
[docs] def rec_process_fn(self): augment_dataset = [] for conv_dict in tqdm(self.dataset): if conv_dict['role'] == 'Recommender': for movie in conv_dict['items']: augment_conv_dict = deepcopy(conv_dict) augment_conv_dict['item'] = movie augment_dataset.append(augment_conv_dict) return augment_dataset
[docs] def rec_batchify(self, batch): batch_context_entities = [] batch_context_words = [] batch_item = [] for conv_dict in batch: batch_context_entities.append( truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False)) batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False)) batch_item.append(conv_dict['item']) return (padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False), get_onehot(batch_context_entities, self.n_entity), torch.tensor(batch_item, dtype=torch.long))
[docs] def conv_process_fn(self, *args, **kwargs): return self.retain_recommender_target()
[docs] def conv_batchify(self, batch): batch_context_tokens = [] batch_context_entities = [] batch_context_words = [] batch_response = [] for conv_dict in batch: batch_context_tokens.append( truncate(merge_utt(conv_dict['context_tokens']), self.context_truncate, truncate_tail=False)) batch_context_entities.append( truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False)) batch_context_words.append(truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False)) batch_response.append( add_start_end_token_idx(truncate(conv_dict['response'], self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx)) return (padded_tensor(batch_context_tokens, self.pad_token_idx, pad_tail=False), padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False), padded_tensor(batch_response, self.pad_token_idx))
[docs] def policy_batchify(self, *args, **kwargs): pass