Skip to content

Commit

Permalink
added test for bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
SumanthRH committed Jan 30, 2024
1 parent 68d8931 commit 0b27bf9
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions tests/lm_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from ecco.lm import LM, _one_hot, sample_output_token, activations_dict_to_array
import ecco
import torch
import numpy as np
from transformers import PreTrainedModel
import torch
from transformers import PreTrainedModel

import ecco
from ecco.lm import _one_hot
from ecco.lm import activations_dict_to_array
from ecco.lm import sample_output_token


class TestLM:
Expand Down Expand Up @@ -58,6 +61,10 @@ def test_call_dummy_bert(self):
# If we do require padding, this CUDA compains with this model for some reason.
assert output.activations['encoder'].shape == (2, 1, 40, 3)

def test_half_prec(self):
# pass model kwargs
lm = ecco.from_pretrained('sshleifer/tiny-gpt2', activations=True, torch_dtype=torch.bfloat16)
assert lm.model.dtype == torch.bfloat16, f"Model dtype should be Bfloat16, got {lm.model.dtype}"

# TODO: Test LM Generate with Activation. Tweak to support batch dimension.
# def test_generate_token_no_attribution(self, mocker):
Expand Down

0 comments on commit 0b27bf9

Please sign in to comment.