From 76a313151834f17a7d386d95961b11aff936f08e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Mon, 28 Oct 2024 16:48:53 +0100 Subject: [PATCH] fix inclusion of vllm_flash_attn python/compiled files --- setup.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 7566eeb5dc444..82c37d32dfd7d 100644 --- a/setup.py +++ b/setup.py @@ -496,15 +496,23 @@ def _read_requirements(filename: str) -> List[str]: f"Failed to get vLLM wheel from {wheel_location}") from exc with zipfile.ZipFile(wheel_filename) as wheel: - for lib in filter(lambda file: file.filename.endswith(".so"), - wheel.filelist): + for lib in filter( + lambda file: file.filename.endswith(".so") or file.filename. + startswith("vllm/vllm_flash_attn"), wheel.filelist): + print("Extracting and including {lib.filename} from existing wheel") package_name = os.path.dirname(lib.filename).replace("/", ".") + file_name = os.path.basename(lib.filename) + if package_name not in package_data: package_data[package_name] = [] wheel.extract(lib) - package_data[package_name].append(lib.filename) - print(f"Added {lib.filename} to package_data[\"{package_name}\"]") + if file_name.endswith(".py"): + # python files shouldn't be added to package_data + continue + + package_data[package_name].append(file_name) + if _no_device(): ext_modules = []