# @Time : 2020/11/22
# @Author : Chenzhan Shang
# @Email : czshang@outlook.com
# UPDATE:
# @Time : 2020/12/16
# @Author : Xiaolei Wang
# @Email : wxl1999@foxmail.com
import re
from copy import copy
import torch
from tqdm import tqdm
from crslab.data.dataloader.base import BaseDataLoader
from crslab.data.dataloader.utils import padded_tensor, get_onehot, truncate
movie_pattern = re.compile(r'^@\d{5,6}$')
[docs]class ReDialDataLoader(BaseDataLoader):
"""Dataloader for model ReDial.
Notes:
You can set the following parameters in config:
- ``'utterance_truncate'``: the maximum length of a single utterance.
- ``'conversation_truncate'``: the maximum length of the whole conversation.
The following values must be specified in ``vocab``:
- ``'pad'``
- ``'start'``
- ``'end'``
- ``'unk'``
the above values specify the id of needed special token.
- ``'ind2tok'``: map from index to token.
- ``'n_entity'``: number of entities in the entity KG of dataset.
- ``'vocab_size'``: size of vocab.
"""
def __init__(self, opt, dataset, vocab):
"""
Args:
opt (Config or dict): config for dataloader or the whole system.
dataset: data for model.
vocab (dict): all kinds of useful size, idx and map between token and idx.
"""
super().__init__(opt, dataset)
self.ind2tok = vocab['ind2tok']
self.n_entity = vocab['n_entity']
self.pad_token_idx = vocab['pad']
self.start_token_idx = vocab['start']
self.end_token_idx = vocab['end']
self.unk_token_idx = vocab['unk']
self.item_token_idx = vocab['vocab_size']
self.conversation_truncate = self.opt.get('conversation_truncate', None)
self.utterance_truncate = self.opt.get('utterance_truncate', None)
[docs] def rec_process_fn(self, *args, **kwargs):
dataset = []
for conversation in self.dataset:
if conversation['role'] == 'Recommender':
for item in conversation['items']:
context_entities = conversation['context_entities']
dataset.append({'context_entities': context_entities, 'item': item})
return dataset
[docs] def rec_batchify(self, batch):
batch_context_entities = []
batch_item = []
for conversation in batch:
batch_context_entities.append(conversation['context_entities'])
batch_item.append(conversation['item'])
context_entities = get_onehot(batch_context_entities, self.n_entity)
return {'context_entities': context_entities, 'item': torch.tensor(batch_item, dtype=torch.long)}
[docs] def conv_process_fn(self):
dataset = []
for conversation in tqdm(self.dataset):
if conversation['role'] != 'Recommender':
continue
context_tokens = [truncate(utterance, self.utterance_truncate, truncate_tail=True) for utterance in
conversation['context_tokens']]
context_tokens = truncate(context_tokens, self.conversation_truncate, truncate_tail=True)
context_length = len(context_tokens)
utterance_lengths = [len(utterance) for utterance in context_tokens]
request = context_tokens[-1]
response = truncate(conversation['response'], self.utterance_truncate, truncate_tail=True)
dataset.append({'context_tokens': context_tokens, 'context_length': context_length,
'utterance_lengths': utterance_lengths, 'request': request, 'response': response})
return dataset
[docs] def conv_batchify(self, batch):
max_utterance_length = max([max(conversation['utterance_lengths']) for conversation in batch])
max_response_length = max([len(conversation['response']) for conversation in batch])
max_utterance_length = max(max_utterance_length, max_response_length)
max_context_length = max([conversation['context_length'] for conversation in batch])
batch_context = []
batch_context_length = []
batch_utterance_lengths = []
batch_request = [] # tensor
batch_request_length = []
batch_response = []
for conversation in batch:
padded_context = padded_tensor(conversation['context_tokens'], pad_idx=self.pad_token_idx,
pad_tail=True, max_len=max_utterance_length)
if len(conversation['context_tokens']) < max_context_length:
pad_tensor = padded_context.new_full(
(max_context_length - len(conversation['context_tokens']), max_utterance_length), self.pad_token_idx
)
padded_context = torch.cat((padded_context, pad_tensor), 0)
batch_context.append(padded_context)
batch_context_length.append(conversation['context_length'])
batch_utterance_lengths.append(conversation['utterance_lengths'] +
[0] * (max_context_length - len(conversation['context_tokens'])))
request = conversation['request']
batch_request_length.append(len(request))
batch_request.append(request)
response = copy(conversation['response'])
# replace '^\d{5,6}$' by '__item__'
for i in range(len(response)):
if movie_pattern.match(self.ind2tok[response[i]]):
response[i] = self.item_token_idx
batch_response.append(response)
context = torch.stack(batch_context, dim=0)
request = padded_tensor(batch_request, self.pad_token_idx, pad_tail=True, max_len=max_utterance_length)
response = padded_tensor(batch_response, self.pad_token_idx, pad_tail=True,
max_len=max_utterance_length) # (bs, utt_len)
return {'context': context, 'context_lengths': torch.tensor(batch_context_length),
'utterance_lengths': torch.tensor(batch_utterance_lengths), 'request': request,
'request_lengths': torch.tensor(batch_request_length), 'response': response}
[docs] def policy_batchify(self, batch):
pass