Skip to content

Commit

Permalink
Adapt for program isolation and fiber flexibility
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 29, 2024
1 parent 264075b commit 5014b1f
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 77 deletions.
1 change: 0 additions & 1 deletion shortfin/python/shortfin_apps/sd/components/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ async def run(self):
gen_process = GenerateImageProcess(self, self.gen_req, index)
gen_processes.append(gen_process)
gen_process.launch()

await asyncio.gather(*gen_processes)

# TODO: stream image outputs
Expand Down
134 changes: 77 additions & 57 deletions shortfin/python/shortfin_apps/sd/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@

logger = logging.getLogger(__name__)

prog_isolations = {
"none": sf.ProgramIsolation.NONE,
"per_fiber": sf.ProgramIsolation.PER_FIBER,
"per_call": sf.ProgramIsolation.PER_CALL,
}


class GenerateService:
"""Top level service interface for image generation."""
Expand All @@ -39,6 +45,8 @@ def __init__(
sysman: SystemManager,
tokenizers: list[Tokenizer],
model_params: ModelParams,
fibers_per_device: int,
prog_isolation: str = "per_fiber",
):
self.name = name

Expand All @@ -50,17 +58,19 @@ def __init__(
self.inference_modules: dict[str, sf.ProgramModule] = {}
self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {}
self.inference_programs: dict[str, sf.Program] = {}
self.procs_per_device = 1
self.trace_execution = False
self.fibers_per_device = fibers_per_device
self.prog_isolation = prog_isolations[prog_isolation]
self.workers = []
self.fibers = []
self.locks = []
self.fiber_status = []
for idx, device in enumerate(self.sysman.ls.devices):
for i in range(self.procs_per_device):
for i in range(self.fibers_per_device):
worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}")
fiber = sysman.ls.create_fiber(worker, devices=[device])
self.workers.append(worker)
self.fibers.append(fiber)
self.locks.append(asyncio.Lock())
self.fiber_status.append(0)

# Scope dependent objects.
self.batcher = BatcherProcess(self)
Expand Down Expand Up @@ -99,7 +109,8 @@ def start(self):
self.inference_programs[component] = sf.Program(
modules=component_modules,
devices=fiber.raw_devices,
trace_execution=False,
isolation=self.prog_isolation,
trace_execution=self.trace_execution,
)

# TODO: export vmfbs with multiple batch size entrypoints
Expand Down Expand Up @@ -169,6 +180,7 @@ def __init__(self, service: GenerateService):
self.strobe_enabled = True
self.strobes: int = 0
self.ideal_batch_size: int = max(service.model_params.max_batch_size)
self.num_fibers = len(service.fibers)

def shutdown(self):
self.batcher_infeed.close()
Expand Down Expand Up @@ -199,6 +211,7 @@ async def run(self):
logger.error("Illegal message received by batcher: %r", item)

self.board_flights()

self.strobe_enabled = True
await strober_task

Expand All @@ -210,28 +223,40 @@ def board_flights(self):
logger.info("Waiting a bit longer to fill flight")
return
self.strobes = 0
batches = self.sort_batches()
for idx, batch in batches.items():
for fidx, status in enumerate(self.service.fiber_status):
if (
status == 0
or self.service.prog_isolation == sf.ProgramIsolation.PER_CALL
):
self.board(batch["reqs"], index=fidx)
break

batches = self.sort_pending()
for idx in batches.keys():
self.board(batches[idx]["reqs"], index=idx)

def sort_pending(self):
"""Returns pending requests as sorted batches suitable for program invocations."""
def sort_batches(self):
"""Files pending requests into sorted batches suitable for program invocations."""
reqs = self.pending_requests
next_key = 0
batches = {}
for req in self.pending_requests:
for req in reqs:
is_sorted = False
req_metas = [req.phases[phase]["metadata"] for phase in req.phases.keys()]
next_key = 0

