Source code for crslab.config.config

# @Time   : 2020/11/22
# @Author : Kun Zhou
# @Email  : francis_kun_zhou@163.com

# UPDATE:
# @Time   : 2020/11/23, 2021/1/9
# @Author : Kun Zhou, Xiaolei Wang
# @Email  : francis_kun_zhou@163.com, wxl1999@foxmail.com

import json
import os
import time
from pprint import pprint

import yaml
import torch
from loguru import logger
from tqdm import tqdm


[docs]class Config: """Configurator module that load the defined parameters.""" def __init__(self, config_file, gpu='-1', debug=False): """Load parameters and set log level. Args: config_file (str): path to the config file, which should be in ``yaml`` format. You can use default config provided in the `Github repo`_, or write it by yourself. debug (bool, optional): whether to enable debug function during running. Defaults to False. .. _Github repo: https://github.com/RUCAIBox/CRSLab """ self.opt = self.load_yaml_configs(config_file) # gpu os.environ['CUDA_VISIBLE_DEVICES'] = gpu if gpu != '-1': self.opt['gpu'] = [i for i in range(len(gpu.split(',')))] else: self.opt['gpu'] = [-1] # dataset dataset = self.opt['dataset'] tokenize = self.opt['tokenize'] if isinstance(tokenize, dict): tokenize = ', '.join(tokenize.values()) # model model = self.opt.get('model', None) rec_model = self.opt.get('rec_model', None) conv_model = self.opt.get('conv_model', None) policy_model = self.opt.get('policy_model', None) if model: model_name = model else: models = [] if rec_model: models.append(rec_model) if conv_model: models.append(conv_model) if policy_model: models.append(policy_model) model_name = '_'.join(models) self.opt['model_name'] = model_name # log log_name = self.opt.get("log_name", dataset + '_' + model_name + '_' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) + ".log" if not os.path.exists("log"): os.makedirs("log") logger.remove() if debug: level = 'DEBUG' else: level = 'INFO' logger.add(os.path.join("log", log_name), level=level) logger.add(lambda msg: tqdm.write(msg, end=''), colorize=True, level=level) logger.info(f"[Dataset: {dataset} tokenized in {tokenize}]") if model: logger.info(f'[Model: {model}]') if rec_model: logger.info(f'[Recommendation Model: {rec_model}]') if conv_model: logger.info(f'[Conversation Model: {conv_model}]') if policy_model: logger.info(f'[Policy Model: {policy_model}]') logger.info("[Config]" + '\n' + json.dumps(self.opt, indent=4))
[docs] @staticmethod def load_yaml_configs(filename): """This function reads ``yaml`` file to build config dictionary Args: filename (str): path to ``yaml`` config Returns: dict: config """ config_dict = dict() with open(filename, 'r', encoding='utf-8') as f: config_dict.update(yaml.safe_load(f.read())) return config_dict
def __setitem__(self, key, value): if not isinstance(key, str): raise TypeError("index must be a str.") self.opt[key] = value def __getitem__(self, item): if item in self.opt: return self.opt[item] else: return None
[docs] def get(self, item, default=None): """Get value of corrsponding item in config Args: item (str): key to query in config default (optional): default value for item if not found in config. Defaults to None. Returns: value of corrsponding item in config """ if item in self.opt: return self.opt[item] else: return default
def __contains__(self, key): if not isinstance(key, str): raise TypeError("index must be a str.") return key in self.opt def __str__(self): return str(self.opt) def __repr__(self): return self.__str__()
if __name__ == '__main__': opt_dict = Config('../../config/crs/kbrd/redial.yaml') pprint(opt_dict)