diff --git a/applications/vision/alexnet.py b/applications/vision/alexnet.py index b67f2673714..fab9fbc36a6 100644 --- a/applications/vision/alexnet.py +++ b/applications/vision/alexnet.py @@ -22,6 +22,9 @@ parser.add_argument( '--num-classes', action='store', default=1000, type=int, help='number of ImageNet classes (default: 1000)', metavar='NUM') +parser.add_argument( + '--data-path', action='store', default=None, type=str, + help='Path to top-level imagenet directory. default: None') lbann.contrib.args.add_optimizer_arguments(parser) args = parser.parse_args() @@ -64,7 +67,8 @@ opt = lbann.contrib.args.create_optimizer(args) # Setup data reader -data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes) +data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes, + data_path=args.data_path) # Setup trainer trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size) diff --git a/applications/vision/data/imagenet/__init__.py b/applications/vision/data/imagenet/__init__.py index 636f625daf6..c6f05da0883 100644 --- a/applications/vision/data/imagenet/__init__.py +++ b/applications/vision/data/imagenet/__init__.py @@ -5,7 +5,7 @@ import lbann import lbann.contrib.launcher -def make_data_reader(num_classes=1000, small_testing=False): +def make_data_reader(num_classes=1000, small_testing=False, data_path=None): # Load Protobuf message from file current_dir = os.path.dirname(os.path.realpath(__file__)) @@ -18,27 +18,36 @@ def make_data_reader(num_classes=1000, small_testing=False): google.protobuf.text_format.Merge(f.read(), message) message = message.data_reader - # Paths to ImageNet data - # Note: Paths are only known for some compute centers - compute_center = lbann.contrib.launcher.compute_center() - if compute_center == 'lc': - from lbann.contrib.lc.paths import imagenet_dir, imagenet_labels - train_data_dir = imagenet_dir(data_set='train', - num_classes=num_classes) - train_label_file = imagenet_labels(data_set='train', - num_classes=num_classes) - test_data_dir = imagenet_dir(data_set='val', - num_classes=num_classes) - test_label_file = imagenet_labels(data_set='val', + + if data_path is not None: + print("Setting up data reader") + train_data_dir = os.path.join(data_path, 'train') + test_data_dir = os.path.join(data_path, 'val') + train_label_file = os.path.join(data_path, 'labels/train.txt') + test_label_file = os.path.join(data_path, 'labels/val.txt') + + elif lbann.contrib.launcher.compute_center() in ['lc', 'nersc']: + # Paths to ImageNet data + # Note: Paths are only known for some compute centers + compute_center = lbann.contrib.launcher.compute_center() + if compute_center == 'lc': + from lbann.contrib.lc.paths import imagenet_dir, imagenet_labels + train_data_dir = imagenet_dir(data_set='train', num_classes=num_classes) - elif compute_center == 'nersc': - from lbann.contrib.nersc.paths import imagenet_dir, imagenet_labels - train_data_dir = imagenet_dir(data_set='train') - train_label_file = imagenet_labels(data_set='train') - test_data_dir = imagenet_dir(data_set='val') - test_label_file = imagenet_labels(data_set='val') + train_label_file = imagenet_labels(data_set='train', + num_classes=num_classes) + test_data_dir = imagenet_dir(data_set='val', + num_classes=num_classes) + test_label_file = imagenet_labels(data_set='val', + num_classes=num_classes) + elif compute_center == 'nersc': + from lbann.contrib.nersc.paths import imagenet_dir, imagenet_labels + train_data_dir = imagenet_dir(data_set='train') + train_label_file = imagenet_labels(data_set='train') + test_data_dir = imagenet_dir(data_set='val') + test_label_file = imagenet_labels(data_set='val') else: - raise RuntimeError(f'ImageNet data paths are unknown for current compute center ({compute_center})') + raise RuntimeError(f'ImageNet data paths are unknown for current compute center ({compute_center}). Set "--data-path" to the location of your dataset.') # Check that data paths are accessible if not os.path.isdir(train_data_dir): diff --git a/applications/vision/densenet.py b/applications/vision/densenet.py index 5e27bdfd517..aae8b25d52f 100644 --- a/applications/vision/densenet.py +++ b/applications/vision/densenet.py @@ -428,6 +428,8 @@ def get_args(): parser.add_argument("--print-matrix-summary", dest="print_matrix_summary", action="store_const", const=True, default=False) + parser.add_argument('--data-path', action='store', default=None, type=str, + help='Path to top-level imagenet directory. default: None') args = parser.parse_args() return args @@ -438,7 +440,7 @@ def set_up_experiment(args, labels): algo = lbann.BatchedIterativeOptimizer("sgd", epoch_count=args.num_epochs) - + # Set up objective function cross_entropy = lbann.CrossEntropy([probs, labels]) layers = list(lbann.traverse_layer_graph(input_)) @@ -472,7 +474,9 @@ def set_up_experiment(args, callbacks=callbacks) # Set up data reader - data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes, small_testing=True) + data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes, + small_testing=True, + data_path=args.data_path) percentage = 0.001 * 2 * (args.mini_batch_size / 16) * 2 diff --git a/applications/vision/resnet.py b/applications/vision/resnet.py index 005204421fe..bb0a48fdbe7 100644 --- a/applications/vision/resnet.py +++ b/applications/vision/resnet.py @@ -50,6 +50,9 @@ parser.add_argument( '--random-seed', action='store', default=0, type=int, help='random seed for LBANN RNGs', metavar='NUM') +parser.add_argument( + '--data-path', action='store', default=None, type=str, + help='Path to top-level imagenet directory. default: None') lbann.contrib.args.add_optimizer_arguments(parser, default_learning_rate=0.1) args = parser.parse_args() @@ -145,7 +148,8 @@ opt = lbann.contrib.args.create_optimizer(args) # Setup data reader -data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes) +data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes, + data_path=args.data_path) # Setup trainer trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size, random_seed=args.random_seed) diff --git a/docs/data_ingestion.rst b/docs/data_ingestion.rst index 1f29fbee33c..71b96438ae0 100644 --- a/docs/data_ingestion.rst +++ b/docs/data_ingestion.rst @@ -1,3 +1,8 @@ +.. role:: bash(code) + :language: bash +.. role:: python(code) + :language: python + Data Ingestion ============== @@ -27,6 +32,14 @@ Legacy Data Readers Some of the legacy data readers are the ``MNIST``, ``ImageNet``, and ``CIFAR10`` data readers. +.. note:: The imagenet data reader uses a path that may not be known + to all compute centers. If the dataset is not found + :python:`--data-path` may be set to the top level of the data + set in :code:`resnet.py`, :code:`alexnet.py`, and + :code:`densenet.py`. The data set is must contain + :code:`labels/train.txt`, :code:`labels/val.txt`, + :code:`train/`, and :code:`val/`. + "New" Data Readers -------------------