for idx_key, data in batches.items():
if not isinstance(data, dict):
logger.error(
"Expected to find a dictionary containing a list of requests and their shared metadatas."
)
if data["meta"] == req_metas:
batches[idx_key]["reqs"].append(req)
if len(batches[idx_key]["reqs"]) >= self.ideal_batch_size:
# Batch is full
next_key = idx_key + 1
continue
elif data["meta"] == req_metas:
batches[idx_key]["reqs"].extend([req])
is_sorted = True
break
next_key = idx_key + 1
else:
next_key = idx_key + 1
if not is_sorted:
batches[next_key] = {
"reqs": [req],
Expand All @@ -251,7 +276,8 @@ def board(self, request_bundle, index):
if exec_process.exec_requests:
for flighted_request in exec_process.exec_requests:
self.pending_requests.remove(flighted_request)
print(f"launching exec process for {exec_process.exec_requests}")
if self.service.prog_isolation != sf.ProgramIsolation.PER_CALL:
self.service.fiber_status[index] = 1
exec_process.launch()


Expand Down Expand Up @@ -284,22 +310,22 @@ async def run(self):
phases = self.exec_requests[0].phases

req_count = len(self.exec_requests)
async with self.service.locks[self.worker_index]:
device0 = self.fiber.device(0)
if phases[InferencePhase.PREPARE]["required"]:
await self._prepare(device=device0, requests=self.exec_requests)
if phases[InferencePhase.ENCODE]["required"]:
await self._encode(device=device0, requests=self.exec_requests)
if phases[InferencePhase.DENOISE]["required"]:
await self._denoise(device=device0, requests=self.exec_requests)
if phases[InferencePhase.DECODE]["required"]:
await self._decode(device=device0, requests=self.exec_requests)
if phases[InferencePhase.POSTPROCESS]["required"]:
await self._postprocess(device=device0, requests=self.exec_requests)
device0 = self.service.fibers[self.worker_index].device(0)
if phases[InferencePhase.PREPARE]["required"]:
await self._prepare(device=device0, requests=self.exec_requests)
if phases[InferencePhase.ENCODE]["required"]:
await self._encode(device=device0, requests=self.exec_requests)
if phases[InferencePhase.DENOISE]["required"]:
await self._denoise(device=device0, requests=self.exec_requests)
if phases[InferencePhase.DECODE]["required"]:
await self._decode(device=device0, requests=self.exec_requests)
if phases[InferencePhase.POSTPROCESS]["required"]:
await self._postprocess(device=device0, requests=self.exec_requests)

for i in range(req_count):
req = self.exec_requests[i]
req.done.set_success()
self.service.fiber_status[self.worker_index] = 0

except Exception:
logger.exception("Fatal error in image generation")
Expand Down Expand Up @@ -345,7 +371,6 @@ async def _prepare(self, device, requests):
sfnp.fill_randn(sample_host, generator=generator)

request.sample.copy_from(sample_host)
await device
return

async def _encode(self, device, requests):
Expand Down Expand Up @@ -385,15 +410,13 @@ async def _encode(self, device, requests):
clip_inputs[idx].copy_from(host_arrs[idx])

# Encode tokenized inputs.
logger.info(
logger.debug(
"INVOKE %r: %s",
fn,
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]),
)
await device
pe, te = await fn(*clip_inputs, fiber=self.fiber)

await device
for i in range(req_bs):
cfg_mult = 2
requests[i].prompt_embeds = pe.view(slice(i * cfg_mult, (i + 1) * cfg_mult))
Expand Down Expand Up @@ -477,35 +500,34 @@ async def _denoise(self, device, requests):
ns_host.items = [step_count]
num_steps.copy_from(ns_host)

await device
init_inputs = [
denoise_inputs["sample"],
num_steps,
]

