Skip to content

Commit

Permalink
Set C++17 for latest pytorch versions. Add flags for CUDA 12 and 11.8 (
Browse files Browse the repository at this point in the history
…#641)

* Set C++17 for latest pytorch versions. Add flags for CUDA 12 and 11.8

* Update setup.py

* remove import subprocess

* more robust way to compare version

---------

Co-authored-by: Jinze Xue <[email protected]>
  • Loading branch information
RaulPPelaez and Jinze Xue authored Nov 14, 2023
1 parent 40cf334 commit 17204c6
Showing 1 changed file with 12 additions and 29 deletions.
41 changes: 12 additions & 29 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import subprocess
from packaging import version
from setuptools import setup, find_packages
from distutils import log
import sys
Expand All @@ -24,32 +24,6 @@
long_description = fh.read()


def maybe_download_cub():
import torch
dirs = torch.utils.cpp_extension.include_paths(cuda=True)
for d in dirs:
cubdir = os.path.join(d, 'cub')
log.info(f'Searching for cub at {cubdir}...')
if os.path.isdir(cubdir):
log.info(f'Found cub in {cubdir}')
return []
# if no cub, download it to include dir from github
if not os.path.isdir('./include/cub'):
if not os.path.exists('./include'):
os.makedirs('include')
commands = """
echo "Downloading CUB library";
wget -q https://github.com/NVIDIA/cub/archive/refs/tags/1.11.0.zip;
unzip -q 1.11.0.zip -d include;
mv include/cub-1.11.0/cub include;
echo "Removing unnecessary files";
rm 1.11.0.zip;
rm -rf include/cub-1.11.0;
"""
subprocess.run(commands, shell=True, check=True, universal_newlines=True)
return [os.path.abspath("./include")]


def cuda_extension(build_all=False):
import torch
from torch.utils.cpp_extension import CUDAExtension
Expand Down Expand Up @@ -87,15 +61,24 @@ def cuda_extension(build_all=False):
nvcc_args.append("-gencode=arch=compute_80,code=sm_80")
if cuda_version >= 11.1:
nvcc_args.append("-gencode=arch=compute_86,code=sm_86")
if cuda_version >= 11.8:
nvcc_args.append("-gencode=arch=compute_89,code=sm_89")
if cuda_version >= 12.0:
nvcc_args.append("-gencode=arch=compute_90,code=sm_90")

print("nvcc_args: ", nvcc_args)
print('-' * 75)
include_dirs = [*maybe_download_cub(), os.path.abspath("torchani/cuaev/")]
include_dirs = [os.path.abspath("torchani/cuaev/")]
# Update C++ standard based on PyTorch version
pytorch_version = version.parse(torch.__version__)
cxx_args = ['-std=c++17'] if pytorch_version >= version.parse("2.1.0") else ['-std=c++14']

return CUDAExtension(
name='torchani.cuaev',
pkg='torchani.cuaev',
sources=["torchani/cuaev/cuaev.cpp", "torchani/cuaev/aev.cu"],
include_dirs=include_dirs,
extra_compile_args={'cxx': ['-std=c++14'], 'nvcc': nvcc_args})
extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})


def cuaev_kwargs():
Expand Down

0 comments on commit 17204c6

Please sign in to comment.