Skip to content

Commit

Permalink
Update CLIP model for new testing & linting
Browse files Browse the repository at this point in the history
  • Loading branch information
ProGamerGov authored May 13, 2022
1 parent e9598ea commit 599d8e1
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 61 deletions.
37 changes: 15 additions & 22 deletions captum/optim/models/_image/clip_resnet50x4_image.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Optional, Type
from typing import Any, Optional, Type
from warnings import warn

import torch
from torch import nn

import torch.nn as nn
from captum.optim.models._common import RedirectedReluLayer, SkipLayer

GS_SAVED_WEIGHTS_URL = (
Expand All @@ -15,7 +14,7 @@ def clip_resnet50x4_image(
pretrained: bool = False,
progress: bool = True,
model_path: Optional[str] = None,
**kwargs
**kwargs: Any,
) -> "CLIP_ResNet50x4Image":
"""
The visual portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable
Expand All @@ -24,9 +23,8 @@ def clip_resnet50x4_image(
This model can be combined with the CLIP ResNet 50x4 Text model to create the full
CLIP ResNet 50x4 model.
AvgPool2d layers were replaced with AdaptiveAvgPool2d to allow for any input height
and width size, though the best results are obtained by using the model's intended
input height and width of 288x288.
Note that model inputs are expected to have a shape of: [B, 3, 288, 288] or
[3, 288, 288].
See here for more details:
https://github.com/openai/CLIP
Expand Down Expand Up @@ -82,6 +80,7 @@ class CLIP_ResNet50x4Image(nn.Module):
The visual portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable
Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020
"""

__constants__ = ["transform_input"]

def __init__(
Expand Down Expand Up @@ -124,13 +123,13 @@ def __init__(
self.conv3 = nn.Conv2d(40, 80, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(80)
self.relu3 = activ()
self.avgpool = nn.AdaptiveAvgPool2d(72)
self.avgpool = nn.AvgPool2d(2)

# Residual layers
self.layer1 = self._build_layer(80, 80, 4, stride=1, pooling=72, activ=activ)
self.layer2 = self._build_layer(320, 160, 6, stride=2, pooling=36, activ=activ)
self.layer3 = self._build_layer(640, 320, 10, stride=2, pooling=18, activ=activ)
self.layer4 = self._build_layer(1280, 640, 6, stride=2, pooling=9, activ=activ)
self.layer1 = self._build_layer(80, 80, blocks=4, stride=1, activ=activ)
self.layer2 = self._build_layer(320, 160, blocks=6, stride=2, activ=activ)
self.layer3 = self._build_layer(640, 320, blocks=10, stride=2, activ=activ)
self.layer4 = self._build_layer(1280, 640, blocks=6, stride=2, activ=activ)

# Attention Pooling
self.attnpool = AttentionPool2d(9, 2560, out_features=640, num_heads=40)
Expand All @@ -141,7 +140,6 @@ def _build_layer(
planes: int = 80,
blocks: int = 4,
stride: int = 1,
pooling: int = 72,
activ: Type[nn.Module] = nn.ReLU,
) -> nn.Module:
"""
Expand All @@ -160,18 +158,16 @@ def _build_layer(
Default: 4
stride (int, optional): The stride value to use for the Bottleneck layers.
Default: 1
pooling (int, optional): The output size used for nn.AdaptiveAvgPool2d.
Default: 72
activ (type of nn.Module, optional): The nn.Module class type to use for
activation layers.
Default: nn.ReLU
Returns:
residual_layer (nn.Sequential): A full residual layer.
"""
layers = [Bottleneck(inplanes, planes, stride, pooling=pooling, activ=activ)]
layers = [Bottleneck(inplanes, planes, stride, activ=activ)]
for _ in range(blocks - 1):
layers += [Bottleneck(planes * 4, planes, pooling=pooling, activ=activ)]
layers += [Bottleneck(planes * 4, planes, activ=activ)]
return nn.Sequential(*layers)

def _transform_input(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -230,7 +226,6 @@ def __init__(
inplanes: int = 80,
planes: int = 80,
stride: int = 1,
pooling: int = 72,
activ: Type[nn.Module] = nn.ReLU,
) -> None:
"""
Expand All @@ -244,8 +239,6 @@ def __init__(
Default: 80
stride (int, optional): The stride value to use for the Bottleneck layers.
Default: 1
pooling (int, optional): The output size used for nn.AdaptiveAvgPool2d.
Default: 72
activ (type of nn.Module, optional): The nn.Module class type to use for
activation layers.
Default: nn.ReLU
Expand All @@ -259,15 +252,15 @@ def __init__(
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = activ()

self.avgpool = nn.AdaptiveAvgPool2d(pooling)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()

self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu3 = activ()

if stride > 1 or inplanes != planes * 4:
self.downsample = nn.Sequential(
nn.AdaptiveAvgPool2d(pooling),
nn.AvgPool2d(stride),
nn.Conv2d(inplanes, planes * 4, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(planes * 4),
)
Expand Down
9 changes: 5 additions & 4 deletions captum/optim/models/_image/clip_resnet50x4_text.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional

import math
from typing import Any, Optional

import torch
from torch import nn
import torch.nn as nn


GS_SAVED_WEIGHTS_URL = (
Expand All @@ -14,7 +14,7 @@ def clip_resnet50x4_text(
pretrained: bool = False,
progress: bool = True,
model_path: Optional[str] = None,
**kwargs
**kwargs: Any,
) -> "CLIP_ResNet50x4Text":
"""
The text portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable
Expand Down Expand Up @@ -72,6 +72,7 @@ class CLIP_ResNet50x4Text(nn.Module):
The text portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable
Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020
"""

def __init__(
self,
width: int = 640,
Expand Down
62 changes: 36 additions & 26 deletions tests/optim/models/test_clip_resnet50x4_image.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
#!/usr/bin/env python3
import unittest
from typing import Type

import torch

from captum.optim.models import clip_resnet50x4_image
from captum.optim.models._common import RedirectedReluLayer, SkipLayer
from packaging import version
from tests.helpers.basic import BaseTest, assertTensorAlmostEqual
from tests.optim.helpers.models import check_layer_in_model


class TestCLIPResNet50x4Image(BaseTest):
def test_load_clip_resnet50x4_image_with_redirected_relu(self) -> None:
if torch.__version__ <= "1.6.0":
if version.parse(torch.__version__) <= version.parse("1.6.0"):
raise unittest.SkipTest(
"Skipping load pretrained CLIP ResNet 50x4 Image due to insufficient"
+ " Torch version."
Expand All @@ -23,7 +22,7 @@ def test_load_clip_resnet50x4_image_with_redirected_relu(self) -> None:
self.assertTrue(check_layer_in_model(model, RedirectedReluLayer))

def test_load_clip_resnet50x4_image_no_redirected_relu(self) -> None:
if torch.__version__ <= "1.6.0":
if version.parse(torch.__version__) <= version.parse("1.6.0"):
raise unittest.SkipTest(
"Skipping load pretrained CLIP ResNet 50x4 Image RedirectedRelu test"
+ " due to insufficient Torch version."
Expand All @@ -35,7 +34,7 @@ def test_load_clip_resnet50x4_image_no_redirected_relu(self) -> None:
self.assertTrue(check_layer_in_model(model, torch.nn.ReLU))

def test_load_clip_resnet50x4_image_linear(self) -> None:
if torch.__version__ <= "1.6.0":
if version.parse(torch.__version__) <= version.parse("1.6.0"):
raise unittest.SkipTest(
"Skipping load pretrained CLIP ResNet 50x4 Image linear test due to"
+ " insufficient Torch version."
Expand All @@ -46,7 +45,7 @@ def test_load_clip_resnet50x4_image_linear(self) -> None:
self.assertTrue(check_layer_in_model(model, SkipLayer))

def test_clip_resnet50x4_image_transform(self) -> None:
if torch.__version__ <= "1.6.0":
if version.parse(torch.__version__) <= version.parse("1.6.0"):
raise unittest.SkipTest(
"Skipping CLIP ResNet 50x4 Image internal transform test due to"
+ " insufficient Torch version."
Expand All @@ -63,20 +62,20 @@ def test_clip_resnet50x4_image_transform(self) -> None:
assertTensorAlmostEqual(self, output, expected_output, 0)

def test_clip_resnet50x4_image_transform_warning(self) -> None:
if torch.__version__ <= "1.6.0":
if version.parse(torch.__version__) <= version.parse("1.6.0"):
raise unittest.SkipTest(
"Skipping CLIP ResNet 50x4 Image internal transform warning test due"
+ " to insufficient Torch version."
)
x = torch.stack(
[torch.ones(3, 112, 112) * -1, torch.ones(3, 112, 112) * 2], dim=0
[torch.ones(3, 288, 288) * -1, torch.ones(3, 288, 288) * 2], dim=0
)
model = clip_resnet50x4_image(pretrained=True)
with self.assertWarns(UserWarning):
model._transform_input(x)

def test_clip_resnet50x4_image_load_and_forward(self) -> None:
if torch.__version__ <= "1.6.0":
if version.parse(torch.__version__) <= version.parse("1.6.0"):
raise unittest.SkipTest(
"Skipping basic pretrained CLIP ResNet 50x4 Image forward test due to"
+ " insufficient Torch version."
Expand All @@ -87,7 +86,7 @@ def test_clip_resnet50x4_image_load_and_forward(self) -> None:
self.assertEqual(list(output.shape), [1, 640])

def test_untrained_clip_resnet50x4_image_load_and_forward(self) -> None:
if torch.__version__ <= "1.6.0":
if version.parse(torch.__version__) <= version.parse("1.6.0"):
raise unittest.SkipTest(
"Skipping basic untrained CLIP ResNet 50x4 Image forward test due to"
+ " insufficient Torch version."
Expand All @@ -97,24 +96,21 @@ def test_untrained_clip_resnet50x4_image_load_and_forward(self) -> None:
output = model(x)
self.assertEqual(list(output.shape), [1, 640])

def test_clip_resnet50x4_image_load_and_forward_diff_sizes(self) -> None:
if torch.__version__ <= "1.6.0":
def test_clip_resnet50x4_image_warning(self) -> None:
if version.parse(torch.__version__) <= version.parse("1.6.0"):
raise unittest.SkipTest(
"Skipping pretrained CLIP ResNet 50x4 Image forward with different"
+ " sized inputs test due to insufficient Torch version."
"Skipping pretrained CLIP ResNet 50x4 Image transform input"
+ " warning test due to insufficient Torch version."
)
x = torch.zeros(1, 3, 512, 512)
x2 = torch.zeros(1, 3, 126, 224)
x = torch.stack(
[torch.ones(3, 288, 288) * -1, torch.ones(3, 288, 288) * 2], dim=0
)
model = clip_resnet50x4_image(pretrained=True)

output = model(x)
output2 = model(x2)

self.assertEqual(list(output.shape), [1, 640])
self.assertEqual(list(output2.shape), [1, 640])
with self.assertWarns(UserWarning):
_ = model._transform_input(x)

def test_clip_resnet50x4_image_forward_cuda(self) -> None:
if torch.__version__ <= "1.6.0":
if version.parse(torch.__version__) <= version.parse("1.6.0"):
raise unittest.SkipTest(
"Skipping pretrained CLIP ResNet 50x4 Image forward CUDA test due to"
+ " insufficient Torch version."
Expand All @@ -124,23 +120,37 @@ def test_clip_resnet50x4_image_forward_cuda(self) -> None:
"Skipping pretrained CLIP ResNet 50x4 Image forward CUDA test due to"
+ " not supporting CUDA."
)
x = torch.zeros(1, 3, 224, 224).cuda()
x = torch.zeros(1, 3, 288, 288).cuda()
model = clip_resnet50x4_image(pretrained=True).cuda()
output = model(x)

self.assertTrue(output.is_cuda)
self.assertEqual(list(output.shape), [1, 640])

def test_clip_resnet50x4_image_jit_module_no_redirected_relu(self) -> None:
if torch.__version__ <= "1.8.0":
if version.parse(torch.__version__) <= version.parse("1.8.0"):
raise unittest.SkipTest(
"Skipping pretrained CLIP ResNet 50x4 Image load & JIT module with"
+ " no redirected relu test due to insufficient Torch version."
)
x = torch.zeros(1, 3, 224, 224)
x = torch.zeros(1, 3, 288, 288)
model = clip_resnet50x4_image(
pretrained=True, replace_relus_with_redirectedrelu=False
)
jit_model = torch.jit.script(model)
output = jit_model(x)
self.assertEqual(list(output.shape), [1, 640])

def test_clip_resnet50x4_image_jit_module_with_redirected_relu(self) -> None:
if version.parse(torch.__version__) <= version.parse("1.8.0"):
raise unittest.SkipTest(
"Skipping pretrained CLIP ResNet 50x4 Image load & JIT module with"
+ " redirected relu test due to insufficient Torch version."
)
x = torch.zeros(1, 3, 288, 288)
model = clip_resnet50x4_image(
pretrained=True, replace_relus_with_redirectedrelu=True
)
jit_model = torch.jit.script(model)
output = jit_model(x)
self.assertEqual(list(output.shape), [1, 640])
18 changes: 9 additions & 9 deletions tests/optim/models/test_clip_resnet50x4_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,37 @@
import unittest

import torch

from captum.optim.models import clip_resnet50x4_text
from packaging import version
from tests.helpers.basic import BaseTest, assertTensorAlmostEqual


class TestCLIPResNet50x4Text(BaseTest):
def test_clip_resnet50x4_text_logit_scale(self) -> None:
if torch.__version__ <= "1.6.0":
if version.parse(torch.__version__) <= version.parse("1.6.0"):
raise unittest.SkipTest(
"Skipping basic pretrained CLIP ResNet 50x4 Text logit scale test due"
+ " to insufficient Torch version."
)
model = clip_resnet50x4_text(pretrained=True)
expected_logit_scale = torch.tensor([4.605170249938965])
expected_logit_scale = torch.tensor(4.605170249938965)
assertTensorAlmostEqual(self, model.logit_scale, expected_logit_scale)

def test_clip_resnet50x4_text_load_and_forward(self) -> None:
if torch.__version__ <= "1.6.0":
if version.parse(torch.__version__) <= version.parse("1.6.0"):
raise unittest.SkipTest(
"Skipping basic pretrained CLIP ResNet 50x4 Text forward test due to"
+ " insufficient Torch version."
)
# Start & End tokens: 49405, 49406
x = torch.cat([torch.tensor([49405, 49406]), torch.zeros(77 - 2)])
x = x.int()[None, :]
x = x[None, :].long()
model = clip_resnet50x4_text(pretrained=True)
output = model(x)
self.assertEqual(list(output.shape), [1, 640])

def test_clip_resnet50x4_text_forward_cuda(self) -> None:
if torch.__version__ <= "1.6.0":
if version.parse(torch.__version__) <= version.parse("1.6.0"):
raise unittest.SkipTest(
"Skipping pretrained CLIP ResNet 50x4 Text forward CUDA test due to"
+ " insufficient Torch version."
Expand All @@ -43,21 +43,21 @@ def test_clip_resnet50x4_text_forward_cuda(self) -> None:
+ " not supporting CUDA."
)
x = torch.cat([torch.tensor([49405, 49406]), torch.zeros(77 - 2)]).cuda()
x = x.int()[None, :]
x = x[None, :].long()
model = clip_resnet50x4_text(pretrained=True).cuda()
output = model(x)

self.assertTrue(output.is_cuda)
self.assertEqual(list(output.shape), [1, 640])

def test_clip_resnet50x4_text_jit_module(self) -> None:
if torch.__version__ <= "1.8.0":
if version.parse(torch.__version__) <= version.parse("1.8.0"):
raise unittest.SkipTest(
"Skipping pretrained CLIP ResNet 50x4 Text load & JIT module"
+ " test due to insufficient Torch version."
)
x = torch.cat([torch.tensor([49405, 49406]), torch.zeros(77 - 2)])
x = x.int()[None, :]
x = x[None, :].long()
model = clip_resnet50x4_text(pretrained=True)
jit_model = torch.jit.script(model)
output = jit_model(x)
Expand Down

0 comments on commit 599d8e1

Please sign in to comment.