diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index a5c53b222..3fdd1c1ed 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -10,7 +10,18 @@ """ import logging import os -from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload +from typing import ( + Dict, + List, + NamedTuple, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) import einops import numpy as np @@ -67,6 +78,8 @@ "bf16": torch.bfloat16, } +T = TypeVar("T", bound="HookedTransformer") + class Output(NamedTuple): """Output Named Tuple. @@ -1053,7 +1066,7 @@ def move_model_modules_to_device(self): @classmethod def from_pretrained( - cls, + cls: Type[T], model_name: str, fold_ln: bool = True, center_writing_weights: bool = True, @@ -1072,7 +1085,7 @@ def from_pretrained( dtype="float32", first_n_layers: Optional[int] = None, **from_pretrained_kwargs, - ) -> "HookedTransformer": + ) -> T: """Load in a Pretrained Model. Load in pretrained model weights to the HookedTransformer format and optionally to do some diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 421d35e15..d68bc561f 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -59,7 +59,7 @@ def download_file_from_hf( ) if file_path.endswith(".pth") or force_is_torch: - return torch.load(file_path, map_location="cpu") + return torch.load(file_path, map_location="cpu", weights_only=False) elif file_path.endswith(".json"): return json.load(open(file_path, "r")) else: