Skip to content

Commit

Permalink
[Torch FX] [SD3] Minor Fixes (#2615)
Browse files Browse the repository at this point in the history
1. Reduced Image size generation to speedup calibration data collection.
2. Minor fixes to the notebook for better UX.
  • Loading branch information
anzr299 authored Jan 6, 2025
1 parent 257bee3 commit 776076c
Showing 1 changed file with 39 additions and 32 deletions.
71 changes: 39 additions & 32 deletions notebooks/stable-diffusion-v3/stable-diffusion-v3-torch-fx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -129,7 +129,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -179,7 +179,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -209,7 +209,11 @@
"slideshow": {
"slide_type": ""
},
"tags": []
"tags": [],
"test_replace": {
"height=512,": "",
"width=512": ""
}
},
"outputs": [],
"source": [
Expand All @@ -221,18 +225,9 @@
"num_inference_steps = 28\n",
"with torch.no_grad():\n",
" image = pipe(\n",
" prompt=prompt,\n",
" negative_prompt=\"\",\n",
" num_inference_steps=num_inference_steps,\n",
" generator=generator,\n",
" guidance_scale=5,\n",
" prompt=prompt, negative_prompt=\"\", num_inference_steps=num_inference_steps, generator=generator, guidance_scale=5, height=512, width=512\n",
" ).images[0]\n",
"image.resize(\n",
" (\n",
" 512,\n",
" 512,\n",
" )\n",
")"
"image"
]
},
{
Expand Down Expand Up @@ -276,10 +271,10 @@
},
"tags": [],
"test_replace": {
"torch.ones((1, 16, 128, 128))": "torch.ones((1, 16, 32, 32))",
"torch.ones((1, 3, 128, 128))": "torch.ones((1, 3, 32, 32))",
"torch.ones((1, 16, 64, 64))": "torch.ones((1, 16, 32, 32))",
"torch.ones((1, 3, 64, 64))": "torch.ones((1, 3, 32, 32))",
"torch.ones((2, 154, 4096))": "torch.ones((2, 154, 32))",
"torch.ones((2, 16, 128, 128))": "torch.ones((2, 16, 32, 32))",
"torch.ones((2, 16, 64, 64))": "torch.ones((2, 16, 32, 32))",
"torch.ones((2, 2048))": "torch.ones((2, 64))"
}
},
Expand All @@ -292,11 +287,11 @@
"text_encoder_kwargs = {}\n",
"text_encoder_kwargs[\"output_hidden_states\"] = True\n",
"\n",
"vae_encoder_input = torch.ones((1, 3, 128, 128))\n",
"vae_decoder_input = torch.ones((1, 16, 128, 128))\n",
"vae_encoder_input = torch.ones((1, 3, 64, 64))\n",
"vae_decoder_input = torch.ones((1, 16, 64, 64))\n",
"\n",
"unet_kwargs = {}\n",
"unet_kwargs[\"hidden_states\"] = torch.ones((2, 16, 128, 128))\n",
"unet_kwargs[\"hidden_states\"] = torch.ones((2, 16, 64, 64))\n",
"unet_kwargs[\"timestep\"] = torch.from_numpy(np.array([1, 2], dtype=np.float32))\n",
"unet_kwargs[\"encoder_hidden_states\"] = torch.ones((2, 154, 4096))\n",
"unet_kwargs[\"pooled_projections\"] = torch.ones((2, 2048))\n",
Expand Down Expand Up @@ -370,7 +365,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -404,8 +399,9 @@
},
"tags": [],
"test_replace": {
"calibration_dataset_size = 300": "calibration_dataset_size = 1",
"init_pipeline(models_dict, configs_dict)": "init_pipeline(models_dict, configs_dict, \"katuni4ka/tiny-random-sd3\")"
"calibration_dataset_size = 200": "calibration_dataset_size = 1",
"init_pipeline(models_dict, configs_dict)": "init_pipeline(models_dict, configs_dict, \"katuni4ka/tiny-random-sd3\")",
"pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512)": "pipe(prompt, num_inference_steps=num_inference_steps)"
}
},
"outputs": [],
Expand Down Expand Up @@ -464,7 +460,7 @@
" if len(prompt) > pipe.tokenizer.model_max_length:\n",
" continue\n",
" # Run the pipeline\n",
" pipe(prompt, num_inference_steps=num_inference_steps)\n",
" pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512)\n",
" calibration_data.extend(wrapped_unet.captured_args)\n",
" wrapped_unet.captured_args = []\n",
" pbar.update(len(calibration_data) - pbar.n)\n",
Expand All @@ -478,7 +474,7 @@
"\n",
"if to_quantize:\n",
" pipe = init_pipeline(models_dict, configs_dict)\n",
" calibration_dataset_size = 300\n",
" calibration_dataset_size = 200\n",
" unet_calibration_data = collect_calibration_data(\n",
" pipe, calibration_dataset_size=calibration_dataset_size, num_inference_steps=28\n",
" )\n",
Expand Down Expand Up @@ -659,15 +655,19 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"test_replace": {
"prompt=prompt, negative_prompt=\"\", num_inference_steps=1, generator=generator, height=512, width=512": "prompt=prompt, negative_prompt=\"\", num_inference_steps=1, generator=generator"
}
},
"outputs": [],
"source": [
"%%skip not $to_quantize.value\n",
"\n",
"# Warmup the model for initial compile\n",
"with torch.no_grad():\n",
" image = opt_pipe(\n",
" prompt=prompt, negative_prompt=\"\", num_inference_steps=1, generator=generator\n",
" opt_pipe(\n",
" prompt=prompt, negative_prompt=\"\", num_inference_steps=1, generator=generator, height=512, width=512\n",
" ).images[0]"
]
},
Expand All @@ -683,7 +683,12 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"test_replace": {
"height=512,": "",
"width=512": ""
}
},
"outputs": [],
"source": [
"%%skip not $to_quantize.value\n",
Expand All @@ -697,6 +702,8 @@
" num_inference_steps=28,\n",
" guidance_scale=5,\n",
" generator=generator,\n",
" height=512,\n",
" width=512\n",
").images[0]\n",
"\n",
"visualize_results(image, opt_image)"
Expand Down Expand Up @@ -758,7 +765,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -772,7 +779,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
"version": "3.10.12"
},
"openvino_notebooks": {
"imageUrl": "https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/ac99098c-66ec-4b7b-9e01-e80625f1dc3f",
Expand Down

0 comments on commit 776076c

Please sign in to comment.