Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip converting .safetensors to .bin #1853

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from litgpt.config import Config
from litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load, save_config
from safetensors.torch import load_file as load_safetensors


def copy_weights_gpt_neox(
Expand Down Expand Up @@ -556,13 +557,13 @@ def convert_hf_checkpoint(
elif model_safetensor_map_json_path.is_file():
with open(model_safetensor_map_json_path, encoding="utf-8") as json_map:
bin_index = json.load(json_map)
bin_files = {checkpoint_dir / Path(bin).with_suffix(".bin") for bin in bin_index["weight_map"].values()}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
else:
bin_files = set(checkpoint_dir.glob("*.bin"))
bin_files = set(checkpoint_dir.glob("*.bin")) | set(checkpoint_dir.glob("*.safetensors"))
# some checkpoints serialize the training arguments
bin_files = {f for f in bin_files if f.name != "training_args.bin"}
if not bin_files:
raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files")
raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin or .safetensors files")

with incremental_save(checkpoint_dir / "lit_model.pth") as saver:
# for checkpoints that split the QKV across several files, we need to keep all the bin files
Expand All @@ -584,16 +585,8 @@ def convert_hf_checkpoint(
current_file_size = os.path.getsize(bin_file)
progress_per_file = (current_file_size / total_size) * total_progress

hf_weights = lazy_load(bin_file)
copy_fn(
sd,
hf_weights,
saver=saver,
dtype=dtype,
pbar=pbar,
progress_per_file=progress_per_file,
debug_mode=debug_mode,
)
hf_weights = load_safetensors(bin_file) if bin_file.suffix == ".safetensors" else lazy_load(bin_file)
copy_fn(sd, hf_weights, saver=saver, dtype=dtype, pbar=pbar, progress_per_file=progress_per_file, debug_mode=debug_mode)
gc.collect()

if pbar.n < total_progress:
Expand All @@ -602,7 +595,7 @@ def convert_hf_checkpoint(
else:
# Handling files without progress bar in debug mode
for bin_file in sorted(bin_files):
hf_weights = lazy_load(bin_file)
hf_weights = load_safetensors(bin_file) if bin_file.suffix == ".safetensors" else lazy_load(bin_file)
copy_fn(sd, hf_weights, saver=saver, dtype=dtype, debug_mode=debug_mode)

print(f"Saving converted checkpoint to {checkpoint_dir}")
Expand Down
28 changes: 0 additions & 28 deletions litgpt/scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def download_from_hub(
from huggingface_hub import snapshot_download

download_files = ["tokenizer*", "generation_config.json", "config.json"]
from_safetensors = False
if not tokenizer_only:
bins, safetensors = find_weight_files(repo_id, access_token)
if bins:
Expand All @@ -68,7 +67,6 @@ def download_from_hub(
if not _SAFETENSORS_AVAILABLE:
raise ModuleNotFoundError(str(_SAFETENSORS_AVAILABLE))
download_files.append("*.safetensors*")
from_safetensors = True
else:
raise ValueError(f"Couldn't find weight files for {repo_id}")

Expand All @@ -93,37 +91,11 @@ def download_from_hub(
constants.HF_HUB_ENABLE_HF_TRANSFER = previous
download.HF_HUB_ENABLE_HF_TRANSFER = previous

if from_safetensors:
print("Converting .safetensor files to PyTorch binaries (.bin)")
safetensor_paths = list(directory.glob("*.safetensors"))
with ProcessPoolExecutor() as executor:
executor.map(convert_safetensors_file, safetensor_paths)

if convert_checkpoint and not tokenizer_only:
print("Converting checkpoint files to LitGPT format.")
convert_hf_checkpoint(checkpoint_dir=directory, dtype=dtype, model_name=model_name)


def convert_safetensors_file(safetensor_path: Path) -> None:
from safetensors import SafetensorError
from safetensors.torch import load_file as safetensors_load

bin_path = safetensor_path.with_suffix(".bin")
try:
result = safetensors_load(safetensor_path)
except SafetensorError as e:
raise RuntimeError(f"{safetensor_path} is likely corrupted. Please try to re-download it.") from e
print(f"{safetensor_path} --> {bin_path}")
torch.save(result, bin_path)
try:
os.remove(safetensor_path)
except PermissionError:
print(
f"Unable to remove {safetensor_path} file. "
"This file is no longer needed and you may want to delete it manually to save disk space."
)


def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[str], List[str]]:
from huggingface_hub import repo_info
from huggingface_hub.utils import filter_repo_objects
Expand Down
Loading