# @Time : 2020/11/22
# @Author : Kun Zhou
# @Email : francis_kun_zhou@163.com
# UPDATE:
# @Time : 2020/11/24, 2020/12/29
# @Author : Kun Zhou, Xiaolei Wang
# @Email : francis_kun_zhou@163.com, wxl1999@foxmail.com
from abc import ABC, abstractmethod
from torch import nn
from crslab.download import build
[docs]class BaseModel(ABC, nn.Module):
"""Base class for all models"""
def __init__(self, opt, device, dpath=None, resource=None):
super(BaseModel, self).__init__()
self.opt = opt
self.device = device
if resource is not None:
self.dpath = dpath
dfile = resource['file']
build(dpath, dfile, version=resource['version'])
self.build_model()
[docs] @abstractmethod
def build_model(self, *args, **kwargs):
"""build model"""
pass
[docs] def recommend(self, batch, mode):
"""calculate loss and prediction of recommendation for batch under certain mode
Args:
batch (dict or tuple): batch data
mode (str, optional): train/valid/test.
"""
pass
[docs] def converse(self, batch, mode):
"""calculate loss and prediction of conversation for batch under certain mode
Args:
batch (dict or tuple): batch data
mode (str, optional): train/valid/test.
"""
pass
[docs] def guide(self, batch, mode):
"""calculate loss and prediction of guidance for batch under certain mode
Args:
batch (dict or tuple): batch data
mode (str, optional): train/valid/test.
"""
pass