crslab.model package

Submodules

class crslab.model.base.BaseModel(opt, device, dpath=None, resource=None)[source]

Bases: abc.ABC, torch.nn.modules.module.Module

Base class for all models

abstract build_model(*args, **kwargs)[source]

build model

converse(batch, mode)[source]

calculate loss and prediction of conversation for batch under certain mode

Parameters
  • batch (dict or tuple) – batch data

  • mode (str, optional) – train/valid/test.

guide(batch, mode)[source]

calculate loss and prediction of guidance for batch under certain mode

Parameters
  • batch (dict or tuple) – batch data

  • mode (str, optional) – train/valid/test.

recommend(batch, mode)[source]

calculate loss and prediction of recommendation for batch under certain mode

Parameters
  • batch (dict or tuple) – batch data

  • mode (str, optional) – train/valid/test.

Module contents

crslab.model.get_model(config, model_name, device, vocab, side_data=None)[source]