Skip to content

Commit

Permalink
fix: pylint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
abdullah-ibm committed Nov 7, 2024
1 parent 1feb135 commit da624f4
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,28 @@
from accelerate import Accelerator

try:
# Third Party
from deepspeed.ops.adam import DeepSpeedCPUAdam
except ImportError:
DeepSpeedCPUAdam = None
local_rank = int(os.getenv('LOCAL_RANK', None))
if __name__ == '__main__' and (not local_rank or local_rank == 0):
print("DeepSpeed CPU Optimizer is not available. Some features may be unavailable.")
local_rank = int(os.getenv("LOCAL_RANK", "None"))
if __name__ == "__main__" and (not local_rank or local_rank == 0):
print(
"DeepSpeed CPU Optimizer is not available. Some features may be unavailable."
)

try:
# Third Party
from deepspeed.ops.adam import FusedAdam
from deepspeed.runtime.zero.utils import ZeRORuntimeException
except ImportError:
FusedAdam = None
ZeRORuntimeException = None
local_rank = int(os.getenv('LOCAL_RANK', None))
if __name__ == '__main__' and (not local_rank or local_rank == 0):
local_rank = int(os.getenv("LOCAL_RANK", "None"))
if __name__ == "__main__" and (not local_rank or local_rank == 0):
print("DeepSpeed is not available. Some features may be unavailable.")

# pylint: disable=no-name-in-module
from instructlab.training.confg import DistributedBackend
# Third Party
from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM
from tqdm import tqdm
from transformers import AutoModelForCausalLM, get_scheduler
Expand All @@ -43,6 +46,8 @@
# First Party
from instructlab.training import config
from instructlab.training.async_logger import AsyncStructuredLogger

# pylint: disable=no-name-in-module
from instructlab.training.config import (
DataProcessArgs,
DistributedBackend,
Expand Down Expand Up @@ -533,11 +538,19 @@ def main(args):
# Third Party
import yaml

if args.distributed_training_framework == 'deepspeed' and not FusedAdam:
raise ImportError("DeepSpeed was selected but we cannot import the `FusedAdam` optimizer")
if args.distributed_training_framework == "deepspeed" and not FusedAdam:
raise ImportError(
"DeepSpeed was selected but we cannot import the `FusedAdam` optimizer"
)

if args.distributed_training_framework == 'deepspeed' and args.cpu_offload_optimizer and not DeepSpeedCPUAdam:
raise ImportError("DeepSpeed was selected and CPU offloading was requested, but DeepSpeedCPUAdam could not be imported. This likely means you need to build DeepSpeed with the CPU adam flags.")
if (
args.distributed_training_framework == "deepspeed"
and args.cpu_offload_optimizer
and not DeepSpeedCPUAdam
):
raise ImportError(
"DeepSpeed was selected and CPU offloading was requested, but DeepSpeedCPUAdam could not be imported. This likely means you need to build DeepSpeed with the CPU adam flags."
)

metric_logger = AsyncStructuredLogger(
args.output_dir
Expand Down Expand Up @@ -761,11 +774,14 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
# deepspeed options
if train_args.distributed_backend == DistributedBackend.DeepSpeed:
if not FusedAdam:
raise ImportError("DeepSpeed was selected as the distributed backend, but FusedAdam could not be imported. Please double-check that DeepSpeed is installed correctly")
raise ImportError(
"DeepSpeed was selected as the distributed backend, but FusedAdam could not be imported. Please double-check that DeepSpeed is installed correctly"
)

if train_args.deepspeed_options.cpu_offload_optimizer and not DeepSpeedCPUAdam:
raise ImportError("DeepSpeed CPU offloading was enabled, but DeepSpeedCPUAdam could not be imported. This is most likely because DeepSpeed was not built with CPU Adam. Please rebuild DeepSpeed to have CPU Adam, or disable CPU offloading.")

raise ImportError(
"DeepSpeed CPU offloading was enabled, but DeepSpeedCPUAdam could not be imported. This is most likely because DeepSpeed was not built with CPU Adam. Please rebuild DeepSpeed to have CPU Adam, or disable CPU offloading."
)
if train_args.deepspeed_options.save_samples:
command.append(f"--save_samples_ds={train_args.deepspeed_options.save_samples}")
if train_args.deepspeed_options.cpu_offload_optimizer:
Expand Down

0 comments on commit da624f4

Please sign in to comment.