Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump PyTorch version #150

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
keywords=['meta-learning', 'pytorch', 'few-shot', 'few-shot learning'],
packages=find_packages(exclude=['data', 'contrib', 'docs', 'tests', 'examples']),
install_requires=[
'torch>=1.4.0,<1.10.0',
'torchvision>=0.5.0,<0.11.0',
'torch>=1.4.0,<1.15.0',
'torchvision>=0.5.0,<0.16.0',
'numpy>=1.14.0',
'Pillow>=7.0.0',
'h5py',
Expand Down
87 changes: 47 additions & 40 deletions torchmeta/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,58 +30,65 @@ def get_asset(*args, dtype=None):
# is currently no protection against exceeded quotas. If you get an integrity error in Torchmeta
# (e.g. "MiniImagenet integrity check failed" for MiniImagenet), then this means that the quota
# has exceeded for this dataset. See also: https://github.com/tristandeleu/pytorch-meta/issues/54
#
#
# See also: https://github.com/pytorch/vision/issues/2992
#
#
# The following functions are taken from
# https://github.com/pytorch/vision/blob/cd0268cd408d19d91f870e36fdffd031085abe13/torchvision/datasets/utils.py

from torchvision.datasets.utils import _get_confirm_token, _save_response_content
try:
from torchvision.datasets.utils import _get_confirm_token, _save_response_content

def _quota_exceeded(response: "requests.models.Response"):
return False
# See https://github.com/pytorch/vision/issues/2992 for details
# return "Google Drive - Quota exceeded" in response.text
except ImportError: # `_get_confirm_token` does not exist in torchvision 0.14.0
# assume it is fixed upstream, and use it
from torchvision.datasets.utils import download_file_from_google_drive

else:

def download_file_from_google_drive(file_id, root, filename=None, md5=None):
"""Download a Google Drive file from and place it in root.
def _quota_exceeded(response: "requests.models.Response"):
return False
# See https://github.com/pytorch/vision/issues/2992 for details
# return "Google Drive - Quota exceeded" in response.text

Args:
file_id (str): id of file to be downloaded
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
import requests
url = "https://docs.google.com/uc?export=download"

root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)
def download_file_from_google_drive(file_id, root, filename=None, md5=None):
"""Download a Google Drive file from and place it in root.

os.makedirs(root, exist_ok=True)
Args:
file_id (str): id of file to be downloaded
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
import requests
url = "https://docs.google.com/uc?export=download"

if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else:
session = requests.Session()
root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)

os.makedirs(root, exist_ok=True)

if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else:
session = requests.Session()

response = session.get(url, params={'id': file_id}, stream=True)
token = _get_confirm_token(response)
response = session.get(url, params={'id': file_id}, stream=True)
token = _get_confirm_token(response)

if token:
params = {'id': file_id, 'confirm': token}
response = session.get(url, params=params, stream=True)
if token:
params = {'id': file_id, 'confirm': token}
response = session.get(url, params=params, stream=True)

if _quota_exceeded(response):
msg = (
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
raise RuntimeError(msg)
if _quota_exceeded(response):
msg = (
f"The daily quota of the file {filename} is exceeded and it "
f"can't be downloaded. This is a limitation of Google Drive "
f"and can only be overcome by trying again later."
)
raise RuntimeError(msg)

_save_response_content(response, fpath)
_save_response_content(response, fpath)