Skip to content

Commit

Permalink
cherry pick refit error 3170 from main to release/2.5 branch (#3236)
Browse files Browse the repository at this point in the history
Co-authored-by: cehongwang <[email protected]>
Co-authored-by: Evan Li <[email protected]>
  • Loading branch information
3 people authored Oct 15, 2024
1 parent cd12eb4 commit e314ad6
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 21 deletions.
1 change: 1 addition & 0 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
make_refittable=True,
reuse_cached_engines=False,
) # Output is a torch.fx.GraphModule

# Save the graph module as an exported program
Expand Down
12 changes: 9 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,12 +476,18 @@ def _save_weight_mapping(self) -> None:
# Retrieve each weight name(s) in state_dict
if layer_type == "CONSTANT":
if "embedding" in suffix:
sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}"
sd_weight_name = f"{sd_weight_name}.weight"
elif "weight" in suffix or "mm_other" in suffix:
# Linear layer weight
sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}"
sd_weight_name = f"{sd_weight_name}.weight"
elif "running_mean" in suffix:
# Linear layer weight
sd_weight_name = f"{sd_weight_name}.running_mean"
elif "running_var" in suffix:
# Linear layer weight
sd_weight_name = f"{sd_weight_name}.running_var"
else:
sd_weight_name = f"{sd_weight_name}.{torch_attr[1]}"
sd_weight_name = f"{sd_weight_name}.bias"
elif layer_type == "SCALE":
# Batch norm needs all weights to calculate scale and shift
sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr]
Expand Down
29 changes: 21 additions & 8 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,27 @@ def batch_norm(
# Save the original output shape for later use
output_shape = input.shape

if weight is None:
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
if bias is None:
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
if running_mean is None:
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
if running_var is None:
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")
# We name the weight here according to the state_dict name
weight = (
get_trt_tensor(ctx, 1.0, f"{name}_weight")
if weight is None
else get_trt_tensor(ctx, weight, f"{name}_weight")
)
bias = (
get_trt_tensor(ctx, 0.0, f"{name}_bias")
if bias is None
else get_trt_tensor(ctx, bias, f"{name}_bias")
)
running_mean = (
get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
if running_mean is None
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
)
running_var = (
get_trt_tensor(ctx, 1.0, f"{name}_running_var")
if running_var is None
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
)

# eps_tensor for numerical stability
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")
Expand Down
33 changes: 23 additions & 10 deletions tests/py/dynamo/models/test_model_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
@pytest.mark.unit
def test_mapping():

model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
trt_input = [
torchtrt.Input(i.shape, dtype=torch.float, format=torch.contiguous_format)
Expand All @@ -58,6 +58,7 @@ def test_mapping():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)
settings = trt_gm._run_on_acc_0.settings
runtime = trt.Runtime(TRT_LOGGER)
Expand Down Expand Up @@ -110,6 +111,7 @@ def test_refit_one_engine_with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -141,8 +143,8 @@ def test_refit_one_engine_with_weightmap():
@pytest.mark.unit
def test_refit_one_engine_no_map_with_weightmap():

model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
Expand All @@ -160,6 +162,7 @@ def test_refit_one_engine_no_map_with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

trt_gm._run_on_acc_0.weight_name_map = None
Expand Down Expand Up @@ -192,8 +195,8 @@ def test_refit_one_engine_no_map_with_weightmap():
@pytest.mark.unit
def test_refit_one_engine_with_wrong_weightmap():

model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
Expand All @@ -211,6 +214,7 @@ def test_refit_one_engine_with_wrong_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)
# Manually Deleted all batch norm layer. This suppose to fail the fast refit
trt_gm._run_on_acc_0.weight_name_map = {
Expand Down Expand Up @@ -268,6 +272,7 @@ def test_refit_one_engine_bert_with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -302,8 +307,8 @@ def test_refit_one_engine_bert_with_weightmap():
@pytest.mark.unit
def test_refit_one_engine_inline_runtime__with_weightmap():
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
Expand All @@ -321,6 +326,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)
torchtrt.save(trt_gm, trt_ep_path, inputs=inputs)
trt_gm = torch.export.load(trt_ep_path)
Expand Down Expand Up @@ -348,8 +354,8 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
@pytest.mark.unit
def test_refit_one_engine_python_runtime_with_weightmap():

model = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
enabled_precisions = {torch.float}
debug = False
Expand All @@ -367,6 +373,7 @@ def test_refit_one_engine_python_runtime_with_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -438,6 +445,7 @@ def forward(self, x):
min_block_size=min_block_size,
make_refittable=True,
torch_executed_ops=torch_executed_ops,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -487,6 +495,7 @@ def test_refit_one_engine_without_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -538,6 +547,7 @@ def test_refit_one_engine_bert_without_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -591,6 +601,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)
torchtrt.save(trt_gm, trt_ep_path, inputs=inputs)
trt_gm = torch.export.load(trt_ep_path)
Expand Down Expand Up @@ -637,6 +648,7 @@ def test_refit_one_engine_python_runtime_without_weightmap():
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down Expand Up @@ -708,6 +720,7 @@ def forward(self, x):
min_block_size=min_block_size,
make_refittable=True,
torch_executed_ops=torch_executed_ops,
reuse_cached_engines=False,
)

new_trt_gm = refit_module_weights(
Expand Down

0 comments on commit e314ad6

Please sign in to comment.