Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CIFAR-10 evaluation #12

Open
wants to merge 144 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
4424bda
Initial implementation of CIFAR evaluation, currently runs but haven'…
rohinmshah Jul 22, 2020
53d2f2e
Initial implementation of CIFAR evaluation, currently runs but haven'…
rohinmshah Jul 22, 2020
1fc4d9a
Add support for GPU training, miscellaneous improvements
rohinmshah Aug 5, 2020
5f24dd9
Pull out model training code into its own function
rohinmshah Aug 5, 2020
96dc990
Compatibility with new learn interface
rohinmshah Aug 5, 2020
513b530
Merge master
rohinmshah Aug 5, 2020
7ff2804
Merge branch 'master' into cifar_eval
rohinmshah Aug 6, 2020
89a30dc
Implement the correct augmentations for SimCLR on CIFAR-10
rohinmshah Aug 6, 2020
db49e92
Changes to optimizers and learning rates to be more in line with SimCLR
rohinmshah Aug 9, 2020
c29cf52
Add momentum to optimizer
rohinmshah Aug 10, 2020
66ae3d5
Fix indentation bug
rohinmshah Aug 10, 2020
8303ca5
Address comments on PR, except for LinearWarmupCosine documentation, …
rohinmshah Aug 17, 2020
0d1622f
Merge branch 'master' into cifar_eval
rohinmshah Aug 17, 2020
474b6f9
Rewrote LinearWarmupCosine to be more understandable
rohinmshah Aug 17, 2020
375e76f
Merge
rohinmshah Aug 25, 2020
4950c8e
Merge
rohinmshah Aug 25, 2020
294d436
Miscellaneous small fixes
rohinmshah Aug 25, 2020
9cae4a7
Make things more parameterizable
rohinmshah Aug 26, 2020
89d5864
Update .gitignore
RPC2 Apr 13, 2021
51dad87
Merge branch 'master' into cifar_eval
RPC2 Apr 13, 2021
3fef45e
update model setting
RPC2 Apr 19, 2021
a93d017
Make CIFAR runnable for RepL!
RPC2 Apr 20, 2021
f2fc56b
classification + cleanup
RPC2 Apr 20, 2021
988ec4a
some cleanup
RPC2 Apr 20, 2021
fa647ad
Hardcode warmup_epochs to 2
decodyng Apr 21, 2021
8fe17b5
Import time
decodyng Apr 21, 2021
95a01d0
Make testloader exist
decodyng Apr 21, 2021
6869714
Is RepL training?
decodyng Apr 21, 2021
d977571
Hardcode dataset length
decodyng Apr 21, 2021
574d8fc
Remove excess logging
decodyng Apr 21, 2021
45b285a
Remove Cosine Annealing to be consistent with repo
decodyng Apr 21, 2021
f0b8f71
Put their loss in for ours
decodyng Apr 21, 2021
fc10cd4
Comment otu their losss which is nan for some reason
decodyng Apr 21, 2021
97bf11a
Add breakpoint
decodyng Apr 21, 2021
c22329f
Fix config name
decodyng Apr 21, 2021
55963e0
Switch to running their loss
decodyng Apr 21, 2021
b51f49c
Add another breakpoint
decodyng Apr 21, 2021
2c3b69a
Fix numpy call
decodyng Apr 21, 2021
9e6c792
What if you used their loss but normalized first to maybe avoid infin…
decodyng Apr 21, 2021
09254b1
Add ability to do K means evaluation
decodyng Apr 21, 2021
d464fca
Accidentally called encoder.encoder
decodyng Apr 21, 2021
44e6213
double-import tqdm
decodyng Apr 21, 2021
a35ba0a
Maybe avoid needing traj_info
decodyng Apr 21, 2021
8033682
Remove unused feature, out
decodyng Apr 21, 2021
720c15d
Allow passing in a pretrained model
decodyng Apr 21, 2021
382e9fd
Make it easier to switch between our loss and repo loss
decodyng Apr 21, 2021
2f221c0
Normalize our features before using them in KNN
decodyng Apr 21, 2021
81932ee
Unbreak torch.nn.functional import
decodyng Apr 21, 2021
5724cef
Examine image scale before augmentations
decodyng Apr 22, 2021
8592c66
Explicitly use their model class
decodyng Apr 22, 2021
6bda628
Use SimCLR model for encoder at least
decodyng Apr 22, 2021
d1463f5
No longer expect a tuple in KNN code
decodyng Apr 22, 2021
767511d
Modify decoder kwargs to be closer to SimCLR
decodyng Apr 22, 2021
6e55630
Add comma back in
decodyng Apr 22, 2021
0fc1975
Add code to save images out
decodyng Apr 22, 2021
592a278
Remove image prepreprocessing to avoid double-normalizing
decodyng Apr 22, 2021
fc7388d
Add more image saving and warnings
decodyng Apr 22, 2021
56ebd6f
Add more image saving and warnings
decodyng Apr 22, 2021
618f4ea
Save out more images
decodyng Apr 22, 2021
8b6f75a
Try to get augmentations to match SimCLR
decodyng Apr 22, 2021
01b6a6a
Still convert to PILImage
decodyng Apr 22, 2021
c5aee28
Add back numpy conversion without x255
decodyng Apr 22, 2021
94f0b9a
Add 255x back in
decodyng Apr 22, 2021
be29ef9
Normalize in the same way as SimCLR
decodyng Apr 22, 2021
9270bbc
For some reason getting a dimension error
decodyng Apr 22, 2021
7039fbf
Go back to other normalization
decodyng Apr 22, 2021
cc24d47
Cleanup and final push for the evening
decodyng Apr 23, 2021
d201c9a
Switch from bilinear to bicubic interpolation
decodyng Apr 23, 2021
fda28dd
No longer convert to numpy array before PIL image
decodyng Apr 23, 2021
292df58
Transpose numpy array so PILImage has the right shape:
decodyng Apr 23, 2021
5bdada6
Breakpoint before augmentation
Apr 23, 2021
8b7ac9e
Merge branch 'cifar_eval' of github.com:HumanCompatibleAI/il-represen…
Apr 23, 2021
40d8e7a
Uniform_ contexts and target
decodyng Apr 23, 2021
28cdd83
Merge branch 'cifar_eval' of github.com:HumanCompatibleAI/il-represen…
Apr 23, 2021
43c2f57
log every interval
decodyng Apr 23, 2021
0049873
Merge branch 'cifar_eval' of github.com:HumanCompatibleAI/il-represen…
Apr 23, 2021
d5bdbb0
Make zs uniform
decodyng Apr 23, 2021
90fb7c0
Merge branch 'cifar_eval' of github.com:HumanCompatibleAI/il-represen…
Apr 23, 2021
4446621
Set seed to 10
decodyng Apr 23, 2021
1b16a55
Merge branch 'cifar_eval' of github.com:HumanCompatibleAI/il-represen…
Apr 23, 2021
9884d1b
Add seed back in
Apr 23, 2021
432c7b0
No longer have random zs
Apr 23, 2021
1224059
Remove breakpoint
Apr 23, 2021
e6bff32
Remove random seed
decodyng Apr 23, 2021
877ecc5
Examine distribution after encoder
decodyng Apr 23, 2021
fc9dad2
No longer randomize images
decodyng Apr 23, 2021
7bbb153
Remove extraneous breakpoint
decodyng Apr 23, 2021
cb06dbc
Add parameter check to repl
decodyng Apr 24, 2021
71e129f
Swap our data loader for theirs
decodyng Apr 26, 2021
5553d2a
Add dataloader back in and add breakpoint
decodyng Apr 26, 2021
294bab4
Try to get .next() to work
decodyng Apr 26, 2021
2884db6
Swap in new contexts/targets temporarily
decodyng Apr 26, 2021
9eb90b9
If we double augment that should break things... right?
decodyng Apr 26, 2021
31864d6
Switch back to using our data loader
decodyng Apr 26, 2021
502c223
Skip the decoding step entirely
decodyng Apr 26, 2021
c4096a1
Don't calculate norm on decoder while we're testing out not using it
decodyng Apr 26, 2021
3dd1e80
Remove decoder from _calculate_norms
decodyng Apr 26, 2021
f2645d5
try to use direct network output instead of a distribution
Apr 27, 2021
29a9083
return to using multivariate normal and adjust loss and temperature
Apr 27, 2021
31d04e5
test linear head
Apr 27, 2021
77afc23
Try to fully use SimCLR repo's linear evaluation code
RPC2 Apr 27, 2021
130c127
select test method
RPC2 Apr 27, 2021
869f3bd
Add comment
decodyng Apr 27, 2021
feffcf0
Merge branch 'cifar_eval' of github.com:HumanCompatibleAI/il-represen…
decodyng Apr 27, 2021
1a537bf
Specifically ablate change to decoder
decodyng Apr 27, 2021
f769190
Switch ReLu back to be after BatchNorm
decodyng Apr 27, 2021
5d2ff3f
Remove breakpoint on Github
decodyng Apr 27, 2021
ea8deae
config for running few trajs
RPC2 Apr 30, 2021
7319c1e
Merge branch 'master' into gcp-cyn
RPC2 Apr 30, 2021
6a2842e
Update chain_configs.py
RPC2 Apr 30, 2021
93783a1
Finding a good gpu number balance
RPC2 Apr 30, 2021
4c1c2b8
Add SimCLR model to default SimCLR settings
RPC2 May 6, 2021
01473d8
Try to use 3e-4 lr for SimCLR repl
RPC2 May 6, 2021
05805a0
Merge branch 'gcp-cyn' into cifar_eval
RPC2 May 6, 2021
701c691
update config
RPC2 May 6, 2021
17af15f
comment out context saving code
RPC2 May 6, 2021
766b160
Try augmenting with SimCLR default
RPC2 May 6, 2021
a398ed8
adjust augmenter
May 6, 2021
ae814c3
Try to use multiple GPUs
RPC2 May 6, 2021
2491d70
Add a script for running simclr
May 6, 2021
d992c52
adjust decoder input dim
RPC2 May 6, 2021
111a97d
Merge branch 'cifar_eval' of https://github.com/HumanCompatibleAI/il-…
RPC2 May 6, 2021
8d4ebd5
Adjust decoder shape and normalization
May 6, 2021
872101b
Update run_il.sh for long dmc runs with few trajs
May 6, 2021
56c724e
Merge branch 'gcp-cyn' of ssh://github.com/HumanCompatibleAI/il-repre…
May 6, 2021
3834b88
Setting up loading procgen dataset
RPC2 May 7, 2021
d57fcbf
Merge branch 'procgen' of github.com:HumanCompatibleAI/il-representat…
May 7, 2021
1f2a6ad
Merge branch 'gcp-cyn' into procgen
RPC2 May 7, 2021
c679534
Merge branch 'procgen' of github.com:HumanCompatibleAI/il-representat…
May 7, 2021
2c70645
Adding support for procgen (loading env)
May 7, 2021
2f0659a
Set Procgen env names
May 7, 2021
fe9558e
Update loading procgen envs
May 11, 2021
dcd44ca
Maybe we don't need next_obs?
May 11, 2021
e236df2
Env wrapper is already handled by Procgen
May 11, 2021
9faea2d
More clean up
May 11, 2021
19d1dc3
Add framestack
May 11, 2021
7e67b54
Adjust encoder network channel
RPC2 May 12, 2021
0234acf
Merge branch 'procgen' into cifar_eval
RPC2 May 12, 2021
f2c9be2
Update simclr running script
May 12, 2021
ad9a9d1
Try a smaller network
RPC2 May 12, 2021
4365e02
See if it can run end to end
RPC2 May 12, 2021
5eb9283
Update encoder kwargs
RPC2 May 12, 2021
5fdf57a
Current script to train simclr as repl
May 12, 2021
9983123
Use default augmentation
May 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions algos/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ class ProjectionHead(LossDecoder):
def __init__(self, representation_dim, projection_shape, sample=False, learn_scale=False):
super(ProjectionHead, self).__init__(representation_dim, projection_shape, sample)

