diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index fbcfa871a6..2b41bea9bf 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -16,6 +16,7 @@ from litgpt.config import Config from litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load, save_config +from safetensors.torch import load_file as load_safetensors def copy_weights_gpt_neox( @@ -556,13 +557,13 @@ def convert_hf_checkpoint( elif model_safetensor_map_json_path.is_file(): with open(model_safetensor_map_json_path, encoding="utf-8") as json_map: bin_index = json.load(json_map) - bin_files = {checkpoint_dir / Path(bin).with_suffix(".bin") for bin in bin_index["weight_map"].values()} + bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} else: - bin_files = set(checkpoint_dir.glob("*.bin")) + bin_files = set(checkpoint_dir.glob("*.bin")) | set(checkpoint_dir.glob("*.safetensors")) # some checkpoints serialize the training arguments bin_files = {f for f in bin_files if f.name != "training_args.bin"} if not bin_files: - raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files") + raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin or .safetensors files") with incremental_save(checkpoint_dir / "lit_model.pth") as saver: # for checkpoints that split the QKV across several files, we need to keep all the bin files @@ -584,16 +585,8 @@ def convert_hf_checkpoint( current_file_size = os.path.getsize(bin_file) progress_per_file = (current_file_size / total_size) * total_progress - hf_weights = lazy_load(bin_file) - copy_fn( - sd, - hf_weights, - saver=saver, - dtype=dtype, - pbar=pbar, - progress_per_file=progress_per_file, - debug_mode=debug_mode, - ) + hf_weights = load_safetensors(bin_file) if bin_file.suffix == ".safetensors" else lazy_load(bin_file) + copy_fn(sd, hf_weights, saver=saver, dtype=dtype, pbar=pbar, progress_per_file=progress_per_file, debug_mode=debug_mode) gc.collect() if pbar.n < total_progress: @@ -602,7 +595,7 @@ def convert_hf_checkpoint( else: # Handling files without progress bar in debug mode for bin_file in sorted(bin_files): - hf_weights = lazy_load(bin_file) + hf_weights = load_safetensors(bin_file) if bin_file.suffix == ".safetensors" else lazy_load(bin_file) copy_fn(sd, hf_weights, saver=saver, dtype=dtype, debug_mode=debug_mode) print(f"Saving converted checkpoint to {checkpoint_dir}") diff --git a/litgpt/scripts/download.py b/litgpt/scripts/download.py index fc6c153fad..7ab609b30f 100644 --- a/litgpt/scripts/download.py +++ b/litgpt/scripts/download.py @@ -58,7 +58,6 @@ def download_from_hub( from huggingface_hub import snapshot_download download_files = ["tokenizer*", "generation_config.json", "config.json"] - from_safetensors = False if not tokenizer_only: bins, safetensors = find_weight_files(repo_id, access_token) if bins: @@ -68,7 +67,6 @@ def download_from_hub( if not _SAFETENSORS_AVAILABLE: raise ModuleNotFoundError(str(_SAFETENSORS_AVAILABLE)) download_files.append("*.safetensors*") - from_safetensors = True else: raise ValueError(f"Couldn't find weight files for {repo_id}") @@ -93,37 +91,11 @@ def download_from_hub( constants.HF_HUB_ENABLE_HF_TRANSFER = previous download.HF_HUB_ENABLE_HF_TRANSFER = previous - if from_safetensors: - print("Converting .safetensor files to PyTorch binaries (.bin)") - safetensor_paths = list(directory.glob("*.safetensors")) - with ProcessPoolExecutor() as executor: - executor.map(convert_safetensors_file, safetensor_paths) - if convert_checkpoint and not tokenizer_only: print("Converting checkpoint files to LitGPT format.") convert_hf_checkpoint(checkpoint_dir=directory, dtype=dtype, model_name=model_name) -def convert_safetensors_file(safetensor_path: Path) -> None: - from safetensors import SafetensorError - from safetensors.torch import load_file as safetensors_load - - bin_path = safetensor_path.with_suffix(".bin") - try: - result = safetensors_load(safetensor_path) - except SafetensorError as e: - raise RuntimeError(f"{safetensor_path} is likely corrupted. Please try to re-download it.") from e - print(f"{safetensor_path} --> {bin_path}") - torch.save(result, bin_path) - try: - os.remove(safetensor_path) - except PermissionError: - print( - f"Unable to remove {safetensor_path} file. " - "This file is no longer needed and you may want to delete it manually to save disk space." - ) - - def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[str], List[str]]: from huggingface_hub import repo_info from huggingface_hub.utils import filter_repo_objects