diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 3eaa388..16c1582 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1127,9 +1127,8 @@ def _test_exception_decrypts(self): # even in the event of an exception encryption = EncryptionParams.random() with torch.device("cpu"): - model = AutoModelForCausalLM.from_pretrained( - model_name, device_map="cpu" - ) + model = AutoModelForCausalLM.from_pretrained(model_name) + self.assertEqual(model.device.type, "cpu") model_sd = model.state_dict() model_clone = {k: v.detach().clone() for k, v in model_sd.items()}