Source code for crslab.model

# @Time   : 2020/11/22
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/11/24, 2020/12/24
# @Author : Kun Zhou, Xiaolei Wang
# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com

# @Time   : 2021/10/06
# @Author : Zhipeng Zhao
# @Email  : oran_official@outlook.com

import torch
from loguru import logger

from .conversation import *
from .crs import *
from .policy import *
from .recommendation import *

Model_register_table = {
    'KGSF': KGSFModel,
    'KBRD': KBRDModel,
    'TGRec': TGRecModel,
    'TGConv': TGConvModel,
    'TGPolicy': TGPolicyModel,
    'ReDialRec': ReDialRecModel,
    'ReDialConv': ReDialConvModel,
    'InspiredRec': InspiredRecModel,
    'InspiredConv': InspiredConvModel,
    'GPT2': GPT2Model,
    'Transformer': TransformerModel,
    'ConvBERT': ConvBERTModel,
    'ProfileBERT': ProfileBERTModel,
    'TopicBERT': TopicBERTModel,
    'PMI': PMIModel,
    'MGCG': MGCGModel,
    'BERT': BERTModel,
    'SASREC': SASRECModel,
    'GRU4REC': GRU4RECModel,
    'Popularity': PopularityModel,
    'TextCNN': TextCNNModel,
    'NTRD': NTRDModel
}


[docs]def get_model(config, model_name, device, vocab, side_data=None): if model_name in Model_register_table: model = Model_register_table[model_name](config, device, vocab, side_data) logger.info(f'[Build model {model_name}]') if config.opt["gpu"] == [-1]: return model else: if len(config.opt["gpu"]) > 1: if model_name == 'PMI' or model_name == 'KBRD': logger.info(f'[PMI/KBRD model does not support multi GPUs yet, using single GPU now]') return model.to(device) else: return torch.nn.DataParallel(model, device_ids=config["gpu"]) else: return model.to(device) else: raise NotImplementedError('Model [{}] has not been implemented'.format(model_name))