-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
36 lines (30 loc) · 1.38 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import numpy as np
import logging
from mmcv.utils import get_logger
def get_total_grad_norm(parameters, norm_type=2):
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
device = parameters[0].grad.device
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
norm_type)
return total_norm
def image_graph_collate_scene_graph(batch):
batch = list(zip(*batch))
return tuple(batch)
def get_root_logger(log_file=None, log_level=logging.INFO):
"""Use ``get_logger`` method in mmcv to get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If ``log_file`` is specified, a FileHandler
will also be added. The name of the root logger is the top-level package
name, e.g., "mmaction".
Args:
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
:obj:`logging.Logger`: The root logger.
"""
return get_logger(__name__.split('.')[0], log_file, log_level)