Skip to content

Commit

Permalink
Pull request itlab-vision#153: [AI-9579] Fixes for pytorch cpp
Browse files Browse the repository at this point in the history
Merge in AT/dl-benchmark from malibekov/AI-9579/pytorch_cpp_fix to develop
  • Loading branch information
Murad Alibekov authored and AlibekovMurad5202 committed Nov 21, 2023
1 parent 4fd74e4 commit 87f7421
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ void PytorchLauncher::read(const std::string& model_file, const std::string& wei
}

torch::Device torch_device(device_type, 0);
module = torch::jit::load(model_file, torch_device);
try {
module = torch::jit::load(model_file, torch_device);
} catch (const c10::Error& e) {
throw std::runtime_error("Failed to read model " + model_file);
}
module.eval();
}

Expand Down
6 changes: 2 additions & 4 deletions src/inference/inference_pytorch_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def create_dict_from_args_for_process(args):

def prepare_images_for_benchmark(inputs, tmp_dir):
if os.path.isdir(inputs[0]):
return inputs
return inputs[0]
for path in inputs[0].split(','):
shutil.copy2(path, tmp_dir)
return tmp_dir
Expand Down Expand Up @@ -241,9 +241,7 @@ def main():
io.prepare_input(compiled_model, args.input)

log.info('Preparing images for benchmark in temporary directory')
prepare_images_for_benchmark(args.input, tmp_input.name)

args.input = tmp_input.name
args.input = prepare_images_for_benchmark(args.input, tmp_input.name)

log.info('Initializing PyTorch process')
proc = PyTorchProcess()
Expand Down

0 comments on commit 87f7421

Please sign in to comment.