self.shared_mlp = nn.Sequential(nn.Linear(self.representation_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU())
self.mean_layer = nn.Linear(256, self.projection_dim)
dim = self.representation_dim
self.shared_mlp = nn.Sequential(nn.Linear(dim, dim),
nn.ReLU(),
nn.Linear(dim, dim),
rohinmshah marked this conversation as resolved.
Show resolved Hide resolved
nn.ReLU())
self.mean_layer = nn.Linear(dim, self.projection_dim, bias=False)

if learn_scale:
self.scale_layer = nn.Linear(256, self.projection_dim)
Expand Down
4 changes: 2 additions & 2 deletions algos/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __init__(self, obs_space, representation_dim):
super().__init__()
self.input_channel = obs_space.shape[0]
self.representation_dim = representation_dim
shared_network_layers = []

for layer_spec in DEFAULT_CNN_ARCHITECTURE['CONV']:
shared_network_layers.append(nn.Conv2d(self.input_channel, layer_spec['out_dim'],
Expand All @@ -47,7 +46,7 @@ def __init__(self, obs_space, representation_dim):
shared_network_layers.append(nn.Linear(in_dim, out_dim))
shared_network_layers.append(nn.ReLU())

self.shared_network = nn.Sequential(*shared_network_layers)
self.shared_network = nn.Sequential(*shared_network_layers)
rohinmshah marked this conversation as resolved.
Show resolved Hide resolved

self.mean_layer = nn.Linear(DEFAULT_CNN_ARCHITECTURE['DENSE'][-1]['in_dim'], self.representation_dim)
self.scale_layer = nn.Linear(DEFAULT_CNN_ARCHITECTURE['DENSE'][-1]['in_dim'], self.representation_dim)
Expand Down Expand Up @@ -80,6 +79,7 @@ def __init__(self, obs_space, representation_dim, architecture_module_cls=None,
representing the mean representation z of a fixed-variance representation distribution
"""
super(DeterministicEncoder, self).__init__()
self.representation_dim = representation_dim
if architecture_module_cls is None:
architecture_module_cls = NatureCNN
self.network = architecture_module_cls(obs_space, representation_dim)
Expand Down
7 changes: 6 additions & 1 deletion algos/representation_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, env, log_dir, encoder, decoder, loss_calculator, target_pair_
representation_dim=512,
projection_dim=None,
device=None,
normalize=True,
shuffle_batches=True,
batch_size=256,
preprocess_extra_context=True,
Expand All @@ -51,6 +52,7 @@ def __init__(self, env, log_dir, encoder, decoder, loss_calculator, target_pair_
else:
self.device = device

self.normalize = normalize
self.shuffle_batches = shuffle_batches
self.batch_size = batch_size
self.preprocess_extra_context = preprocess_extra_context
Expand Down Expand Up @@ -135,14 +137,17 @@ def _preprocess(self, input_data):
input_data = input_data.permute(self.permutation_tuple)

# Normalization to range [-1, 1]
if isinstance(self.observation_space, Box):
if self.normalize:
assert isinstance(self.observation_space, Box)
low, high = self.observation_space.low, self.observation_space.high
low_min, low_max, high_min, high_max = low.min(), low.max(), high.min(), high.max()
assert low_min == low_max and high_min == high_max
low, high = low_min, high_max
mid = (low + high) / 2
delta = high - mid
input_data = (input_data - mid) / delta

assert input_data.shape[1:] == self.observation_shape
return input_data

def _preprocess_extra_context(self, extra_context):
Expand Down
12 changes: 6 additions & 6 deletions algos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,23 @@ def log(self, msg):

class LinearWarmupCosine(_LRScheduler):
def __init__(self, optimizer, warmup_epoch, T_max, eta_min=0, last_epoch=-1):
rohinmshah marked this conversation as resolved.
Show resolved Hide resolved
rohinmshah marked this conversation as resolved.
Show resolved Hide resolved
self.T_max = T_max
self.eta_min = eta_min
self.warmup_epoch = warmup_epoch
self.cosine_epochs = T_max - warmup_epoch
super(LinearWarmupCosine, self).__init__(optimizer, last_epoch)

def get_lr(self):
if self.warmup_epoch > 0:
if self.last_epoch <= self.warmup_epoch:
return [base_lr / self.warmup_epoch * self.last_epoch for base_lr in self.base_lrs]
if ((self.last_epoch - self.warmup_epoch) - 1 - (self.T_max - self.warmup_epoch)) % (2 * (self.T_max - self.warmup_epoch)) == 0:
if ((self.last_epoch - self.warmup_epoch) - 1 - self.cosine_epochs) % (2 * self.cosine_epochs) == 0:
rohinmshah marked this conversation as resolved.
Show resolved Hide resolved
return [group['lr'] + (base_lr - self.eta_min) *
(1 - math.cos(math.pi / (self.T_max - self.warmup_epoch))) / 2
(1 - math.cos(math.pi / self.cosine_epochs)) / 2
for base_lr, group in
zip(self.base_lrs, self.optimizer.param_groups)]
else:
return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epoch) / (self.T_max - self.warmup_epoch))) /
(1 + math.cos(math.pi * ((self.last_epoch - self.warmup_epoch) - 1) / (self.T_max - self.warmup_epoch))) *
return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epoch) / self.cosine_epochs)) /
(1 + math.cos(math.pi * ((self.last_epoch - self.warmup_epoch) - 1) / self.cosine_epochs)) *
(group['lr'] - self.eta_min) + self.eta_min
for group in self.optimizer.param_groups]

Expand Down Expand Up @@ -166,4 +166,4 @@ def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
self.avg = self.sum / self.count
200 changes: 200 additions & 0 deletions run_cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from algos import *
from gym.spaces import Discrete, Box
from sacred import Experiment
from sacred.observers import FileStorageObserver
from algos.utils import gaussian_blur

import numpy as np
import os
import PIL
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models.resnet import resnet18
from algos.utils import LinearWarmupCosine


class MockGymEnv(object):
"""A mock Gym env for a supervised learning dataset pretending to be an RL
task. Action space is set to Discrete(1), observation space corresponds to
the original supervised learning task.
"""
def __init__(self, obs_space):
self.observation_space = obs_space
self.action_space = Discrete(1)

def seed(self, seed):
pass

def close(self):
pass


def transform_to_rl(dataset):
"""Transforms the input supervised learning dataset into an "RL dataset", by
adding dummy 'actions' (always 0) and 'dones' (always False), and pretending
that everything is from the same 'trajectory'.
"""
states = [img for img, label in dataset]
data_dict = {
'states': states,
'actions': [0.0] * len(states),
'dones': [False] * len(states),
}
return data_dict


class LinearHead(nn.Module):
def __init__(self, encoder, output_dim):
super().__init__()
self.encoder = encoder
self.output_dim = output_dim
self.layer = nn.Linear(encoder.representation_dim, output_dim)

def forward(self, x):
encoding = self.encoder.encode_context(x, None).loc.detach()
return self.layer(encoding)


def train_classifier(classifier, data_dir, num_epochs, device):
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomResizedCrop(32, interpolation=PIL.Image.BICUBIC),
transforms.RandomHorizontalFlip(),
# No color jitter or grayscale for finetuning
# SimCLR doesn't use blur for CIFAR-10
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(classifier.layer.parameters(), lr=0.2, momentum=0.9, weight_decay=0.0, nesterov=True)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)

for epoch in range(num_epochs):
print(f"Epoch {epoch}/{num_epochs} with lr {optimizer.param_groups[0]['lr']}")
running_loss = 0.0
for i, (inputs, labels) in enumerate(trainloader, 0):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = classifier(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# print statistics
running_loss += loss.item()
if i % 20 == 19: # print every 20 mini-batches
print('[Epoch %d, Batch %3d] Average loss: %.3f' %
(epoch + 1, i + 1, running_loss / 20))
running_loss = 0.0

scheduler.step()


def evaluate_classifier(classifier, data_dir, device):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)
correct = 0
total = 0
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)
outputs = classifier(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print('Accuracy: %d %%' % (100 * correct / total))


def representation_learning(algo, data_dir, num_epochs, device, log_dir):
print('Creating model for representation learning')

if isinstance(algo, str):
algo = globals()[algo]
assert issubclass(algo, RepresentationLearner)

rep_learning_augmentations = [
transforms.Lambda(torch.tensor),
transforms.ToPILImage(),
transforms.RandomResizedCrop(32, interpolation=PIL.Image.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# SimCLR doesn't use blur for CIFAR-10
]
env = MockGymEnv(Box(low=0.0, high=1.0, shape=(3, 32, 32), dtype=np.float32))
# Note that the resnet18 model used here has an architecture meant for
# ImageNet, not CIFAR-10. The SimCLR implementation uses a version
# specialized for CIFAR, see https://github.com/google-research/simclr/blob/37ad4e01fb22e3e6c7c4753bd51a1e481c2d992e/resnet.py#L531
# It seems that SimCLR does not include the final fully connected layer for ResNets, so we set it to the identity.
resnet_without_fc = resnet18()
resnet_without_fc.fc = torch.nn.Identity()
# Note SimCLR uses LARSOptimizer, which we currently do not do
rohinmshah marked this conversation as resolved.
Show resolved Hide resolved
model = algo(
env, log_dir=log_dir, batch_size=512, representation_dim=512, projection_dim=128,
device=device, normalize=False, shuffle_batches=True,
encoder_kwargs={'architecture_module_cls': lambda *args: resnet_without_fc},
augmenter_kwargs={'augmentations': rep_learning_augmentations},
optimizer_kwargs={'lr': 2.0, 'weight_decay': 1e-4, 'momentum': 0.9},
scheduler=LinearWarmupCosine,
scheduler_kwargs={'warmup_epoch': 10, 'T_max': num_epochs},
loss_calculator_kwargs={'temp': 0.5},
)

print('Train representation learner')
transform = transforms.ToTensor()
trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
rep_learning_data = transform_to_rl(trainset)
model.learn(rep_learning_data, num_epochs)
env.close()
return model


cifar_ex = Experiment('cifar')


@cifar_ex.config
def default_config():
seed = 1
algo = SimCLR
data_dir = 'cifar10/'
pretrain_epochs = 1000
finetune_epochs = 100
_ = locals()
del _


@cifar_ex.main
def run(seed, algo, data_dir, pretrain_epochs, finetune_epochs, _config):
# TODO fix this hacky nonsense
log_dir = os.path.join(cifar_ex.observers[0].dir, 'training_logs')
os.mkdir(log_dir)
os.makedirs(data_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = representation_learning(algo, data_dir, pretrain_epochs, device, log_dir)

print('Train linear head')
classifier = LinearHead(model.encoder, 10).to(device)
rohinmshah marked this conversation as resolved.
Show resolved Hide resolved
train_classifier(classifier, data_dir, num_epochs=finetune_epochs, device=device)

print('Evaluate accuracy on test set')
evaluate_classifier(classifier, data_dir, device=device)


if __name__ == '__main__':
cifar_ex.observers.append(FileStorageObserver('cifar_runs'))
cifar_ex.run_commandline()