diff --git a/notebooks/lora_decoder.ipynb b/notebooks/lora_decoder.ipynb index 5e79969..56cecaa 100644 --- a/notebooks/lora_decoder.ipynb +++ b/notebooks/lora_decoder.ipynb @@ -3127,7 +3127,7 @@ "id": "Q-Sa97ezzBWo", "outputId": "63a2f6ef-6399-4bc6-8be3-9aecda18a652" }, - "execution_count": 1, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -3261,7 +3261,7 @@ "metadata": { "id": "h196ZQ66kOah" }, - "execution_count": 2, + "execution_count": null, "outputs": [] }, { @@ -3276,7 +3276,7 @@ "metadata": { "id": "qMBAU2yOkttM" }, - "execution_count": 3, + "execution_count": null, "outputs": [] }, { @@ -3292,7 +3292,7 @@ "id": "RJw0fSdF1U0j", "outputId": "1874d2ff-7694-4ada-9ae9-5254150fe368" }, - "execution_count": 4, + "execution_count": null, "outputs": [ { "output_type": "execute_result", @@ -3583,7 +3583,7 @@ "metadata": { "id": "_o3t6oXKy8-s" }, - "execution_count": 5, + "execution_count": null, "outputs": [] }, { @@ -3602,7 +3602,7 @@ "metadata": { "id": "sLjwl1RtyzHO" }, - "execution_count": 6, + "execution_count": null, "outputs": [] }, { @@ -3636,7 +3636,7 @@ "metadata": { "id": "HMPOrc1Y47Y7" }, - "execution_count": 7, + "execution_count": null, "outputs": [] }, { @@ -3651,7 +3651,7 @@ "id": "jk_8ykSi56Bh", "outputId": "62d490a7-476c-4f21-d2eb-0b94d1a683d5" }, - "execution_count": 8, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -3694,7 +3694,7 @@ "metadata": { "id": "Ca-K7XRMO_wC" }, - "execution_count": 9, + "execution_count": null, "outputs": [] }, { @@ -3709,7 +3709,7 @@ "id": "VGaiksGgPFGP", "outputId": "f62ceac1-27af-4df1-fcf5-5d4fbe28c268" }, - "execution_count": 10, + "execution_count": null, "outputs": [ { "output_type": "stream", @@ -3748,7 +3748,7 @@ "metadata": { "id": "Ig17twUvQ8cr" }, - "execution_count": 11, + "execution_count": null, "outputs": [] }, { @@ -3837,7 +3837,7 @@ "id": "Qap8MnWtO9FS", "outputId": "89a2ff09-3170-404d-c01d-a3b4b4877e37" }, - "execution_count": 12, + "execution_count": null, "outputs": [ { "output_type": "display_data", @@ -3941,7 +3941,20 @@ }, { "cell_type": "code", - "source": [], + "source": [ + "from safetensors import safe_open\n", + "\n", + "# Loading the weights from the `safetensors` format to a plain dictionary.\n", + "decoder_lora_state_dict = {}\n", + "with safe_open(\"/content/decoder_lora_saves/checkpoint-500/model.safetensors\", framework=\"pt\", device=\"cpu\") as f:\n", + " for key in f.keys():\n", + " decoder_lora_state_dict[key] = f.get_tensor(key)\n", + "\n", + "decoder_prior_state_dict = {}\n", + "with safe_open(\"/content/decoder_prior_saves/checkpoint-500/model.safetensors\", framework=\"pt\", device=\"cpu\") as f:\n", + " for key in f.keys():\n", + " decoder_prior_state_dict[key] = f.get_tensor(key)" + ], "metadata": { "id": "xEZNDvvintn_" }, @@ -3953,7 +3966,6 @@ "source": [ "from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnAddedKVProcessor\n", "lora_attn_procs = {}\n", - "d = torch.load('/content/decoder_lora_saves/checkpoint-500/pytorch_model.bin')\n", "for name in decoder.unet.attn_processors.keys():\n", " cross_attention_dim = None if name.endswith(\"attn1.processor\") else decoder.unet.config.cross_attention_dim\n", " if name.startswith(\"mid_block\"):\n", @@ -3971,13 +3983,13 @@ " ).to('cuda')\n", "\n", "decoder.unet.set_attn_processor(lora_attn_procs)\n", - "decoder.unet.load_state_dict(d, strict=False)\n", + "decoder.unet.load_state_dict(decoder_lora_state_dict, strict=False)\n", "None" ], "metadata": { "id": "cQaJGtebasLy" }, - "execution_count": 16, + "execution_count": null, "outputs": [] }, { @@ -3988,13 +4000,13 @@ "for name in prior.prior.attn_processors.keys():\n", " lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=2048).to('cuda')\n", "prior.prior.set_attn_processor(lora_attn_procs)\n", - "prior.prior.load_state_dict(torch.load('/content/decoder_prior_saves/checkpoint-500/pytorch_model.bin'), strict=False)\n", + "prior.prior.load_state_dict(decoder_prior_state_dict, strict=False)\n", "None" ], "metadata": { "id": "5lfV9j-9atJt" }, - "execution_count": 17, + "execution_count": null, "outputs": [] }, { @@ -4049,7 +4061,7 @@ "id": "gqnaHUCZnq3Z", "outputId": "fb69eb33-feb1-440d-91e6-a194d0f9d974" }, - "execution_count": 20, + "execution_count": null, "outputs": [ { "output_type": "display_data", @@ -4115,7 +4127,7 @@ "id": "AMuPm5bma-Fq", "outputId": "5a6b959a-f742-44bc-b31a-056f9b6d9c7c" }, - "execution_count": 21, + "execution_count": null, "outputs": [ { "output_type": "execute_result",