From 2b6b22094b545e74b05c075f3daac9c14f16414d Mon Sep 17 00:00:00 2001 From: Qiuhui Liu Date: Fri, 12 Jun 2020 14:41:49 +0800 Subject: [PATCH] add missing update for utils --- utils/base.py | 45 +++++++++++++++++++++++++++++++++++++++++---- utils/comm.py | 16 ++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) create mode 100644 utils/comm.py diff --git a/utils/base.py b/utils/base.py index 087c830..531096f 100644 --- a/utils/base.py +++ b/utils/base.py @@ -5,14 +5,27 @@ from threading import Thread +from functools import wraps + from random import sample from random import seed as rpyseed +from math import ceil + import logging from utils.h5serial import h5save, h5load -mask_tensor_type = torch.uint8 if torch.__version__ < "1.2.0" else torch.bool +secure_type_map = {torch.float16: torch.float64, torch.float32: torch.float64, torch.uint8: torch.int64, torch.int8: torch.int64, torch.int16: torch.int64, torch.int32: torch.int64} + +# handling torch.bool +if torch.__version__ < "1.2.0": + mask_tensor_type = torch.uint8 + nccl_type_map = None +else: + mask_tensor_type = torch.bool + secure_type_map[mask_tensor_type] = torch.int64 + nccl_type_map = {torch.bool:torch.uint8} def pad_tensors(tensor_list, dim=-1): @@ -256,14 +269,15 @@ def ModuleList2Dict(modin): def add_module(m, strin, m_add): - if strin.find(".") < 0: + _name_list = strin.split(".") + if len(_name_list) == 1: m.add_module(strin, m_add) else: - _m, _name_list = m, strin.split(".") + _m = m # update _modules with pytorch: https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.add_module for _tmp in _name_list[:-1]: _m = _m._modules[_tmp] - _m._modules[_name_list[-1]] = m_add + _m.add_module(_name_list[-1], m_add) return m @@ -324,3 +338,26 @@ def report_parameters(modin): rs += _para.numel() return rs + +def float2odd(fin): + + _rs = ceil(fin) + if _rs % 2 == 1: + _rs -= 1 + + return _rs + +def wrap_float2odd(func): + @wraps(func) + def wrapper(*args, **kwargs): + return float2odd(func(*args, **kwargs)) + return wrapper + +def iternext(iterin): + + try: + rs = next(iterin) + except: + rs = None + + return rs diff --git a/utils/comm.py b/utils/comm.py new file mode 100644 index 0000000..0519da2 --- /dev/null +++ b/utils/comm.py @@ -0,0 +1,16 @@ +#encoding: utf-8 + +import torch.cuda.comm as comm +from utils.base import nccl_type_map + +def secure_broadcast_coalesced(tensors, devices, buffer_size=10485760): + + if nccl_type_map is None: + + return comm.broadcast_coalesced(tensors, devices, buffer_size=buffer_size) + else: + src_type = [para.dtype for para in tensors] + map_type = [nccl_type_map[para.dtype] if para.dtype in nccl_type_map else None for para in tensors] + rs = comm.broadcast_coalesced([para if typ is None else para.to(typ) for para, typ in zip(tensors, map_type)], devices, buffer_size=buffer_size) + + return list(zip(*[para if mtyp is None else [pu.to(styp) for pu in para] for para, mtyp, styp in zip(list(zip(*rs)), map_type, src_type)]))