From 16363650d07c8efb6fcb602570b360b2f6e1f710 Mon Sep 17 00:00:00 2001 From: Hang Zhang <8041160+zhanghang1989@users.noreply.github.com> Date: Fri, 15 Jun 2018 13:36:48 -0700 Subject: [PATCH] fix path (#73) --- encoding/models/base.py | 11 +++++++---- encoding/models/encnet.py | 5 +++-- encoding/models/fcn.py | 8 ++++---- encoding/models/psp.py | 6 +++--- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/encoding/models/base.py b/encoding/models/base.py index 60624e23..79ef0699 100644 --- a/encoding/models/base.py +++ b/encoding/models/base.py @@ -24,7 +24,7 @@ class BaseNet(nn.Module): def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None, - mean=[.485, .456, .406], std=[.229, .224, .225]): + mean=[.485, .456, .406], std=[.229, .224, .225], root='~/.encoding/models'): super(BaseNet, self).__init__() self.nclass = nclass self.aux = aux @@ -33,11 +33,14 @@ def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None self.std = std # copying modules from pretrained models if backbone == 'resnet50': - self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated, norm_layer=norm_layer) + self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated, + norm_layer=norm_layer, root=root) elif backbone == 'resnet101': - self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated, norm_layer=norm_layer) + self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated, + norm_layer=norm_layer, root=root) elif backbone == 'resnet152': - self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated, norm_layer=norm_layer) + self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated, + norm_layer=norm_layer, root=root) else: raise RuntimeError('unknown backbone: {}'.format(backbone)) # bilinear upsample options diff --git a/encoding/models/encnet.py b/encoding/models/encnet.py index 15af2921..7df71913 100644 --- a/encoding/models/encnet.py +++ b/encoding/models/encnet.py @@ -19,7 +19,8 @@ class EncNet(BaseNet): def __init__(self, nclass, backbone, aux=True, se_loss=True, lateral=False, norm_layer=nn.BatchNorm2d, **kwargs): - super(EncNet, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer) + super(EncNet, self).__init__(nclass, backbone, aux, se_loss, + norm_layer=norm_layer, **kwargs) self.head = EncHead(self.nclass, in_channels=2048, se_loss=se_loss, lateral=lateral, norm_layer=norm_layer, up_kwargs=self._up_kwargs) @@ -142,7 +143,7 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False, kwargs['lateral'] = True if dataset.lower() == 'pcontext' else False # infer number of classes from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation - model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs) + model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs) if pretrained: from .model_store import get_model_file model.load_state_dict(torch.load( diff --git a/encoding/models/fcn.py b/encoding/models/fcn.py index 2586267f..3f1dac00 100644 --- a/encoding/models/fcn.py +++ b/encoding/models/fcn.py @@ -39,7 +39,7 @@ class FCN(BaseNet): >>> print(model) """ def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs): - super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer) + super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs) self.head = FCNHead(2048, nclass, norm_layer) if aux: self.auxlayer = FCNHead(1024, nclass, norm_layer) @@ -97,7 +97,7 @@ def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False, } # infer number of classes from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation - model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs) + model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs) if pretrained: from .model_store import get_model_file model.load_state_dict(torch.load( @@ -122,7 +122,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa >>> model = get_fcn_resnet50_pcontext(pretrained=True) >>> print(model) """ - return get_fcn('pcontext', 'resnet50', pretrained, aux=False, **kwargs) + return get_fcn('pcontext', 'resnet50', pretrained, root=root, aux=False, **kwargs) def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation" @@ -141,4 +141,4 @@ def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): >>> model = get_fcn_resnet50_ade(pretrained=True) >>> print(model) """ - return get_fcn('ade20k', 'resnet50', pretrained, **kwargs) + return get_fcn('ade20k', 'resnet50', pretrained, root=root, **kwargs) diff --git a/encoding/models/psp.py b/encoding/models/psp.py index c21ebe4c..31e0e4fb 100644 --- a/encoding/models/psp.py +++ b/encoding/models/psp.py @@ -16,7 +16,7 @@ class PSP(BaseNet): def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs): - super(PSP, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer) + super(PSP, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs) self.head = PSPHead(2048, nclass, norm_layer, self._up_kwargs) if aux: self.auxlayer = FCNHead(1024, nclass, norm_layer) @@ -59,7 +59,7 @@ def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False, } # infer number of classes from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation - model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs) + model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs) if pretrained: from .model_store import get_model_file model.load_state_dict(torch.load( @@ -83,4 +83,4 @@ def get_psp_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs): >>> model = get_psp_resnet50_ade(pretrained=True) >>> print(model) """ - return get_psp('ade20k', 'resnet50', pretrained) + return get_psp('ade20k', 'resnet50', pretrained, root=root, **kwargs)