Source code for crslab.model.conversation.gpt2.gpt2

# @Time   : 2020/12/14
# @Author : Yuanhang Zhou
# @Email  : sdzyh002@gmail.com

# UPDATE
# @Time   : 2021/1/7
# @Author : Xiaolei Wang
# @email  : wxl1999@foxmail.com

r"""
GPT2
====
References:
    Radford, Alec, et al. `"Language Models are Unsupervised Multitask Learners."`_.

.. _`"Language Models are Unsupervised Multitask Learners."`:
   https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf

"""

import os

import torch
from torch.nn import CrossEntropyLoss
from transformers import GPT2LMHeadModel

from crslab.config import PRETRAIN_PATH
from crslab.data import dataset_language_map
from crslab.model.base import BaseModel
from crslab.model.pretrained_models import resources


[docs]class GPT2Model(BaseModel): """ Attributes: context_truncate: A integer indicating the length of dialogue context. response_truncate: A integer indicating the length of dialogue response. pad_id: A integer indicating the id of padding token. """ def __init__(self, opt, device, vocab, side_data): """ Args: opt (dict): A dictionary record the hyper parameters. device (torch.device): A variable indicating which device to place the data and model. vocab (dict): A dictionary record the vocabulary information. side_data (dict): A dictionary record the side data. """ self.context_truncate = opt['context_truncate'] self.response_truncate = opt['response_truncate'] self.pad_id = vocab['pad'] language = dataset_language_map[opt['dataset']] resource = resources['gpt2'][language] dpath = os.path.join(PRETRAIN_PATH, "gpt2", language) super(GPT2Model, self).__init__(opt, device, dpath, resource)
[docs] def build_model(self): """build model""" self.model = GPT2LMHeadModel.from_pretrained(self.dpath) self.loss = CrossEntropyLoss(ignore_index=self.pad_id)
[docs] def forward(self, batch, mode): _, _, input_ids, context, _, _, y = batch if mode != 'test': # torch.tensor's shape = (bs, seq_len, v_s); tuple's length = 12 lm_logits = self.model(input_ids).logits # index from 1 to self.reponse_truncate is valid response loss = self.calculate_loss( lm_logits[:, -self.response_truncate:-1, :], input_ids[:, -self.response_truncate + 1:]) pred = torch.max(lm_logits, dim=2)[1] # [bs, seq_len] pred = pred[:, -self.response_truncate:] return loss, pred else: return self.generate(context)
[docs] def generate(self, context): """ Args: context: torch.tensor, shape=(bs, context_turncate) Returns: generated_response: torch.tensor, shape=(bs, reponse_turncate-1) """ generated_response = [] former_hidden_state = None context = context[..., -self.response_truncate + 1:] for i in range(self.response_truncate - 1): outputs = self.model(context, former_hidden_state) # (bs, c_t, v_s), last_hidden_state, former_hidden_state = outputs.logits, outputs.past_key_values next_token_logits = last_hidden_state[:, -1, :] # (bs, v_s) preds = next_token_logits.argmax(dim=-1).long() # (bs) context = preds.unsqueeze(1) generated_response.append(preds) generated_response = torch.stack(generated_response).T return generated_response
[docs] def calculate_loss(self, logit, labels): """ Args: preds: torch.FloatTensor, shape=(bs, response_truncate, vocab_size) labels: torch.LongTensor, shape=(bs, response_truncate) """ loss = self.loss(logit.reshape(-1, logit.size(-1)), labels.reshape(-1)) return loss
[docs] def generate_bs(self, context, beam=4): context = context[..., -self.response_truncate + 1:] context_former = context batch_size = context.shape[0] sequences = [[[list(), 1.0]]] * batch_size for i in range(self.response_truncate - 1): if sequences != [[[list(), 1.0]]] * batch_size: context = [] for i in range(batch_size): for cand in sequences[i]: text = torch.cat( (context_former[i], torch.tensor(cand[0]).to(self.device))) # 由于取消了state,与之前的context拼接 context.append(text) context = torch.stack(context) with torch.no_grad(): outputs = self.model(context) last_hidden_state, state = outputs.logits, outputs.past_key_values next_token_logits = last_hidden_state[:, -1, :] next_token_probs = torch.nn.functional.softmax(next_token_logits) topk = torch.topk(next_token_probs, beam, dim=-1) probs = topk.values.reshape([batch_size, -1, beam]) # (bs, candidate, beam) preds = topk.indices.reshape([batch_size, -1, beam]) # (bs, candidate, beam) for j in range(batch_size): all_candidates = [] for n in range(len(sequences[j])): for k in range(beam): seq = sequences[j][n][0] prob = sequences[j][n][1] seq_tmp = seq.copy() seq_tmp.append(preds[j][n][k]) candidate = [seq_tmp, prob * probs[j][n][k]] all_candidates.append(candidate) ordered = sorted(all_candidates, key=lambda tup: tup[1], reverse=True) sequences[j] = ordered[:beam] res = [] for i in range(batch_size): res.append(torch.stack(sequences[i][0][0])) res = torch.stack(res) return res