-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support for mcore optimizer (to enable MoE) #380
base: dev
Are you sure you want to change the base?
Changes from 3 commits
1c98a08
6214591
4f60e87
7598f24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,31 +101,52 @@ def prepare_for_training_step(ptl_model, zero_grad=True): | |
param.data_ptr() | ||
|
||
|
||
# TODO: Delete this once API introduced in NeMo (https://github.com/NVIDIA/NeMo/pull/10803) | ||
# TODO: Update PR to move this logic into staticmethod in nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py | ||
def grad_reductions(ptl_model): | ||
# when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced | ||
if ptl_model.cfg.get("tensor_model_parallel_size", 1) > 1 and ptl_model.cfg.get("sequence_parallel", False): | ||
ptl_model.allreduce_sequence_parallel_gradients() | ||
|
||
if ptl_model.with_distributed_adam: | ||
# synchronize asynchronous grad reductions | ||
# note: not necessary, but reduces performance degradation | ||
# from multiple simultaneous NCCL calls | ||
ptl_model._optimizer._finish_bucket_grad_sync() | ||
# Mcore DistOpt handles this, so we don't have to | ||
if not ptl_model.use_mcore_dist_optim: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how do you feel about dropping support for non mcore dist optim? are they equivalent to apex now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea, i want to do that in a follow up PR (would help our build times immensely). This just adds the feature without breaking apex |
||
ptl_model.megatron_timer_start("allreduce_sequence_parallel_gradients", log_level=1) | ||
ptl_model.allreduce_sequence_parallel_gradients() | ||
ptl_model.megatron_timer_stop("allreduce_sequence_parallel_gradients") | ||
|
||
ptl_model.megatron_timer_start("gradient_allreduce", log_level=1) | ||
if ptl_model.use_fsdp: | ||
# Reduce the gradients omitted from FSDP-sharding | ||
ptl_model.allreduce_fsdp_sharding_omitted_gradients() | ||
elif ptl_model.with_distributed_adam: | ||
if not ptl_model.use_mcore_dist_optim: | ||
# synchronize asynchronous grad reductions | ||
# note: not necessary, but reduces performance degradation | ||
# from multiple simultaneous NCCL calls | ||
ptl_model._optimizer._finish_bucket_grad_sync() | ||
# else: Mcore distributed optim calls finalize_model_grads to finish grad sync | ||
elif ptl_model.megatron_amp_O2: | ||
# when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) | ||
if ptl_model.cfg.get("pipeline_model_parallel_size", 1) > 1 or ptl_model.cfg.get("sequence_parallel", False): | ||
if ( | ||
ptl_model.cfg.get("pipeline_model_parallel_size", 1) > 1 | ||
or ptl_model.cfg.get("sequence_parallel", False) | ||
or not ptl_model.cfg.get("async_grad_allreduce", True) | ||
): | ||
# main grads are stored in the MainParamsOptimizer wrapper | ||
ptl_model._optimizer.allreduce_main_grads() | ||
else: | ||
# async grad allreduce is not currently implemented for O1/autocasting mixed precision training | ||
# so we all-reduce gradients after the pipeline | ||
ptl_model.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) | ||
ptl_model.megatron_timer_stop("gradient_allreduce") | ||
gshennvm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if ptl_model.cfg.get("pipeline_model_parallel_size", 1) > 1 and ptl_model.cfg.get( | ||
"share_embeddings_and_output_weights", True | ||
if ( | ||
not ptl_model.use_mcore_dist_optim | ||
and ptl_model.cfg.get("pipeline_model_parallel_size", 1) > 1 | ||
and ptl_model.cfg.get("share_embeddings_and_output_weights", True) | ||
): | ||
ptl_model.megatron_timer_start("allreduce_first_last_embeddings", log_level=1) | ||
# when using pipeline parallelism the first and last stage must keep embeddings in sync | ||
ptl_model.allreduce_first_last_embeddings() | ||
ptl_model.megatron_timer_stop("allreduce_first_last_embeddings") | ||
|
||
|
||
def prepare_for_validation_step(ptl_model): | ||
|
@@ -155,14 +176,26 @@ def set_eval(ptl_model): | |
ptl_model.eval() | ||
|
||
|
||
# TODO: adapt the version in /opt/NeMo/nemo/collections/nlp/models/language_modeling/megatron_base_model.py | ||
def clip_gradients(ptl_model, clip_val): | ||
"""PTL hook to configure gradients. | ||
We use gradient clipping implementation from megatron-lm. | ||
""" | ||
if clip_val is None: | ||
return | ||
|
||
clip_val = float(clip_val) | ||
if clip_val <= 0: | ||
return | ||
|
||
if ptl_model.with_megatron_fused_adam or ptl_model.use_mcore_dist_optim: | ||
# Gradient clipping is done in optimizer step | ||
return | ||
|
||
if ptl_model.grad_clip_pl_default: | ||
# use the default behavior | ||
return super().configure_gradient_clipping(*args, **kwargs) | ||
|
||
if ptl_model.with_distributed_adam: | ||
grad_norm = clip_grad_norm_distributed_optimizer(ptl_model._optimizer, clip_val) | ||
else: | ||
|
@@ -171,6 +204,5 @@ def clip_gradients(ptl_model, clip_val): | |
parameters = ptl_model._optimizer.get_parameters_with_grad() | ||
else: | ||
parameters = ptl_model.get_parameters_with_grad() | ||
grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val) | ||
|
||
grad_norm = clip_grad_norm_fp32(parameters=parameters, max_norm=clip_val, use_fsdp=ptl_model.use_fsdp,) | ||
return grad_norm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have to specify the closure now? i thought this was optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for mcore dist opt it's required, so i just set it everywhere