-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple formatting with Black, CPU support for inference and forgotten main function in training script #3
base: master
Are you sure you want to change the base?
Conversation
__pycache__ | ||
*/__pycache__ | ||
**/__pycache__ | ||
saved_models/ | ||
.vscode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simple .gitignore
import glob | ||
import time | ||
|
||
import numpy as np | ||
from PIL import Image | ||
from skimage import io, transform | ||
|
||
import torch | ||
import torchvision | ||
from torch.autograd import Variable | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.autograd import Variable | ||
from torch.utils.data import Dataset, DataLoader | ||
from torchvision import transforms#, utils | ||
# import torch.optim as optim | ||
|
||
import numpy as np | ||
from PIL import Image | ||
import glob | ||
# import torch.optim as optim | ||
import torchvision | ||
from torchvision import transforms # , utils |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prettier import statements
u2net_test.py
Outdated
if torch.cuda.is_available(): | ||
net.load_state_dict(torch.load(model_dir)) | ||
net.cuda() | ||
else: | ||
net.load_state_dict(torch.load(model_dir, map_location=torch.device("cpu"))) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you try to load_state_dict on CPU without mapping location to CPU, you will have a RuntimeError
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tried this with torch==0.4.0
and it gives me TypeError: 'torch.Device' object is not callable
. Solved it by upgrading to torch==0.4.1
, so this should also be updated in the README
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MatiasConTilde this should also work on torch>=0.4 with map_location="cpu"
start = time.time() | ||
|
||
inputs_test = data_test['image'] | ||
inputs_test = data_test["image"] | ||
inputs_test = inputs_test.type(torch.FloatTensor) | ||
|
||
if torch.cuda.is_available(): | ||
inputs_test = Variable(inputs_test.cuda()) | ||
else: | ||
inputs_test = Variable(inputs_test) | ||
|
||
d1,d2,d3,d4,d5,d6,d7= net(inputs_test) | ||
d1, d2, d3, d4, d5, d6, d7 = net(inputs_test) | ||
|
||
print( | ||
f"Predicted {os.path.basename(img_name_list[i_test])} in {time.time() - start:.2f}s" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return loss0, loss | ||
|
||
|
||
def main(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that you forgot to add the main function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
except for the commented lines, the rest is just reformatting with black.
should work with torch>=0.4
thanks for your contribution we are reviewing and testing it. Will update later. |
Also, adding |
Hello, thank you for uploading your code.
I've made a few small changes.