log_helper.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. import logging
  3. logs = set()
  4. # LOGGER
  5. BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8)
  6. RESET_SEQ = "\033[0m"
  7. COLOR_SEQ = "\033[1;%dm"
  8. COLORS = {
  9. 'WARNING': YELLOW,
  10. 'INFO': WHITE,
  11. 'DEBUG': BLUE,
  12. 'CRITICAL': YELLOW,
  13. 'ERROR': RED
  14. }
  15. class ColoredFormatter(logging.Formatter):
  16. def __init__(self, msg, use_color=True):
  17. logging.Formatter.__init__(self, msg)
  18. self.use_color = use_color
  19. def format(self, record):
  20. msg = record.msg
  21. levelname = record.levelname
  22. if self.use_color and levelname in COLORS and COLORS[levelname] != WHITE:
  23. if isinstance(msg, str):
  24. msg_color = COLOR_SEQ % (30 + COLORS[levelname]) + msg + RESET_SEQ
  25. record.msg = msg_color
  26. levelname_color = COLOR_SEQ % (30 + COLORS[levelname]) + levelname + RESET_SEQ
  27. record.levelname = levelname_color
  28. return logging.Formatter.format(self, record)
  29. def init_log(name, level=logging.INFO):
  30. if (name, level) in logs:
  31. return
  32. logs.add((name, level))
  33. logger = logging.getLogger(name)
  34. logger.setLevel(level)
  35. ch = logging.StreamHandler()
  36. ch.setLevel(level)
  37. if 'SLURM_PROCID' in os.environ:
  38. rank = int(os.environ['SLURM_PROCID'])
  39. logger.addFilter(lambda record: rank == 0)
  40. else:
  41. rank = 0
  42. FORMAT = f'[%(levelname)s]%(asctime)s-rk{rank}-%(filename)s#%(lineno)d:%(message)s'
  43. formatter = ColoredFormatter(FORMAT)
  44. ch.setFormatter(formatter)
  45. logger.addHandler(ch)