From 187c45e03b44c92b37caff3f2d1ba615538f0e59 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 28 Oct 2024 15:26:19 -0700 Subject: [PATCH] [shortfin] Implement fiber local program forking. (#332) API changes: * sf.Program(..., fiber=) no longer takes a fiber. The fiber is now moved to the invocation (i.e. function(..., fiber=) * sf.Program now requires an explicit devices= (as opposed to implicitly inferring this from the fiber). The prior behavior can be achieved by sf.Program(, devices=fiber.raw_devices). Semantic changes: * Programs are now created with an isolation level (NONE, PER_FIBER, PER_CALL), defaulting to PER_FIBER. The prior behavior was like NONE. * In PER_FIBER mode, only one in-flight call to a program can happen at a time on a single fiber. When a program is used across fibers, each fiber will get their own context fork (which is a light-weight clone of the context that shares immutable data). * In PER_CALL mode, backing contexts are pooled on each calling fiber and each call gets their own context. This supports multiple concurrent invocations to a context at the expense of each call having its own mutable state (i.e. this is only suitable for immutable programs). --- shortfin/python/lib_ext.cc | 38 ++-- shortfin/src/shortfin/local/fiber.h | 17 ++ shortfin/src/shortfin/local/program.cc | 184 ++++++++++++++---- shortfin/src/shortfin/local/program.h | 97 +++++++-- shortfin/tests/invocation/conftest.py | 34 ++-- .../invocation/mobilenet_program_test.py | 97 ++++++--- 6 files changed, 357 insertions(+), 110 deletions(-) diff --git a/shortfin/python/lib_ext.cc b/shortfin/python/lib_ext.cc index dca171d67..c73bf5a93 100644 --- a/shortfin/python/lib_ext.cc +++ b/shortfin/python/lib_ext.cc @@ -248,8 +248,9 @@ void PyAddProgramInvocationArg(py::capsule &inv_capsule, py::handle arg) { } local::ProgramInvocation::Future PyFunctionCall(local::ProgramFunction &self, - py::args args) { - auto inv = self.CreateInvocation(); + py::args args, + local::Fiber &fiber) { + auto inv = self.CreateInvocation(fiber.shared_from_this()); py::capsule inv_capsule(inv.get()); for (py::handle arg : args) { PyAddProgramInvocationArg(inv_capsule, arg); @@ -592,13 +593,14 @@ void BindLocal(py::module_ &m) { py::class_(m, "Program") .def(py::new_([](std::span modules, - local::Fiber &fiber, bool trace_execution) { + std::vector devices, + bool trace_execution) { local::Program::Options options; + options.devices = devices; options.trace_execution = trace_execution; - return local::Program::Load(fiber.shared_from_this(), modules, - std::move(options)); + return local::Program::Load(modules, std::move(options)); }), - py::arg("modules"), py::arg("fiber"), py::kw_only(), + py::arg("modules"), py::kw_only(), py::arg("devices"), py::arg("trace_execution") = false) .def_prop_ro("exports", &local::Program::exports) .def("lookup_function", &local::Program::LookupRequiredFunction) @@ -607,9 +609,14 @@ void BindLocal(py::module_ &m) { .def_prop_ro("name", &local::ProgramFunction::name) .def_prop_ro("calling_convention", &local::ProgramFunction::calling_convention) - .def("invocation", &local::ProgramFunction::CreateInvocation, - DOCSTRING_PROGRAM_FUNCTION_INVOCATION) - .def("__call__", PyFunctionCall, py::arg("args")) + .def( + "invocation", + [](local::ProgramFunction &self, local::Fiber &fiber) { + return self.CreateInvocation(fiber.shared_from_this()); + }, + DOCSTRING_PROGRAM_FUNCTION_INVOCATION) + .def("__call__", PyFunctionCall, py::arg("args"), py::kw_only(), + py::arg("fiber")) .def("__repr__", &local::ProgramFunction::to_s); py::class_(m, "ProgramModule") .def_prop_ro("exports", &local::ProgramModule::exports) @@ -718,8 +725,17 @@ void BindLocal(py::module_ &m) { }; py::class_(m, "Fiber") .def("__repr__", &local::Fiber::to_s) - .def_prop_ro("raw_devices", &local::Fiber::raw_devices, - py::rv_policy::reference_internal) + .def_prop_ro( + "raw_devices", + [](local::Fiber &self) { + std::vector devices; + devices.reserve(self.raw_devices().size()); + for (auto it : self.raw_devices()) { + devices.push_back(it.second); + } + return devices; + }, + py::rv_policy::reference_internal) .def( "raw_device", [](local::Fiber &self, int index) { return self.raw_device(index); }, diff --git a/shortfin/src/shortfin/local/fiber.h b/shortfin/src/shortfin/local/fiber.h index dd63b30f4..afd65b346 100644 --- a/shortfin/src/shortfin/local/fiber.h +++ b/shortfin/src/shortfin/local/fiber.h @@ -146,6 +146,23 @@ class SHORTFIN_API Fiber : public std::enable_shared_from_this { std::unordered_map device_class_count_; // Ordered devices named as ``. std::vector> devices_; + + // Program isolation control. + // This data structure is manipulated by APIs on the Program class hierarchy. + // It maps a parent context pointer to an isolate accounting struct. This + // struct contains a strong reference to the parent_context and a vector + // of fork contexts. For PER_FIBER invocations, there will only ever be either + // zero or one fork_contexts: when no calls have been issued there will be one + // and if a call is outstanding, there will be zero. This is used to guard + // concurrent access. For PER_CALL invocations, there will be as many + // fork_contexts as are needed to satisfy the peak number of calls in flight + // at any time. + // The program_isolate_mu_ must be held to manipulate the accounting structs. + iree::slim_mutex program_isolate_mu_; + std::unordered_map> + program_isolates_; + friend struct detail::ProgramIsolate; }; } // namespace shortfin::local diff --git a/shortfin/src/shortfin/local/program.cc b/shortfin/src/shortfin/local/program.cc index 2f2d95ab4..038cd106a 100644 --- a/shortfin/src/shortfin/local/program.cc +++ b/shortfin/src/shortfin/local/program.cc @@ -36,12 +36,12 @@ void GetVmModuleExports(iree_vm_module_t *vm_module, // -------------------------------------------------------------------------- // ProgramFunction::ProgramFunction( - std::shared_ptr fiber, iree::vm_context_ptr vm_context, - iree_vm_function_t vm_function, + iree::vm_context_ptr vm_context, iree_vm_function_t vm_function, + ProgramIsolation isolation, std::optional invocation_model) - : fiber_(std::move(fiber)), - vm_context_(std::move(vm_context)), + : vm_context_(std::move(vm_context)), vm_function_(vm_function), + isolation_(isolation), invocation_model_(invocation_model ? *invocation_model : GetInvocationModelFromFunction(vm_function)) {} @@ -73,9 +73,19 @@ std::string_view ProgramFunction::calling_convention() const { iree_vm_function_signature(&vm_function_).calling_convention); } -ProgramInvocation::Ptr ProgramFunction::CreateInvocation() { - return ProgramInvocation::New(fiber_, vm_context_, vm_function_, - invocation_model_); +ProgramInvocation::Ptr ProgramFunction::CreateInvocation( + std::shared_ptr fiber) { + // Low-overhead NONE isolation handling (saves some ref-count twiddling). + if (isolation_ == ProgramIsolation::NONE) { + return ProgramInvocation::New(std::move(fiber), vm_context_, vm_function_, + invocation_model_, /*isolate=*/nullptr); + } + + // Create an isolated invocation. + auto [isolated_context, isolate] = + detail::ProgramIsolate::AcquireIsolate(*fiber, vm_context_, isolation_); + return ProgramInvocation::New(std::move(fiber), std::move(isolated_context), + vm_function_, invocation_model_, isolate); } std::string ProgramFunction::to_s() const { @@ -106,7 +116,7 @@ ProgramModule ProgramModule::Load(System &system, system.vm_instance(), contents.const_buffer(), contents.deallocator(), system.host_allocator(), module.for_output())); contents.release(); // Must be invoked on success path only. - return ProgramModule(std::move(module)); + return ProgramModule(system.shared_from_this(), std::move(module)); } ProgramModule ProgramModule::ParameterProvider( @@ -126,7 +136,7 @@ ProgramModule ProgramModule::ParameterProvider( SHORTFIN_THROW_IF_ERROR(iree_io_parameters_module_create( system.vm_instance(), providers.size(), providers.data(), system.host_allocator(), module.for_output())); - return ProgramModule(std::move(module)); + return ProgramModule(system.shared_from_this(), std::move(module)); } std::string_view ProgramModule::name() const { @@ -158,14 +168,27 @@ std::vector ProgramModule::exports() const { // Program // -------------------------------------------------------------------------- // -Program Program::Load(std::shared_ptr fiber, - std::span modules, Options options) { +Program Program::Load(std::span modules, + Options &&options) { std::vector all_modules; std::vector raw_devices; + System *system = nullptr; // By default, bind all devices in the fiber in order to the program. - for (auto &it : fiber->raw_devices()) { - raw_devices.push_back(it.second->hal_device()); + for (auto &it : options.devices) { + raw_devices.push_back(it->hal_device()); + } + + for (auto &mod : modules) { + if (system && &mod.system() != system) { + throw std::invalid_argument( + "Cannot create Program from modules loaded from multiple system " + "instances"); + } + system = &mod.system(); + } + if (!system) { + throw std::invalid_argument("Cannot create Program with no modules"); } // Add a HAL module. @@ -177,12 +200,11 @@ Program Program::Load(std::shared_ptr fiber, // functionality (or module versions; iree_vm_module_dependency_t has the // minimum version required so you can switch between them, and whether they // are optional/required). - auto &system = fiber->system(); iree::vm_module_ptr hal_module; - SHORTFIN_THROW_IF_ERROR( - iree_hal_module_create(system.vm_instance(), raw_devices.size(), - raw_devices.data(), IREE_HAL_MODULE_FLAG_NONE, - system.host_allocator(), hal_module.for_output())); + SHORTFIN_THROW_IF_ERROR(iree_hal_module_create( + system->vm_instance(), raw_devices.size(), raw_devices.data(), + IREE_HAL_MODULE_FLAG_NONE, system->host_allocator(), + hal_module.for_output())); all_modules.push_back(hal_module); // Add explicit modules. @@ -195,10 +217,10 @@ Program Program::Load(std::shared_ptr fiber, iree_vm_context_flags_t flags = IREE_VM_CONTEXT_FLAG_CONCURRENT; if (options.trace_execution) flags |= IREE_VM_CONTEXT_FLAG_TRACE_EXECUTION; SHORTFIN_THROW_IF_ERROR(iree_vm_context_create_with_modules( - system.vm_instance(), flags, all_modules.size(), all_modules.data(), - system.host_allocator(), context.for_output())); + system->vm_instance(), flags, all_modules.size(), all_modules.data(), + system->host_allocator(), context.for_output())); - return Program(std::move(fiber), std::move(context)); + return Program(std::move(context), options.isolation); } std::optional Program::LookupFunction(std::string_view name) { @@ -217,7 +239,7 @@ std::optional Program::LookupFunction(std::string_view name) { // TODO: Torch import is not setting the coarse-fences abi.model on // its functions. Get it from there instead of just assuming based on // name. - return ProgramFunction(fiber_, vm_context_, f, + return ProgramFunction(vm_context_, f, isolation_, ProgramInvocationModel::COARSE_FENCES); } else if (!iree_status_is_not_found(status)) { SHORTFIN_THROW_IF_ERROR(status); @@ -229,7 +251,7 @@ std::optional Program::LookupFunction(std::string_view name) { vm_context_, to_iree_string_view(name), &f); if (iree_status_is_not_found(status)) return {}; SHORTFIN_THROW_IF_ERROR(status); - return ProgramFunction(fiber_, vm_context_, f); + return ProgramFunction(vm_context_, f, isolation_); } ProgramFunction Program::LookupRequiredFunction(std::string_view name) { @@ -260,6 +282,15 @@ std::vector Program::exports() const { return results; } +void Program::PrepareIsolate(Fiber &fiber) { + if (isolation_ == ProgramIsolation::NONE) return; + auto [context, isolate] = + detail::ProgramIsolate::AcquireIsolate(fiber, vm_context_, isolation_); + if (isolate) { + detail::ProgramIsolate::ReleaseIsolate(fiber, std::move(context), isolate); + } +} + // -------------------------------------------------------------------------- // // ProgramInvocation // -------------------------------------------------------------------------- // @@ -287,18 +318,23 @@ void ProgramInvocation::Deleter::operator()(ProgramInvocation *inst) { } ProgramInvocation::ProgramInvocation() = default; -ProgramInvocation::~ProgramInvocation() { - if (!scheduled()) { - // This instance was dropped on the floor before scheduling. - // Clean up the initialization parameters. - iree::vm_context_ptr drop = - iree::vm_context_ptr::steal_reference(state.params.context); +ProgramInvocation::~ProgramInvocation() { ReleaseContext(); } + +void ProgramInvocation::ReleaseContext() { + if (vm_context_) { + if (isolate_) { + detail::ProgramIsolate::ReleaseIsolate(*fiber_, std::move(vm_context_), + isolate_); + } else { + vm_context_.reset(); + } } } ProgramInvocation::Ptr ProgramInvocation::New( std::shared_ptr fiber, iree::vm_context_ptr vm_context, - iree_vm_function_t &vm_function, ProgramInvocationModel invocation_model) { + iree_vm_function_t &vm_function, ProgramInvocationModel invocation_model, + detail::ProgramIsolate *isolate) { auto sig = iree_vm_function_signature(&vm_function); iree_host_size_t arg_count; iree_host_size_t result_count; @@ -337,8 +373,8 @@ ProgramInvocation::Ptr ProgramInvocation::New( static_cast(inst_storage.release())), Deleter()); inst->fiber_ = std::move(fiber); - inst->state.params.context = - vm_context.release(); // Ref transfer to ProgramInvocation. + inst->vm_context_ = std::move(vm_context); + inst->isolate_ = isolate; inst->state.params.function = vm_function; inst->state.params.invocation_model = invocation_model; inst->result_list_ = result_list; @@ -421,7 +457,6 @@ ProgramInvocation::Future ProgramInvocation::Invoke( Params params = invocation->state.params; auto schedule = [](ProgramInvocation *raw_invocation, Worker *worker, - iree_vm_context_t *owned_context, iree_vm_function_t function, ProgramInvocationModel invocation_model, std::optional failure_future) { @@ -440,6 +475,7 @@ ProgramInvocation::Future ProgramInvocation::Invoke( ProgramInvocation::Ptr invocation( static_cast(user_data)); ProgramInvocation *raw_invocation = invocation.get(); + raw_invocation->ReleaseContext(); if (iree_status_is_ok(status)) { raw_invocation->future_->set_result(std::move(invocation)); } else { @@ -469,7 +505,7 @@ ProgramInvocation::Future ProgramInvocation::Invoke( if (iree_status_is_ok(status)) { status = iree_vm_async_invoke(worker->loop(), &invocation->state.async_invoke_state, - owned_context, function, + invocation->vm_context_.get(), function, /*flags=*/IREE_VM_INVOCATION_FLAG_NONE, /*policy=*/nullptr, /*inputs=*/invocation->arg_list(), @@ -478,10 +514,6 @@ ProgramInvocation::Future ProgramInvocation::Invoke( /*user_data=*/invocation.get()); } - // Regardless of status, the context reference we were holding is no - // longer needed. Drop it on the floor. - iree::vm_context_ptr::steal_reference(owned_context); - // On success, then the complete callback takes ownership of the // invocation, so we release it here and return. We have to treat // the invocation as possibly deallocated at this point, since the @@ -490,9 +522,11 @@ ProgramInvocation::Future ProgramInvocation::Invoke( invocation.release(); } else if (failure_future) { // Requested to set any failure on the future. + invocation->ReleaseContext(); failure_future->set_failure(status); } else { // Synchronous: just throw. + invocation->ReleaseContext(); SHORTFIN_THROW_IF_ERROR(status); } }; @@ -504,14 +538,13 @@ ProgramInvocation::Future ProgramInvocation::Invoke( if (&worker == Worker::GetCurrent()) { // On the same worker: fast-path directly to the loop. - schedule(invocation.release(), &worker, params.context, params.function, + schedule(invocation.release(), &worker, params.function, params.invocation_model, /*failure_future=*/{}); } else { // Cross worker coordination: submit an external task to bootstrap. - auto bound_schedule = - std::bind(schedule, invocation.release(), &worker, params.context, - params.function, params.invocation_model, - /*failure_future=*/fork_future); + auto bound_schedule = std::bind(schedule, invocation.release(), &worker, + params.function, params.invocation_model, + /*failure_future=*/fork_future); worker.CallThreadsafe(bound_schedule); } @@ -623,4 +656,69 @@ void StaticProgramParameters::Load(std::filesystem::path file_path, to_iree_string_view(options.format), file_handle.get(), index_.get())); } +// -------------------------------------------------------------------------- // +// ProgramIsolate +// -------------------------------------------------------------------------- // + +std::pair +detail::ProgramIsolate::AcquireIsolate(Fiber &fiber, + iree::vm_context_ptr root_context, + ProgramIsolation isolation) { + assert(isolation != ProgramIsolation::NONE && + "cannot AcquireIsolate when isolation == NONE"); + // Some isolation required. + detail::ProgramIsolate *isolate = nullptr; + { + iree::slim_mutex_lock_guard lock(fiber.program_isolate_mu_); + auto found_it = fiber.program_isolates_.find(root_context.get()); + if (found_it != fiber.program_isolates_.end()) { + isolate = found_it->second.get(); + } + if (isolate && !isolate->fork_contexts.empty()) { + // Fast path: there is an existing isolate and a context avaialable. + auto isolated_context = std::move(isolate->fork_contexts.back()); + isolate->fork_contexts.pop_back(); + return std::make_pair(std::move(isolated_context), isolate); + } else if (!isolate) { + // Initialize a new isolate accounting struct while in the lock. + // Note that this can cause a fault for PER_FIBER mode if the call + // to fork fails below as it will leave the isolate with no available + // context and every future call will raise an exception indicating that + // the context is busy (vs trying to create a new one). This is deemed + // an acceptable situation for a system fault (which is the only reason + // a fork will fail). + auto [inserted_it, inserted] = + fiber.program_isolates_.insert(std::make_pair( + root_context.get(), + std::make_unique(root_context))); + isolate = inserted_it->second.get(); + } else if (isolation == ProgramIsolation::PER_FIBER) { + throw std::logic_error( + "Cannot make concurrent invocations of a PER_FIBER program from " + "the same Fiber. This typically means that two invocations were " + "attempted on the same program on the same fiber without an " + "await. Consider fixing adding appropriate sequencing or switching " + "to either PER_CALL or NONE isolation if appropriate for the use " + "case. This exception can also occur if the first invocation to this " + "Program failed, leaving no initialized Program for this fiber."); + } + } + + // Slow-path: fork needed (and possibly new isolate registration needed). + iree::vm_context_ptr new_context; + SHORTFIN_THROW_IF_ERROR(iree_vm_context_fork( + root_context.get(), fiber.host_allocator(), new_context.for_output())); + return std::make_pair(std::move(new_context), isolate); +} + +void detail::ProgramIsolate::ReleaseIsolate(Fiber &fiber, + iree::vm_context_ptr context, + detail::ProgramIsolate *isolate) { + assert(isolate && "attempt to release null isolate"); + { + iree::slim_mutex_lock_guard lock(fiber.program_isolate_mu_); + isolate->fork_contexts.push_back(std::move(context)); + } +} + } // namespace shortfin::local diff --git a/shortfin/src/shortfin/local/program.h b/shortfin/src/shortfin/local/program.h index bc5ae05dc..ea4f0cc3f 100644 --- a/shortfin/src/shortfin/local/program.h +++ b/shortfin/src/shortfin/local/program.h @@ -26,6 +26,10 @@ class BaseProgramParameters; class Fiber; class System; +namespace detail { +struct ProgramIsolate; +} // namespace detail + enum class ProgramInvocationModel { // Uses the coarse-fences invocation model. In this model, the last two // arguments are a wait and signal fence, which are used for function-level @@ -37,6 +41,24 @@ enum class ProgramInvocationModel { UNKNOWN, }; +// The level of isolation that a program has with respect to concurrent use. +enum class ProgramIsolation { + // There is no isolation: Callers are completely on their own to only issue + // concurrent invocations if supported. + NONE = 0, + + // Each fiber in the system that makes calls into the program will have its + // own shallow fork of the module. This is done on-demand and the root + // program is retained for the lifetime of any referencing fibers. + // Concurrent calls on the same fiber are considered programming errors and + // will be flagged as such at an appropriate debug level. + PER_FIBER = 1, + + // Each call triggers a shallow fork of the module. This is the most expensive + // but safest way to ensure complete isolation of stateless invocations. + PER_CALL = 2, +}; + // State related to making an invocation of a function on a program. // // Since ownership of this object is transferred to the loop/callback and @@ -67,7 +89,8 @@ class SHORTFIN_API ProgramInvocation { static Ptr New(std::shared_ptr fiber, iree::vm_context_ptr vm_context, iree_vm_function_t &vm_function, - ProgramInvocationModel invocation_model); + ProgramInvocationModel invocation_model, + detail::ProgramIsolate *isolate); ProgramInvocation(const ProgramInvocation &) = delete; ProgramInvocation &operator=(const ProgramInvocation &) = delete; ProgramInvocation &operator=(ProgramInvocation &&) = delete; @@ -133,6 +156,11 @@ class SHORTFIN_API ProgramInvocation { private: ProgramInvocation(); void CheckNotScheduled(); + // Eagerly releases context when it is known that no further use of it can + // be made (allowing it to be returned to a pool prior to the invocation + // actually being recycled). Object destruction also does this, but possibly + // extending the context lifetime. + void ReleaseContext(); // Returns a pointer to the trailing arg list. iree_vm_list_t *arg_list(); @@ -156,8 +184,6 @@ class SHORTFIN_API ProgramInvocation { // This must not contain entities that require destruction or cannot be // trivially copied. struct Params { - // Context is retained upon construction and released when scheduled. - iree_vm_context_t *context; iree_vm_function_t function; ProgramInvocationModel invocation_model; }; @@ -169,6 +195,8 @@ class SHORTFIN_API ProgramInvocation { } state; std::shared_ptr fiber_; + iree::vm_context_ptr vm_context_; + detail::ProgramIsolate *isolate_; iree_vm_list_t *result_list_ = nullptr; std::optional future_; iree::hal_fence_ptr wait_fence_; @@ -187,7 +215,7 @@ class SHORTFIN_API ProgramFunction { std::string_view calling_convention() const; ProgramInvocationModel invocation_model() const { return invocation_model_; } - ProgramInvocation::Ptr CreateInvocation(); + ProgramInvocation::Ptr CreateInvocation(std::shared_ptr fiber); std::string to_s() const; @@ -195,17 +223,16 @@ class SHORTFIN_API ProgramFunction { operator iree_vm_function_t &() { return vm_function_; } private: - ProgramFunction(std::shared_ptr fiber, iree::vm_context_ptr vm_context, - iree_vm_function_t vm_function, + ProgramFunction(iree::vm_context_ptr vm_context, + iree_vm_function_t vm_function, ProgramIsolation isolation, std::optional invocation_model = {}); static ProgramInvocationModel GetInvocationModelFromFunction( iree_vm_function_t &f); - // The context that this function was resolved against. - std::shared_ptr fiber_; iree::vm_context_ptr vm_context_; iree_vm_function_t vm_function_; + ProgramIsolation isolation_; ProgramInvocationModel invocation_model_; friend class Program; }; @@ -231,6 +258,7 @@ class SHORTFIN_API ProgramModule { std::string to_s() const; iree_vm_module_t *vm_module() const { return vm_module_; } std::string_view name() const; + System &system() const { return *system_; } // Loads a dynamic bytecode module (VMFB) from a path on the file system. static ProgramModule Load(System &system, const std::filesystem::path &path, @@ -246,10 +274,12 @@ class SHORTFIN_API ProgramModule { std::vector exports() const; protected: - explicit ProgramModule(iree::vm_module_ptr vm_module) - : vm_module_(std::move(vm_module)) {} + explicit ProgramModule(std::shared_ptr system, + iree::vm_module_ptr vm_module) + : system_(std::move(system)), vm_module_(std::move(vm_module)) {} private: + std::shared_ptr system_; iree::vm_module_ptr vm_module_; }; @@ -269,15 +299,19 @@ class SHORTFIN_API Program { struct Options { Options() {} + // Ordered list of devices to bind this program to. + std::span devices; + + // The isolation level to apply to program invocation. + ProgramIsolation isolation = ProgramIsolation::PER_FIBER; + // Enables program-wide execution tracing (to stderr). bool trace_execution = false; }; - // Loads a program attached to a fiber with a list of user provided modules - // and options. - static Program Load(std::shared_ptr fiber, - std::span modules, - Options options = {}); + // Load a program from a list of modules and options. + static Program Load(std::span modules, + Options &&options); // Looks up a public function by fully qualified name (i.e. module.function). // Returns nothing if not found. @@ -290,12 +324,16 @@ class SHORTFIN_API Program { // Gets the name of all exported functions. std::vector exports() const; + // Eagerly does any per-fiber isolation preparation for the program at a + // convenient point (usually init time) to avoid first-invocation overhead. + void PrepareIsolate(Fiber &fiber); + private: - explicit Program(std::shared_ptr fiber, - iree::vm_context_ptr vm_context) - : fiber_(std::move(fiber)), vm_context_(std::move(vm_context)) {} - std::shared_ptr fiber_; + explicit Program(iree::vm_context_ptr vm_context, ProgramIsolation isolation) + : vm_context_(std::move(vm_context)), isolation_(isolation) {} + iree::vm_context_ptr vm_context_; + ProgramIsolation isolation_; friend class Fiber; }; @@ -354,6 +392,27 @@ class SHORTFIN_API StaticProgramParameters : public BaseProgramParameters { iree::io_parameter_index_ptr index_; }; +namespace detail { +// See Fiber::program_isolates_. +struct ProgramIsolate { + ProgramIsolate(iree::vm_context_ptr parent_context) + : parent_context(std::move(parent_context)) {} + iree::vm_context_ptr parent_context; + std::vector fork_contexts; + + // Acquires an isolate for the given fiber. This will return a context which + // may be the original program context or may be a forked child that is + // available for use. It is only valid to call this when isolation != NONE. + static std::pair + AcquireIsolate(Fiber &fiber, iree::vm_context_ptr root_context, + ProgramIsolation isolation); + + // Releases an isolate obtained from a fiber in AcquireIsolate. + static void ReleaseIsolate(Fiber &fiber, iree::vm_context_ptr context, + ProgramIsolate *isolate); +}; +}; // namespace detail + } // namespace shortfin::local #endif // SHORTFIN_LOCAL_PROGRAM_H diff --git a/shortfin/tests/invocation/conftest.py b/shortfin/tests/invocation/conftest.py index c366c7f82..148ae064d 100644 --- a/shortfin/tests/invocation/conftest.py +++ b/shortfin/tests/invocation/conftest.py @@ -22,15 +22,16 @@ def mobilenet_onnx_path(tmp_path_factory): import onnx except ModuleNotFoundError: raise pytest.skip("onnx python package not available") - print("Downloading mobilenet.onnx") parent_dir = tmp_path_factory.mktemp("mobilenet_onnx") orig_onnx_path = parent_dir / "mobilenet_orig.onnx" - urllib.request.urlretrieve( - "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", - orig_onnx_path, - ) upgraded_onnx_path = parent_dir / "mobilenet.onnx" - upgrade_onnx(orig_onnx_path, upgraded_onnx_path) + if not upgraded_onnx_path.exists(): + print("Downloading mobilenet.onnx") + urllib.request.urlretrieve( + "https://github.com/onnx/models/raw/main/validated/vision/classification/mobilenet/model/mobilenetv2-12.onnx", + orig_onnx_path, + ) + upgrade_onnx(orig_onnx_path, upgraded_onnx_path) return upgraded_onnx_path @@ -41,15 +42,18 @@ def mobilenet_compiled_cpu_path(mobilenet_onnx_path): import iree.compiler.tools.import_onnx.__main__ as import_onnx except ModuleNotFoundError: raise pytest.skip("iree.compiler packages not available") - print("Compiling mobilenet") mlir_path = mobilenet_onnx_path.parent / "mobilenet.mlir" vmfb_path = mobilenet_onnx_path.parent / "mobilenet_cpu.vmfb" - args = import_onnx.parse_arguments(["-o", str(mlir_path), str(mobilenet_onnx_path)]) - import_onnx.main(args) - tools.compile_file( - str(mlir_path), - output_file=str(vmfb_path), - target_backends=["llvm-cpu"], - input_type="onnx", - ) + if not vmfb_path.exists(): + print("Compiling mobilenet") + args = import_onnx.parse_arguments( + ["-o", str(mlir_path), str(mobilenet_onnx_path)] + ) + import_onnx.main(args) + tools.compile_file( + str(mlir_path), + output_file=str(vmfb_path), + target_backends=["llvm-cpu"], + input_type="onnx", + ) return vmfb_path diff --git a/shortfin/tests/invocation/mobilenet_program_test.py b/shortfin/tests/invocation/mobilenet_program_test.py index 4275fe9e2..84903fb8f 100644 --- a/shortfin/tests/invocation/mobilenet_program_test.py +++ b/shortfin/tests/invocation/mobilenet_program_test.py @@ -5,6 +5,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import array +import asyncio +import time import functools import pytest @@ -21,38 +23,89 @@ def lsys(): @pytest.fixture -def fiber(lsys): +def fiber0(lsys): return lsys.create_fiber() @pytest.fixture -def device(fiber): - return fiber.device(0) +def device(fiber0): + return fiber0.device(0) -def test_invoke_mobilenet(lsys, fiber, mobilenet_compiled_cpu_path): - device = fiber.device(0) +@pytest.fixture +def mobilenet_program_function( + lsys, mobilenet_compiled_cpu_path +) -> tuple[sf.ProgramFunction]: + program_module = lsys.load_module(mobilenet_compiled_cpu_path) + program = sf.Program([program_module], devices=lsys.devices) + main_function = program["module.torch-jit-export"] + return main_function + + +def get_mobilenet_ref_input(device) -> sfnp.device_array: dummy_data = array.array( "f", ([0.2] * (224 * 224)) + ([0.4] * (224 * 224)) + ([-0.2] * (224 * 224)) ) - program_module = lsys.load_module(mobilenet_compiled_cpu_path) - program = sf.Program([program_module], fiber=fiber) - main_function = program["module.torch-jit-export"] + device_input = sfnp.device_array(device, [1, 3, 224, 224], sfnp.float32) + staging_input = device_input.for_transfer() + with staging_input.map(discard=True) as m: + m.fill(dummy_data) + device_input.copy_from(staging_input) + return device_input + + +async def assert_mobilenet_ref_output(device, device_output): + host_output = device_output.for_transfer() + host_output.copy_from(device_output) + await device + flat_output = host_output.items + absmean = functools.reduce( + lambda x, y: x + abs(y) / len(flat_output), flat_output, 0.0 + ) + print("RESULT:", absmean) + assert absmean == pytest.approx(5.01964943873882) + + +def test_invoke_mobilenet(lsys, fiber0, mobilenet_program_function): + device = fiber0.device(0) async def main(): - device_input = sfnp.device_array(device, [1, 3, 224, 224], sfnp.float32) - staging_input = device_input.for_transfer() - with staging_input.map(discard=True) as m: - m.fill(dummy_data) - device_input.copy_from(staging_input) - (device_output,) = await main_function(device_input) - host_output = device_output.for_transfer() - host_output.copy_from(device_output) - await device - flat_output = host_output.items - absmean = functools.reduce( - lambda x, y: x + abs(y) / len(flat_output), flat_output, 0.0 - ) - assert absmean == pytest.approx(5.01964943873882) + device_input = get_mobilenet_ref_input(device) + (device_output,) = await mobilenet_program_function(device_input, fiber=fiber0) + await assert_mobilenet_ref_output(device, device_output) + + lsys.run(main()) + + +def test_invoke_mobilenet_multi_fiber(lsys, mobilenet_program_function): + class InferProcess(sf.Process): + async def run(self): + start_time = time.time() + + def duration(): + return round((time.time() - start_time) * 1000.0) + + print(f"{self}: Start") + device = self.fiber.device(0) + device_input = get_mobilenet_ref_input(device) + (device_output,) = await mobilenet_program_function( + device_input, fiber=self.fiber + ) + print(f"{self}: Program complete (+{duration()}ms)") + await assert_mobilenet_ref_output(device, device_output) + print(f"{self} End (+{duration()}ms)") + + async def main(): + start_time = time.time() + + def duration(): + return round((time.time() - start_time) * 1000.0) + + fibers = [lsys.create_fiber() for _ in range(5)] + print("Fibers:", fibers) + processes = [InferProcess(fiber=f).launch() for f in fibers] + print("Waiting for processes:", processes) + await asyncio.gather(*processes) + print(f"All processes complete: (+{duration()}ms)") lsys.run(main())