Skip to content

Commit

Permalink
smol vocab size
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Nov 30, 2024
1 parent 0736c58 commit f2317ce
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
36 changes: 15 additions & 21 deletions tests/algorithms/test_algorithm_resumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
@pytest.mark.parametrize('alg_cls', get_algs_with_marks())
@pytest.mark.filterwarnings(
'ignore:Detected call of `lr_scheduler.step()',
) # optimizer.step() sometimes skipped when NaN/inf on low batch size
@pytest.mark.filterwarnings(r'ignore:.*Plan failed with a cudnnException.*:UserWarning') # Torch 2.3 regression
)
@pytest.mark.filterwarnings(r'ignore:.*Plan failed with a cudnnException.*:UserWarning')
@world_size(1, 2)
def test_algorithm_resumption(
tmp_path: pathlib.Path,
Expand Down Expand Up @@ -54,14 +54,14 @@ def test_algorithm_resumption(
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)

shared_config = {
'max_duration': '2ep',
'save_filename': 'ep{epoch}-rank{rank}',
'save_interval': '1ep',
'train_subset_num_batches': 2,
'max_duration': '2ba',
'save_filename': 'ba{batch}-rank{rank}',
'save_interval': '1ba',
'train_subset_num_batches': 1,
'precision': 'amp_bf16',
}
train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True)
# train model once, saving checkpoints every epoch
# train model once, saving checkpoints every batch
trainer1 = Trainer(
model=model,
train_dataloader=train_dataloader,
Expand All @@ -73,24 +73,19 @@ def test_algorithm_resumption(
)
trainer1.fit()

# create second trainer, load an intermediate checkpoint
# and continue training

# create second trainer, load from the first batch checkpoint, and continue training
optimizer = torch.optim.Adam(copied_model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)

alg = alg_cls(**alg_kwargs)
# SeqLengthWarmup has a call to ._activate_model() that happens on the first call to the algorithm
# in order to get complete matching of the rng state, we have to cause that extra call to be skipped
# when reloading.
if alg_cls is SeqLengthWarmup:
alg._activated = True # type: ignore

train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True)
trainer2 = Trainer(
model=copied_model,
train_dataloader=train_dataloader,
load_path=os.path.join(folder1, 'ep1-rank{rank}'),
load_path=os.path.join(folder1, 'ba1-rank{rank}'),
load_weights_only=False,
load_strict_model_weights=False,
optimizers=optimizer,
Expand All @@ -100,20 +95,19 @@ def test_algorithm_resumption(
**shared_config,
)
trainer2.fit()
# check that the checkpoints are equal
# check that the checkpoints after the second batch are equal
if world_size == 1 or dist.get_global_rank() == 0:
_assert_checkpoints_equal(
file1=os.path.join(folder1, 'ep2-rank0'),
file2=os.path.join(folder2, 'ep2-rank0'),
file1=os.path.join(folder1, 'ba2-rank0'),
file2=os.path.join(folder2, 'ba2-rank0'),
)

# check that different epoch checkpoints are _not_ equal
# this ensures that the model weights are being updated.
# ensure that the first and second batch checkpoints are not equal
if world_size == 1 or dist.get_global_rank() == 0:
with pytest.raises(AssertionError):
_assert_model_weights_equal(
file1=os.path.join(folder1, 'ep1-rank0'),
file2=os.path.join(folder1, 'ep2-rank0'),
file1=os.path.join(folder1, 'ba1-rank0'),
file2=os.path.join(folder1, 'ba2-rank0'),
)


Expand Down
7 changes: 4 additions & 3 deletions tests/test_full_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,12 @@ def test_full_nlp_pipeline(
algorithms = [algorithm() for algorithm in algorithms]
device = get_device(device)
config = None
small_vocab_size = 1024
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', model_max_length=128)
if model_type == 'tinybert_hf':
# Updated minimal BERT configuration
config = BertConfig(
vocab_size=30522,
vocab_size=small_vocab_size,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=2,
Expand All @@ -280,7 +281,7 @@ def test_full_nlp_pipeline(
metrics=pretraining_metrics,
)
elif model_type == 'simpletransformer':
pretraining_model = SimpleTransformerMaskedLM(vocab_size=30522)
pretraining_model = SimpleTransformerMaskedLM(vocab_size=small_vocab_size)
else:
raise ValueError('Unsupported model type')
pretraining_output_path = pretraining_test_helper(
Expand All @@ -302,7 +303,7 @@ def test_full_nlp_pipeline(
)
elif model_type == 'simpletransformer':
finetuning_model = SimpleTransformerClassifier(
vocab_size=30522,
vocab_size=small_vocab_size,
num_classes=3,
)
else:
Expand Down

0 comments on commit f2317ce

Please sign in to comment.