# Initialize scheduler.
logger.info(
"INVOKE %r: %s",
logger.debug(
"INVOKE %r",
fns["init"],
"".join([f"\n 0: {latents_shape}"]),
)
(latents, time_ids, timesteps, sigmas) = await fns["init"](
denoise_inputs["sample"], num_steps, fiber=self.fiber
*init_inputs, fiber=self.fiber
)

await device
for i, t in tqdm(
enumerate(range(step_count)),
disable=False,
desc=f"Worker #{self.worker_index} DENOISE (bs{req_bs})",
):
step = sfnp.device_array.for_device(device, [1], sfnp.sint64)
s_host = step.for_transfer()
with s_host.map(write=True) as m:
s_host.items = [i]
step.copy_from(s_host)
scale_inputs = [latents, step, timesteps, sigmas]
logger.info(
"INVOKE %r: %s",
logger.debug(
"INVOKE %r",
fns["scale"],
"".join(
[f"\n {i}: {ary.shape}" for i, ary in enumerate(scale_inputs)]
),
)
await device
latent_model_input, t, sigma, next_sigma = await fns["scale"](
*scale_inputs, fiber=self.fiber
)
Expand All @@ -519,32 +541,25 @@ async def _denoise(self, device, requests):
time_ids,
denoise_inputs["guidance_scale"],
]
logger.info(
"INVOKE %r: %s",
logger.debug(
"INVOKE %r",
fns["unet"],
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(unet_inputs)]),
)
await device
(noise_pred,) = await fns["unet"](*unet_inputs, fiber=self.fiber)
await device

step_inputs = [noise_pred, latents, sigma, next_sigma]
logger.info(
"INVOKE %r: %s",
logger.debug(
"INVOKE %r",
fns["step"],
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(step_inputs)]),
)
await device
(latent_model_output,) = await fns["step"](*step_inputs, fiber=self.fiber)
latents.copy_from(latent_model_output)
await device

for idx, req in enumerate(requests):
req.denoised_latents = sfnp.device_array.for_device(
device, latents_shape, self.service.model_params.vae_dtype
)
req.denoised_latents.copy_from(latents.view(idx))
await device
return

async def _decode(self, device, requests):
Expand All @@ -569,6 +584,11 @@ async def _decode(self, device, requests):

await device
# Decode the denoised latents.
logger.debug(
"INVOKE %r: %s",
fn,
"".join([f"\n 0: {latents.shape}"]),
)
(image,) = await fn(latents, fiber=self.fiber)

await device
Expand Down
41 changes: 35 additions & 6 deletions shortfin/python/shortfin_apps/sd/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
from .components.tokenizer import Tokenizer


logger = logging.getLogger(__name__)
from shortfin.support.logging_setup import configure_main_logger

logger = configure_main_logger("server")


@asynccontextmanager
Expand Down Expand Up @@ -87,7 +89,12 @@ def configure(args) -> SystemManager:

model_params = ModelParams.load_json(args.model_config)
sm = GenerateService(
name="sd", sysman=sysman, tokenizers=tokenizers, model_params=model_params
name="sd",
sysman=sysman,
tokenizers=tokenizers,
model_params=model_params,
fibers_per_device=args.fibers_per_device,
prog_isolation=args.isolation,
)
sm.load_inference_module(args.clip_vmfb, component="clip")
sm.load_inference_module(args.unet_vmfb, component="unet")
Expand Down Expand Up @@ -188,10 +195,35 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
nargs="*",
help="Parameter archives to load",
)
parser.add_argument(
"--fibers_per_device",
type=int,
default=1,
help="Concurrency control -- how many fibers are created per device to run inference.",
)
parser.add_argument(
"--isolation",
type=str,
default="per_fiber",
choices=["per_fiber", "per_call", "none"],
help="Concurrency control -- How to isolate programs.",
)
parser.add_argument(
"--log_level", type=str, default="error", choices=["info", "debug", "error"]
)
log_levels = {
"info": logging.INFO,
"debug": logging.DEBUG,
"error": logging.ERROR,
}

args = parser.parse_args(argv)

log_level = log_levels[args.log_level]
logger.setLevel(log_level)

global sysman
sysman = configure(args)

uvicorn.run(
app,
host=args.host,
Expand All @@ -202,9 +234,6 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):


if __name__ == "__main__":
from shortfin.support.logging_setup import configure_main_logger

logger = configure_main_logger("server")
main(
sys.argv[1:],
# Make logging defer to the default shortfin logging config.
Expand Down
Loading

0 comments on commit 5014b1f

Please sign in to comment.