Skip to content

Commit

Permalink
(shortfin_sd) Adapt sf.Program usage for context forking changes
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 28, 2024
1 parent 6f667e0 commit 6f65919
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,19 @@ def load_inference_parameters(
self.inference_parameters[component].append(p)

def start(self):
for component in self.inference_modules:
component_modules = [
sf.ProgramModule.parameter_provider(
self.sysman.ls, *self.inference_parameters.get(component, [])
),
*self.inference_modules[component],
]
self.inference_programs[component] = sf.Program(
modules=component_modules,
fiber=self.fibers[0],
trace_execution=False,
)
for fiber in self.fibers:
for component in self.inference_modules:
component_modules = [
sf.ProgramModule.parameter_provider(
self.sysman.ls, *self.inference_parameters.get(component, [])
),
*self.inference_modules[component],
]
self.inference_programs[component] = sf.Program(
modules=component_modules,
devices=fiber.raw_devices,
trace_execution=False,
)

# TODO: export vmfbs with multiple batch size entrypoints

Expand Down Expand Up @@ -390,7 +391,7 @@ async def _encode(self, device, requests):
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]),
)
await device
pe, te = await fn(*clip_inputs)
pe, te = await fn(*clip_inputs, fiber=self.fiber)

await device
for i in range(req_bs):
Expand Down Expand Up @@ -484,7 +485,7 @@ async def _denoise(self, device, requests):
"".join([f"\n 0: {latents_shape}"]),
)
(latents, time_ids, timesteps, sigmas) = await fns["init"](
denoise_inputs["sample"], num_steps
denoise_inputs["sample"], num_steps, fiber=self.fiber
)

await device
Expand All @@ -505,7 +506,9 @@ async def _denoise(self, device, requests):
),
)
await device
latent_model_input, t, sigma, next_sigma = await fns["scale"](*scale_inputs)
latent_model_input, t, sigma, next_sigma = await fns["scale"](
*scale_inputs, fiber=self.fiber
)
await device

unet_inputs = [
Expand All @@ -522,7 +525,7 @@ async def _denoise(self, device, requests):
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(unet_inputs)]),
)
await device
(noise_pred,) = await fns["unet"](*unet_inputs)
(noise_pred,) = await fns["unet"](*unet_inputs, fiber=self.fiber)
await device

step_inputs = [noise_pred, latents, sigma, next_sigma]
Expand All @@ -532,7 +535,7 @@ async def _denoise(self, device, requests):
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(step_inputs)]),
)
await device
(latent_model_output,) = await fns["step"](*step_inputs)
(latent_model_output,) = await fns["step"](*step_inputs, fiber=self.fiber)
latents.copy_from(latent_model_output)
await device

Expand Down Expand Up @@ -566,7 +569,7 @@ async def _decode(self, device, requests):

await device
# Decode the denoised latents.
(image,) = await fn(latents)
(image,) = await fn(latents, fiber=self.fiber)

await device
images_shape = [
Expand Down

0 comments on commit 6f65919

Please sign in to comment.