#!/usr/bin/env python
import yaml
import argparse
import pprint
import time
import os
[docs]class Config:
"""
Class to handle configuration from input file and command line args.
"""
def __init__(self):
self.parseArgs()
# set the default config
self.defaultConfig()
self.parseConfig()
[docs] def defaultConfig(self):
"""
This method generates default config variables. Any new variables meant to go in the YAML
config file should go here as well. One should use the set() method.
"""
self.config = {}
self.config['cpus_per_task'] = 1
self.config['batch_size'] = 128
self.config['learning_rate'] = 1e-3
self.config['input_label'] = 'X'
self.config['output_label'] = 'Y'
self.config['model'] = 'dnn.py'
self.config['mixed_precision'] = True
self.config['twin'] = False
self.config['n_training_samples'] = 0
self.config['n_testing_samples'] = 0
self.config['use_hist'] = False
[docs] def parseArgs(self):
"""
Set up the argument parser with a few additions to handle multi-node training.
"""
parser = argparse.ArgumentParser(description='Machine learning with Pytorch. Change dnn.py and your YAML input file to modify training.')
parser.add_argument('-lr', '--local_rank', default=0, type=int,
help='ranking within the nodes')
parser.add_argument('-nr', '--node_rank', default=0, type=int, help='Node number')
parser.add_argument('-ng', '--gpus_per_node', default=1, type=int, help='Number of GPUs on node.')
parser.add_argument('-i', '--input', default='input.yaml', type=str)
parser.add_argument('-cp','--checkpoint_path', default='./', type=str )
self.args = parser.parse_args()
[docs] def parseConfig(self):
"""
Parse the configuration from the input yaml file.
"""
yaml_config = yaml.safe_load(open(self.args.input, 'r'))
for key in yaml_config:
self.config[key] = yaml_config[key]
[docs] def printConfig(self):
"""
Print out the configuration.
"""
pp = pprint.PrettyPrinter(indent=4)
print()
print('Here is your configuration:')
pp.pprint(self.config)
print()
[docs] def getConfig(self):
"""
Get the input configuration.
Returns
-------
A Dictionary which holds the configuration.
"""
return self.config
[docs] def getArgs(self):
"""
Get the command line configuration.
Returns
-------
A argparse object.
"""
return self.args
[docs] def set(self, key, val):
"""
Add new variables to the configuration.
"""
self.config[key] = val