| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import os
- import logging
- logs = set()
- # LOGGER
- BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
- RESET_SEQ = "\033[0m"
- COLOR_SEQ = "\033[1;%dm"
- COLORS = {
- 'WARNING': YELLOW,
- 'INFO': WHITE,
- 'DEBUG': BLUE,
- 'CRITICAL': YELLOW,
- 'ERROR': RED
- }
- class ColoredFormatter(logging.Formatter):
- def __init__(self, msg, use_color=True):
- logging.Formatter.__init__(self, msg)
- self.use_color = use_color
- def format(self, record):
- msg = record.msg
- levelname = record.levelname
- if self.use_color and levelname in COLORS and COLORS[levelname] != WHITE:
- if isinstance(msg, str):
- msg_color = COLOR_SEQ % (30 + COLORS[levelname]) + msg + RESET_SEQ
- record.msg = msg_color
- levelname_color = COLOR_SEQ % (30 + COLORS[levelname]) + levelname + RESET_SEQ
- record.levelname = levelname_color
- return logging.Formatter.format(self, record)
- def init_log(name, level=logging.INFO):
- if (name, level) in logs:
- return
- logs.add((name, level))
- logger = logging.getLogger(name)
- logger.setLevel(level)
- ch = logging.StreamHandler()
- ch.setLevel(level)
- if 'SLURM_PROCID' in os.environ:
- rank = int(os.environ['SLURM_PROCID'])
- logger.addFilter(lambda record: rank == 0)
- else:
- rank = 0
- FORMAT = f'[%(levelname)s]%(asctime)s-rk{rank}-%(filename)s#%(lineno)d:%(message)s'
- formatter = ColoredFormatter(FORMAT)
- ch.setFormatter(formatter)
- logger.addHandler(ch)
|