Skip to content

Commit

Permalink
add unit tests for AWSBatchRunner
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Oct 19, 2023
1 parent 3c298c9 commit ca70b22
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
Empty file added tests/aws_batch/__init__.py
Empty file.
49 changes: 49 additions & 0 deletions tests/aws_batch/test_aws_batch_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os
import unittest

from rastervision.pipeline import rv_config_ as rv_config
from rastervision.aws_batch.aws_batch_runner import AWSBatchRunner


class MockPipeline:
commands = ['test_cpu', 'test_gpu']
split_commands = ['test_cpu']
gpu_commands = ['test_gpu']


class TestAWSBatchRunner(unittest.TestCase):
def test_build_cmd(self):
pipeline = MockPipeline()
runner = AWSBatchRunner()
rv_config.set_verbosity(4)
cmd, args = runner.build_cmd(
'config.json',
pipeline, ['predict'],
num_splits=2,
pipeline_run_name='test')
cmd_expected = [
'python', '-m', 'rastervision.pipeline.cli', '-vvv', 'run_command',
'config.json', 'predict', '--runner', 'batch'
]
args_expected = {
'parent_job_ids': [],
'num_array_jobs': None,
'use_gpu': False,
'job_queue': None,
'job_def': None
}
self.assertListEqual(cmd, cmd_expected)
self.assertTrue(args['job_name'].startswith('test'))
del args['job_name']
self.assertDictEqual(args, args_expected)

def test_get_split_ind(self):
runner = AWSBatchRunner()
os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = '1'
self.assertEqual(runner.get_split_ind(), 1)
del os.environ['AWS_BATCH_JOB_ARRAY_INDEX']
self.assertEqual(runner.get_split_ind(), 0)


if __name__ == '__main__':
unittest.main()

0 comments on commit ca70b22

Please sign in to comment.