From 6f659199d5f9f1dd24dcf899add63468da15f2f0 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Sat, 26 Oct 2024 11:39:39 -0500 Subject: [PATCH] (shortfin_sd) Adapt sf.Program usage for context forking changes --- .../shortfin_apps/sd/components/service.py | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index 7c6af7092..d6cf71e48 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -88,18 +88,19 @@ def load_inference_parameters( self.inference_parameters[component].append(p) def start(self): - for component in self.inference_modules: - component_modules = [ - sf.ProgramModule.parameter_provider( - self.sysman.ls, *self.inference_parameters.get(component, []) - ), - *self.inference_modules[component], - ] - self.inference_programs[component] = sf.Program( - modules=component_modules, - fiber=self.fibers[0], - trace_execution=False, - ) + for fiber in self.fibers: + for component in self.inference_modules: + component_modules = [ + sf.ProgramModule.parameter_provider( + self.sysman.ls, *self.inference_parameters.get(component, []) + ), + *self.inference_modules[component], + ] + self.inference_programs[component] = sf.Program( + modules=component_modules, + devices=fiber.raw_devices, + trace_execution=False, + ) # TODO: export vmfbs with multiple batch size entrypoints @@ -390,7 +391,7 @@ async def _encode(self, device, requests): "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), ) await device - pe, te = await fn(*clip_inputs) + pe, te = await fn(*clip_inputs, fiber=self.fiber) await device for i in range(req_bs): @@ -484,7 +485,7 @@ async def _denoise(self, device, requests): "".join([f"\n 0: {latents_shape}"]), ) (latents, time_ids, timesteps, sigmas) = await fns["init"]( - denoise_inputs["sample"], num_steps + denoise_inputs["sample"], num_steps, fiber=self.fiber ) await device @@ -505,7 +506,9 @@ async def _denoise(self, device, requests): ), ) await device - latent_model_input, t, sigma, next_sigma = await fns["scale"](*scale_inputs) + latent_model_input, t, sigma, next_sigma = await fns["scale"]( + *scale_inputs, fiber=self.fiber + ) await device unet_inputs = [ @@ -522,7 +525,7 @@ async def _denoise(self, device, requests): "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(unet_inputs)]), ) await device - (noise_pred,) = await fns["unet"](*unet_inputs) + (noise_pred,) = await fns["unet"](*unet_inputs, fiber=self.fiber) await device step_inputs = [noise_pred, latents, sigma, next_sigma] @@ -532,7 +535,7 @@ async def _denoise(self, device, requests): "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(step_inputs)]), ) await device - (latent_model_output,) = await fns["step"](*step_inputs) + (latent_model_output,) = await fns["step"](*step_inputs, fiber=self.fiber) latents.copy_from(latent_model_output) await device @@ -566,7 +569,7 @@ async def _decode(self, device, requests): await device # Decode the denoised latents. - (image,) = await fn(latents) + (image,) = await fn(latents, fiber=self.fiber) await device images_shape = [