Source code for crslab.data.dataloader.tgredial

# @Time   : 2020/12/9
# @Author : Yuanhang Zhou
# @Email  : sdzyh002@gmail.com

# UPDATE:
# @Time   : 2020/12/29, 2020/12/15
# @Author : Xiaolei Wang, Yuanhang Zhou
# @Email  : wxl1999@foxmail.com, sdzyh002@gmail

import random
from copy import deepcopy

import torch
from tqdm import tqdm

from crslab.data.dataloader.base import BaseDataLoader
from crslab.data.dataloader.utils import add_start_end_token_idx, padded_tensor, truncate, merge_utt


[docs]class TGReDialDataLoader(BaseDataLoader): """Dataloader for model TGReDial. Notes: You can set the following parameters in config: - ``'context_truncate'``: the maximum length of context. - ``'response_truncate'``: the maximum length of response. - ``'entity_truncate'``: the maximum length of mentioned entities in context. - ``'word_truncate'``: the maximum length of mentioned words in context. - ``'item_truncate'``: the maximum length of mentioned items in context. The following values must be specified in ``vocab``: - ``'pad'`` - ``'start'`` - ``'end'`` - ``'unk'`` - ``'pad_entity'`` - ``'pad_word'`` the above values specify the id of needed special token. - ``'ind2tok'``: map from index to token. - ``'tok2ind'``: map from token to index. - ``'vocab_size'``: size of vocab. - ``'id2entity'``: map from index to entity. - ``'n_entity'``: number of entities in the entity KG of dataset. - ``'sent_split'`` (optional): token used to split sentence. Defaults to ``'end'``. - ``'word_split'`` (optional): token used to split word. Defaults to ``'end'``. - ``'pad_topic'`` (optional): token used to pad topic. - ``'ind2topic'`` (optional): map from index to topic. """ def __init__(self, opt, dataset, vocab): """ Args: opt (Config or dict): config for dataloader or the whole system. dataset: data for model. vocab (dict): all kinds of useful size, idx and map between token and idx. """ super().__init__(opt, dataset) self.n_entity = vocab['n_entity'] self.item_size = self.n_entity self.pad_token_idx = vocab['pad'] self.start_token_idx = vocab['start'] self.end_token_idx = vocab['end'] self.unk_token_idx = vocab['unk'] self.conv_bos_id = vocab['start'] self.cls_id = vocab['start'] self.sep_id = vocab['end'] if 'sent_split' in vocab: self.sent_split_idx = vocab['sent_split'] else: self.sent_split_idx = vocab['end'] if 'word_split' in vocab: self.word_split_idx = vocab['word_split'] else: self.word_split_idx = vocab['end'] self.pad_entity_idx = vocab['pad_entity'] self.pad_word_idx = vocab['pad_word'] if 'pad_topic' in vocab: self.pad_topic_idx = vocab['pad_topic'] self.tok2ind = vocab['tok2ind'] self.ind2tok = vocab['ind2tok'] self.id2entity = vocab['id2entity'] if 'ind2topic' in vocab: self.ind2topic = vocab['ind2topic'] self.context_truncate = opt.get('context_truncate', None) self.response_truncate = opt.get('response_truncate', None) self.entity_truncate = opt.get('entity_truncate', None) self.word_truncate = opt.get('word_truncate', None) self.item_truncate = opt.get('item_truncate', None)
[docs] def rec_process_fn(self, *args, **kwargs): augment_dataset = [] for conv_dict in tqdm(self.dataset): for movie in conv_dict['items']: augment_conv_dict = deepcopy(conv_dict) augment_conv_dict['item'] = movie augment_dataset.append(augment_conv_dict) return augment_dataset
def _process_rec_context(self, context_tokens): compact_context = [] for i, utterance in enumerate(context_tokens): if i != 0: utterance.insert(0, self.sent_split_idx) compact_context.append(utterance) compat_context = truncate(merge_utt(compact_context), self.context_truncate - 2, truncate_tail=False) compat_context = add_start_end_token_idx(compat_context, self.start_token_idx, self.end_token_idx) return compat_context def _neg_sample(self, item_set): item = random.randint(1, self.item_size) while item in item_set: item = random.randint(1, self.item_size) return item def _process_history(self, context_items, item_id=None): input_ids = truncate(context_items, max_length=self.item_truncate, truncate_tail=False) input_mask = [1] * len(input_ids) sample_negs = [] seq_set = set(input_ids) for _ in input_ids: sample_negs.append(self._neg_sample(seq_set)) if item_id is not None: target_pos = input_ids[1:] + [item_id] return input_ids, target_pos, input_mask, sample_negs else: return input_ids, input_mask, sample_negs
[docs] def rec_batchify(self, batch): batch_context = [] batch_movie_id = [] batch_input_ids = [] batch_target_pos = [] batch_input_mask = [] batch_sample_negs = [] for conv_dict in batch: context = self._process_rec_context(conv_dict['context_tokens']) batch_context.append(context) item_id = conv_dict['item'] batch_movie_id.append(item_id) if 'interaction_history' in conv_dict: context_items = conv_dict['interaction_history'] + conv_dict[ 'context_items'] else: context_items = conv_dict['context_items'] input_ids, target_pos, input_mask, sample_negs = self._process_history( context_items, item_id) batch_input_ids.append(input_ids) batch_target_pos.append(target_pos) batch_input_mask.append(input_mask) batch_sample_negs.append(sample_negs) batch_context = padded_tensor(batch_context, self.pad_token_idx, max_len=self.context_truncate) batch_mask = (batch_context != self.pad_token_idx).long() return (batch_context, batch_mask, padded_tensor(batch_input_ids, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), padded_tensor(batch_target_pos, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), padded_tensor(batch_input_mask, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), padded_tensor(batch_sample_negs, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), torch.tensor(batch_movie_id))
[docs] def rec_interact(self, data): context = [self._process_rec_context(data['context_tokens'])] if 'interaction_history' in data: context_items = data['interaction_history'] + data['context_items'] else: context_items = data['context_items'] input_ids, input_mask, sample_negs = self._process_history(context_items) input_ids, input_mask, sample_negs = [input_ids], [input_mask], [sample_negs] context = padded_tensor(context, self.pad_token_idx, max_len=self.context_truncate) mask = (context != self.pad_token_idx).long() return (context, mask, padded_tensor(input_ids, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), None, padded_tensor(input_mask, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), padded_tensor(sample_negs, pad_idx=self.pad_token_idx, pad_tail=False, max_len=self.item_truncate), None)
[docs] def conv_batchify(self, batch): batch_context_tokens = [] batch_enhanced_context_tokens = [] batch_response = [] batch_context_entities = [] batch_context_words = [] for conv_dict in batch: context_tokens = [utter + [self.conv_bos_id] for utter in conv_dict['context_tokens']] context_tokens[-1] = context_tokens[-1][:-1] batch_context_tokens.append( truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False), ) batch_response.append( add_start_end_token_idx( truncate(conv_dict['response'], max_length=self.response_truncate - 2), start_token_idx=self.start_token_idx, end_token_idx=self.end_token_idx ) ) batch_context_entities.append( truncate(conv_dict['context_entities'], self.entity_truncate, truncate_tail=False)) batch_context_words.append( truncate(conv_dict['context_words'], self.word_truncate, truncate_tail=False)) enhanced_topic = [] if 'target' in conv_dict: for target_policy in conv_dict['target']: topic_variable = target_policy[1] if isinstance(topic_variable, list): for topic in topic_variable: enhanced_topic.append(topic) enhanced_topic = [[ self.tok2ind.get(token, self.unk_token_idx) for token in self.ind2topic[topic_id] ] for topic_id in enhanced_topic] enhanced_topic = merge_utt(enhanced_topic, self.word_split_idx, False, self.sent_split_idx) enhanced_movie = [] if 'items' in conv_dict: for movie_id in conv_dict['items']: enhanced_movie.append(movie_id) enhanced_movie = [ [self.tok2ind.get(token, self.unk_token_idx) for token in self.id2entity[movie_id].split('(')[0]] for movie_id in enhanced_movie] enhanced_movie = truncate(merge_utt(enhanced_movie, self.word_split_idx, self.sent_split_idx), self.item_truncate, truncate_tail=False) if len(enhanced_movie) != 0: enhanced_context_tokens = enhanced_movie + truncate(batch_context_tokens[-1], max_length=self.context_truncate - len( enhanced_movie), truncate_tail=False) elif len(enhanced_topic) != 0: enhanced_context_tokens = enhanced_topic + truncate(batch_context_tokens[-1], max_length=self.context_truncate - len( enhanced_topic), truncate_tail=False) else: enhanced_context_tokens = batch_context_tokens[-1] batch_enhanced_context_tokens.append( enhanced_context_tokens ) batch_context_tokens = padded_tensor(items=batch_context_tokens, pad_idx=self.pad_token_idx, max_len=self.context_truncate, pad_tail=False) batch_response = padded_tensor(batch_response, pad_idx=self.pad_token_idx, max_len=self.response_truncate, pad_tail=True) batch_input_ids = torch.cat((batch_context_tokens, batch_response), dim=1) batch_enhanced_context_tokens = padded_tensor(items=batch_enhanced_context_tokens, pad_idx=self.pad_token_idx, max_len=self.context_truncate, pad_tail=False) batch_enhanced_input_ids = torch.cat((batch_enhanced_context_tokens, batch_response), dim=1) return (batch_enhanced_input_ids, batch_enhanced_context_tokens, batch_input_ids, batch_context_tokens, padded_tensor(batch_context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(batch_context_words, self.pad_word_idx, pad_tail=False), batch_response)
[docs] def conv_interact(self, data): context_tokens = [utter + [self.conv_bos_id] for utter in data['context_tokens']] context_tokens[-1] = context_tokens[-1][:-1] context_tokens = [truncate(merge_utt(context_tokens), max_length=self.context_truncate, truncate_tail=False)] context_tokens = padded_tensor(items=context_tokens, pad_idx=self.pad_token_idx, max_len=self.context_truncate, pad_tail=False) context_entities = [truncate(data['context_entities'], self.entity_truncate, truncate_tail=False)] context_words = [truncate(data['context_words'], self.word_truncate, truncate_tail=False)] return (context_tokens, context_tokens, context_tokens, context_tokens, padded_tensor(context_entities, self.pad_entity_idx, pad_tail=False), padded_tensor(context_words, self.pad_word_idx, pad_tail=False), None)
[docs] def policy_process_fn(self, *args, **kwargs): augment_dataset = [] for conv_dict in tqdm(self.dataset): for target_policy in conv_dict['target']: topic_variable = target_policy[1] for topic in topic_variable: augment_conv_dict = deepcopy(conv_dict) augment_conv_dict['target_topic'] = topic augment_dataset.append(augment_conv_dict) return augment_dataset
[docs] def policy_batchify(self, batch): batch_context = [] batch_context_policy = [] batch_user_profile = [] batch_target = [] for conv_dict in batch: final_topic = conv_dict['final'] final_topic = [[ self.tok2ind.get(token, self.unk_token_idx) for token in self.ind2topic[topic_id] ] for topic_id in final_topic[1]] final_topic = merge_utt(final_topic, self.word_split_idx, False, self.sep_id) context = conv_dict['context_tokens'] context = merge_utt(context, self.sent_split_idx, False, self.sep_id) context += final_topic context = add_start_end_token_idx( truncate(context, max_length=self.context_truncate - 1, truncate_tail=False), start_token_idx=self.cls_id) batch_context.append(context) # [topic, topic, ..., topic] context_policy = [] for policies_one_turn in conv_dict['context_policy']: if len(policies_one_turn) != 0: for policy in policies_one_turn: for topic_id in policy[1]: if topic_id != self.pad_topic_idx: policy = [] for token in self.ind2topic[topic_id]: policy.append(self.tok2ind.get(token, self.unk_token_idx)) context_policy.append(policy) context_policy = merge_utt(context_policy, self.word_split_idx, False) context_policy = add_start_end_token_idx( context_policy, start_token_idx=self.cls_id, end_token_idx=self.sep_id) context_policy += final_topic batch_context_policy.append(context_policy) batch_user_profile.extend(conv_dict['user_profile']) batch_target.append(conv_dict['target_topic']) batch_context = padded_tensor(batch_context, pad_idx=self.pad_token_idx, pad_tail=True, max_len=self.context_truncate) batch_cotnext_mask = (batch_context != self.pad_token_idx).long() batch_context_policy = padded_tensor(batch_context_policy, pad_idx=self.pad_token_idx, pad_tail=True) batch_context_policy_mask = (batch_context_policy != 0).long() batch_user_profile = padded_tensor(batch_user_profile, pad_idx=self.pad_token_idx, pad_tail=True) batch_user_profile_mask = (batch_user_profile != 0).long() batch_target = torch.tensor(batch_target, dtype=torch.long) return (batch_context, batch_cotnext_mask, batch_context_policy, batch_context_policy_mask, batch_user_profile, batch_user_profile_mask, batch_target)