Skip to content

Commit

Permalink
Merge branch 'main' into sfsd-staging
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored Oct 28, 2024
2 parents 8d4475b + 187c45e commit 5dacf6b
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 110 deletions.
38 changes: 27 additions & 11 deletions shortfin/python/lib_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,9 @@ void PyAddProgramInvocationArg(py::capsule &inv_capsule, py::handle arg) {
}

local::ProgramInvocation::Future PyFunctionCall(local::ProgramFunction &self,
py::args args) {
auto inv = self.CreateInvocation();
py::args args,
local::Fiber &fiber) {
auto inv = self.CreateInvocation(fiber.shared_from_this());
py::capsule inv_capsule(inv.get());
for (py::handle arg : args) {
PyAddProgramInvocationArg(inv_capsule, arg);
Expand Down Expand Up @@ -592,13 +593,14 @@ void BindLocal(py::module_ &m) {

py::class_<local::Program>(m, "Program")
.def(py::new_([](std::span<const local::ProgramModule> modules,
local::Fiber &fiber, bool trace_execution) {
std::vector<const local::Device *> devices,
bool trace_execution) {
local::Program::Options options;
options.devices = devices;
options.trace_execution = trace_execution;
return local::Program::Load(fiber.shared_from_this(), modules,
std::move(options));
return local::Program::Load(modules, std::move(options));
}),
py::arg("modules"), py::arg("fiber"), py::kw_only(),
py::arg("modules"), py::kw_only(), py::arg("devices"),
py::arg("trace_execution") = false)
.def_prop_ro("exports", &local::Program::exports)
.def("lookup_function", &local::Program::LookupRequiredFunction)
Expand All @@ -607,9 +609,14 @@ void BindLocal(py::module_ &m) {
.def_prop_ro("name", &local::ProgramFunction::name)
.def_prop_ro("calling_convention",
&local::ProgramFunction::calling_convention)
.def("invocation", &local::ProgramFunction::CreateInvocation,
DOCSTRING_PROGRAM_FUNCTION_INVOCATION)
.def("__call__", PyFunctionCall, py::arg("args"))
.def(
"invocation",
[](local::ProgramFunction &self, local::Fiber &fiber) {
return self.CreateInvocation(fiber.shared_from_this());
},
DOCSTRING_PROGRAM_FUNCTION_INVOCATION)
.def("__call__", PyFunctionCall, py::arg("args"), py::kw_only(),
py::arg("fiber"))
.def("__repr__", &local::ProgramFunction::to_s);
py::class_<local::ProgramModule>(m, "ProgramModule")
.def_prop_ro("exports", &local::ProgramModule::exports)
Expand Down Expand Up @@ -718,8 +725,17 @@ void BindLocal(py::module_ &m) {
};
py::class_<local::Fiber>(m, "Fiber")
.def("__repr__", &local::Fiber::to_s)
.def_prop_ro("raw_devices", &local::Fiber::raw_devices,
py::rv_policy::reference_internal)
.def_prop_ro(
"raw_devices",
[](local::Fiber &self) {
std::vector<local::Device *> devices;
devices.reserve(self.raw_devices().size());
for (auto it : self.raw_devices()) {
devices.push_back(it.second);
}
return devices;
},
py::rv_policy::reference_internal)
.def(
"raw_device",
[](local::Fiber &self, int index) { return self.raw_device(index); },
Expand Down
17 changes: 17 additions & 0 deletions shortfin/src/shortfin/local/fiber.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,23 @@ class SHORTFIN_API Fiber : public std::enable_shared_from_this<Fiber> {
std::unordered_map<std::string_view, int> device_class_count_;
// Ordered devices named as `<device_class><index>`.
std::vector<std::pair<std::string_view, Device *>> devices_;

// Program isolation control.
// This data structure is manipulated by APIs on the Program class hierarchy.
// It maps a parent context pointer to an isolate accounting struct. This
// struct contains a strong reference to the parent_context and a vector
// of fork contexts. For PER_FIBER invocations, there will only ever be either
// zero or one fork_contexts: when no calls have been issued there will be one
// and if a call is outstanding, there will be zero. This is used to guard
// concurrent access. For PER_CALL invocations, there will be as many
// fork_contexts as are needed to satisfy the peak number of calls in flight
// at any time.
// The program_isolate_mu_ must be held to manipulate the accounting structs.
iree::slim_mutex program_isolate_mu_;
std::unordered_map<iree_vm_context_t *,
std::unique_ptr<detail::ProgramIsolate>>
program_isolates_;
friend struct detail::ProgramIsolate;
};

} // namespace shortfin::local
Expand Down
Loading

0 comments on commit 5dacf6b

Please sign in to comment.