From 5014b1fd6ef560b8b37faf7002e771a79567294d Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 29 Oct 2024 15:46:19 -0500 Subject: [PATCH] Adapt for program isolation and fiber flexibility --- .../shortfin_apps/sd/components/generate.py | 1 - .../shortfin_apps/sd/components/service.py | 134 ++++++++++-------- shortfin/python/shortfin_apps/sd/server.py | 41 +++++- shortfin/tests/apps/sd/e2e_test.py | 102 +++++++++++-- 4 files changed, 201 insertions(+), 77 deletions(-) diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py index ca4f9799d..07bb8e2eb 100644 --- a/shortfin/python/shortfin_apps/sd/components/generate.py +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -91,7 +91,6 @@ async def run(self): gen_process = GenerateImageProcess(self, self.gen_req, index) gen_processes.append(gen_process) gen_process.launch() - await asyncio.gather(*gen_processes) # TODO: stream image outputs diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index d6cf71e48..26359a6e9 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -24,6 +24,12 @@ logger = logging.getLogger(__name__) +prog_isolations = { + "none": sf.ProgramIsolation.NONE, + "per_fiber": sf.ProgramIsolation.PER_FIBER, + "per_call": sf.ProgramIsolation.PER_CALL, +} + class GenerateService: """Top level service interface for image generation.""" @@ -39,6 +45,8 @@ def __init__( sysman: SystemManager, tokenizers: list[Tokenizer], model_params: ModelParams, + fibers_per_device: int, + prog_isolation: str = "per_fiber", ): self.name = name @@ -50,17 +58,19 @@ def __init__( self.inference_modules: dict[str, sf.ProgramModule] = {} self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {} self.inference_programs: dict[str, sf.Program] = {} - self.procs_per_device = 1 + self.trace_execution = False + self.fibers_per_device = fibers_per_device + self.prog_isolation = prog_isolations[prog_isolation] self.workers = [] self.fibers = [] - self.locks = [] + self.fiber_status = [] for idx, device in enumerate(self.sysman.ls.devices): - for i in range(self.procs_per_device): + for i in range(self.fibers_per_device): worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}") fiber = sysman.ls.create_fiber(worker, devices=[device]) self.workers.append(worker) self.fibers.append(fiber) - self.locks.append(asyncio.Lock()) + self.fiber_status.append(0) # Scope dependent objects. self.batcher = BatcherProcess(self) @@ -99,7 +109,8 @@ def start(self): self.inference_programs[component] = sf.Program( modules=component_modules, devices=fiber.raw_devices, - trace_execution=False, + isolation=self.prog_isolation, + trace_execution=self.trace_execution, ) # TODO: export vmfbs with multiple batch size entrypoints @@ -169,6 +180,7 @@ def __init__(self, service: GenerateService): self.strobe_enabled = True self.strobes: int = 0 self.ideal_batch_size: int = max(service.model_params.max_batch_size) + self.num_fibers = len(service.fibers) def shutdown(self): self.batcher_infeed.close() @@ -199,6 +211,7 @@ async def run(self): logger.error("Illegal message received by batcher: %r", item) self.board_flights() + self.strobe_enabled = True await strober_task @@ -210,28 +223,40 @@ def board_flights(self): logger.info("Waiting a bit longer to fill flight") return self.strobes = 0 + batches = self.sort_batches() + for idx, batch in batches.items(): + for fidx, status in enumerate(self.service.fiber_status): + if ( + status == 0 + or self.service.prog_isolation == sf.ProgramIsolation.PER_CALL + ): + self.board(batch["reqs"], index=fidx) + break - batches = self.sort_pending() - for idx in batches.keys(): - self.board(batches[idx]["reqs"], index=idx) - - def sort_pending(self): - """Returns pending requests as sorted batches suitable for program invocations.""" + def sort_batches(self): + """Files pending requests into sorted batches suitable for program invocations.""" + reqs = self.pending_requests + next_key = 0 batches = {} - for req in self.pending_requests: + for req in reqs: is_sorted = False req_metas = [req.phases[phase]["metadata"] for phase in req.phases.keys()] - next_key = 0 + for idx_key, data in batches.items(): if not isinstance(data, dict): logger.error( "Expected to find a dictionary containing a list of requests and their shared metadatas." ) - if data["meta"] == req_metas: - batches[idx_key]["reqs"].append(req) + if len(batches[idx_key]["reqs"]) >= self.ideal_batch_size: + # Batch is full + next_key = idx_key + 1 + continue + elif data["meta"] == req_metas: + batches[idx_key]["reqs"].extend([req]) is_sorted = True break - next_key = idx_key + 1 + else: + next_key = idx_key + 1 if not is_sorted: batches[next_key] = { "reqs": [req], @@ -251,7 +276,8 @@ def board(self, request_bundle, index): if exec_process.exec_requests: for flighted_request in exec_process.exec_requests: self.pending_requests.remove(flighted_request) - print(f"launching exec process for {exec_process.exec_requests}") + if self.service.prog_isolation != sf.ProgramIsolation.PER_CALL: + self.service.fiber_status[index] = 1 exec_process.launch() @@ -284,22 +310,22 @@ async def run(self): phases = self.exec_requests[0].phases req_count = len(self.exec_requests) - async with self.service.locks[self.worker_index]: - device0 = self.fiber.device(0) - if phases[InferencePhase.PREPARE]["required"]: - await self._prepare(device=device0, requests=self.exec_requests) - if phases[InferencePhase.ENCODE]["required"]: - await self._encode(device=device0, requests=self.exec_requests) - if phases[InferencePhase.DENOISE]["required"]: - await self._denoise(device=device0, requests=self.exec_requests) - if phases[InferencePhase.DECODE]["required"]: - await self._decode(device=device0, requests=self.exec_requests) - if phases[InferencePhase.POSTPROCESS]["required"]: - await self._postprocess(device=device0, requests=self.exec_requests) + device0 = self.service.fibers[self.worker_index].device(0) + if phases[InferencePhase.PREPARE]["required"]: + await self._prepare(device=device0, requests=self.exec_requests) + if phases[InferencePhase.ENCODE]["required"]: + await self._encode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DENOISE]["required"]: + await self._denoise(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DECODE]["required"]: + await self._decode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.POSTPROCESS]["required"]: + await self._postprocess(device=device0, requests=self.exec_requests) for i in range(req_count): req = self.exec_requests[i] req.done.set_success() + self.service.fiber_status[self.worker_index] = 0 except Exception: logger.exception("Fatal error in image generation") @@ -345,7 +371,6 @@ async def _prepare(self, device, requests): sfnp.fill_randn(sample_host, generator=generator) request.sample.copy_from(sample_host) - await device return async def _encode(self, device, requests): @@ -385,15 +410,13 @@ async def _encode(self, device, requests): clip_inputs[idx].copy_from(host_arrs[idx]) # Encode tokenized inputs. - logger.info( + logger.debug( "INVOKE %r: %s", fn, "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), ) - await device pe, te = await fn(*clip_inputs, fiber=self.fiber) - await device for i in range(req_bs): cfg_mult = 2 requests[i].prompt_embeds = pe.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) @@ -477,20 +500,23 @@ async def _denoise(self, device, requests): ns_host.items = [step_count] num_steps.copy_from(ns_host) - await device + init_inputs = [ + denoise_inputs["sample"], + num_steps, + ] + # Initialize scheduler. - logger.info( - "INVOKE %r: %s", + logger.debug( + "INVOKE %r", fns["init"], - "".join([f"\n 0: {latents_shape}"]), ) (latents, time_ids, timesteps, sigmas) = await fns["init"]( - denoise_inputs["sample"], num_steps, fiber=self.fiber + *init_inputs, fiber=self.fiber ) - - await device for i, t in tqdm( enumerate(range(step_count)), + disable=False, + desc=f"Worker #{self.worker_index} DENOISE (bs{req_bs})", ): step = sfnp.device_array.for_device(device, [1], sfnp.sint64) s_host = step.for_transfer() @@ -498,14 +524,10 @@ async def _denoise(self, device, requests): s_host.items = [i] step.copy_from(s_host) scale_inputs = [latents, step, timesteps, sigmas] - logger.info( - "INVOKE %r: %s", + logger.debug( + "INVOKE %r", fns["scale"], - "".join( - [f"\n {i}: {ary.shape}" for i, ary in enumerate(scale_inputs)] - ), ) - await device latent_model_input, t, sigma, next_sigma = await fns["scale"]( *scale_inputs, fiber=self.fiber ) @@ -519,32 +541,25 @@ async def _denoise(self, device, requests): time_ids, denoise_inputs["guidance_scale"], ] - logger.info( - "INVOKE %r: %s", + logger.debug( + "INVOKE %r", fns["unet"], - "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(unet_inputs)]), ) - await device (noise_pred,) = await fns["unet"](*unet_inputs, fiber=self.fiber) - await device step_inputs = [noise_pred, latents, sigma, next_sigma] - logger.info( - "INVOKE %r: %s", + logger.debug( + "INVOKE %r", fns["step"], - "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(step_inputs)]), ) - await device (latent_model_output,) = await fns["step"](*step_inputs, fiber=self.fiber) latents.copy_from(latent_model_output) - await device for idx, req in enumerate(requests): req.denoised_latents = sfnp.device_array.for_device( device, latents_shape, self.service.model_params.vae_dtype ) req.denoised_latents.copy_from(latents.view(idx)) - await device return async def _decode(self, device, requests): @@ -569,6 +584,11 @@ async def _decode(self, device, requests): await device # Decode the denoised latents. + logger.debug( + "INVOKE %r: %s", + fn, + "".join([f"\n 0: {latents.shape}"]), + ) (image,) = await fn(latents, fiber=self.fiber) await device diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 5e7abd1fc..04da9bd36 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -31,7 +31,9 @@ from .components.tokenizer import Tokenizer -logger = logging.getLogger(__name__) +from shortfin.support.logging_setup import configure_main_logger + +logger = configure_main_logger("server") @asynccontextmanager @@ -87,7 +89,12 @@ def configure(args) -> SystemManager: model_params = ModelParams.load_json(args.model_config) sm = GenerateService( - name="sd", sysman=sysman, tokenizers=tokenizers, model_params=model_params + name="sd", + sysman=sysman, + tokenizers=tokenizers, + model_params=model_params, + fibers_per_device=args.fibers_per_device, + prog_isolation=args.isolation, ) sm.load_inference_module(args.clip_vmfb, component="clip") sm.load_inference_module(args.unet_vmfb, component="unet") @@ -188,10 +195,35 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): nargs="*", help="Parameter archives to load", ) + parser.add_argument( + "--fibers_per_device", + type=int, + default=1, + help="Concurrency control -- how many fibers are created per device to run inference.", + ) + parser.add_argument( + "--isolation", + type=str, + default="per_fiber", + choices=["per_fiber", "per_call", "none"], + help="Concurrency control -- How to isolate programs.", + ) + parser.add_argument( + "--log_level", type=str, default="error", choices=["info", "debug", "error"] + ) + log_levels = { + "info": logging.INFO, + "debug": logging.DEBUG, + "error": logging.ERROR, + } + args = parser.parse_args(argv) + + log_level = log_levels[args.log_level] + logger.setLevel(log_level) + global sysman sysman = configure(args) - uvicorn.run( app, host=args.host, @@ -202,9 +234,6 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): if __name__ == "__main__": - from shortfin.support.logging_setup import configure_main_logger - - logger = configure_main_logger("server") main( sys.argv[1:], # Make logging defer to the default shortfin logging config. diff --git a/shortfin/tests/apps/sd/e2e_test.py b/shortfin/tests/apps/sd/e2e_test.py index d76b39417..8bba60f34 100644 --- a/shortfin/tests/apps/sd/e2e_test.py +++ b/shortfin/tests/apps/sd/e2e_test.py @@ -7,6 +7,7 @@ import os import socket import sys +import copy from contextlib import closing from datetime import datetime as dt @@ -45,12 +46,7 @@ def sd_artifacts(target: str = "gfx942"): cache = os.path.abspath("./tmp/sharktank/sd/") -@pytest.fixture(scope="module") -def sd_server(): - # Create necessary directories - - os.makedirs(cache, exist_ok=True) - +def start_server(fibers_per_device=1, isolation="per_fiber"): # Download model if it doesn't exist vmfbs_bucket = "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/" weights_bucket = ( @@ -82,9 +78,67 @@ def sd_server(): for arg in sd_artifacts().keys(): artifact_arg = f"--{arg}={cache}/{sd_artifacts()[arg]}" srv_args.extend([artifact_arg]) + srv_args.extend( + [ + f"--fibers_per_device={fibers_per_device}", + f"--isolation={isolation}", + ] + ) runner = ServerRunner(srv_args) # Wait for server to start - time.sleep(5) + time.sleep(3) + return runner + + +@pytest.fixture(scope="module") +def sd_server_fpd1(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + runner = start_server(fibers_per_device=1) + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.fixture(scope="module") +def sd_server_fpd1_per_call(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + runner = start_server(fibers_per_device=1, isolation="per_call") + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.fixture(scope="module") +def sd_server_fpd2(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + runner = start_server(fibers_per_device=2) + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.fixture(scope="module") +def sd_server_fpd8(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + runner = start_server(fibers_per_device=8) yield runner @@ -93,15 +147,36 @@ def sd_server(): @pytest.mark.system("amdgpu") -def test_sd_server(sd_server): - imgs, status_code = send_json_file(sd_server.url) +def test_sd_server(sd_server_fpd1): + imgs, status_code = send_json_file(sd_server_fpd1.url) assert len(imgs) == 1 assert status_code == 200 @pytest.mark.system("amdgpu") -def test_sd_server_bs8_dense(sd_server): - imgs, status_code = send_json_file(sd_server.url, num_copies=8) +def test_sd_server_bs4_dense(sd_server_fpd1): + imgs, status_code = send_json_file(sd_server_fpd1.url, num_copies=4) + assert len(imgs) == 4 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs8_percall(sd_server_fpd1_per_call): + imgs, status_code = send_json_file(sd_server_fpd1_per_call.url, num_copies=8) + assert len(imgs) == 8 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs4_dense_fpd2(sd_server_fpd2): + imgs, status_code = send_json_file(sd_server_fpd2.url, num_copies=4) + assert len(imgs) == 4 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs8_dense_fpd8(sd_server_fpd8): + imgs, status_code = send_json_file(sd_server_fpd8.url, num_copies=8) assert len(imgs) == 8 assert status_code == 200 @@ -112,7 +187,6 @@ def __init__(self, args): self.url = "http://0.0.0.0:" + port env = os.environ.copy() env["PYTHONUNBUFFERED"] = "1" - env["HIP_VISIBLE_DEVICES"] = "0" self.process = subprocess.Popen( [ *args, @@ -156,12 +230,14 @@ def bytes_to_img(bytes, idx=0, width=1024, height=1024): image = Image.frombytes( mode="RGB", size=(width, height), data=base64.b64decode(bytes) ) + if os.environ["SF_SAVE_TEST_IMAGES"] == "1": + image.save(f"shortfin_test_output_{timestamp}.png", format="PNG") return image def send_json_file(url="http://0.0.0.0:8000", num_copies=1): # Read the JSON file - data = sample_request + data = copy.deepcopy(sample_request) imgs = [] # Send the data to the /generate endpoint data["prompt"] = (