crslab.model.crs.tgredial package

Submodules

TGReDial_Conv

References

Zhou, Kun, et al. “Towards Topic-Guided Conversational Recommender System.” in COLING 2020.

class crslab.model.crs.tgredial.tg_conv.TGConvModel(opt, device, vocab, side_data)[source]

Bases: crslab.model.base.BaseModel

context_truncate

A integer indicating the length of dialogue context.

response_truncate

A integer indicating the length of dialogue response.

pad_id

A integer indicating the id of padding token.

Parameters
  • opt (dict) – A dictionary record the hyper parameters.

  • device (torch.device) – A variable indicating which device to place the data and model.

  • vocab (dict) – A dictionary record the vocabulary information.

  • side_data (dict) – A dictionary record the side data.

build_model()[source]

build model

calculate_loss(logit, labels)[source]
Parameters
  • preds – torch.FloatTensor, shape=(bs, response_truncate, vocab_size)

  • labels – torch.LongTensor, shape=(bs, response_truncate)

forward(batch, mode)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

generate(context)[source]
Parameters

context – torch.tensor, shape=(bs, context_turncate)

Returns

torch.tensor, shape=(bs, reponse_turncate-1)

Return type

generated_response

generate_bs(context, beam=4)[source]

TGReDial_Policy

References

Zhou, Kun, et al. “Towards Topic-Guided Conversational Recommender System.” in COLING 2020.

class crslab.model.crs.tgredial.tg_policy.TGPolicyModel(opt, device, vocab, side_data)[source]

Bases: crslab.model.base.BaseModel

Parameters
  • opt (dict) – A dictionary record the hyper parameters.

  • device (torch.device) – A variable indicating which device to place the data and model.

  • vocab (dict) – A dictionary record the vocabulary information.

  • side_data (dict) – A dictionary record the side data.

build_model(*args, **kwargs)[source]

build model

forward(batch, mode)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

TGReDial_Rec

References

Zhou, Kun, et al. “Towards Topic-Guided Conversational Recommender System.” in COLING 2020.

class crslab.model.crs.tgredial.tg_rec.TGRecModel(opt, device, vocab, side_data)[source]

Bases: crslab.model.base.BaseModel

hidden_dropout_prob

A float indicating the dropout rate to dropout hidden state in SASRec.

initializer_range

A float indicating the range of parameters initization in SASRec.

hidden_size

A integer indicating the size of hidden state in SASRec.

max_seq_length

A integer indicating the max interaction history length.

item_size

A integer indicating the number of items.

num_attention_heads

A integer indicating the head number in SASRec.

attention_probs_dropout_prob

A float indicating the dropout rate in attention layers.

hidden_act

A string indicating the activation function type in SASRec.

num_hidden_layers

A integer indicating the number of hidden layers in SASRec.

Parameters
  • opt (dict) – A dictionary record the hyper parameters.

  • device (torch.device) – A variable indicating which device to place the data and model.

  • vocab (dict) – A dictionary record the vocabulary information.

  • side_data (dict) – A dictionary record the side data.

build_model()[source]

build model

forward(batch, mode)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Module contents