Skip to content

Commit

Permalink
[libshortfin] Rollup of API changes and fixes to support the LLM serv…
Browse files Browse the repository at this point in the history
…er. (#198)

Core:

* Convert Queue to shared ownership.
* Allow creation of anonymous Queues.
* Adds asyncio loop support for `call_at`
* Fixes `VoidFuture::set_success` to also signal callbacks
* Adds `StaticProgramParameters` and ability to construct a
`ParameterProvider`
* Fixes the hacky `Account::OnSync()` to return a `VoidFuture` and use
it to route any wait error (instead of aborting)

Bindings:

* Adds `name` and `compute_dense_nd_size` bindings to `DType`
* Makes `Process` subclassing more natural by supporting split new/init
(allows arbitrary arguments to the subclass constructor)
* Disables nanobind leak detector in debug builds of CPython (it seems
to rely on immortalization of certain identifiers which are properly
cleaned up in debug builds but left allocated forever in release builds)
* Initialize logging from environment on load
* Adds repr to `ProgramInvocation`
* Reworks `Scope.devices` to support iteration and act like a built-in
* Adds `Queue.write_nodelay` and `Queue.closed`
* Adds `VoidFuture.set_success()`
* Reworks optional dep handling to be more verbose/precise
* Adds a module to perfom logging setup to the native logger

FastAPI interop:

* Replace `close_with_error` with `ensure_response` that can be used
universally in a finally block.
* Auto casts responses of `bytes` to a `Response`
  • Loading branch information
stellaraccident authored Sep 19, 2024
1 parent 5a48398 commit 89d5d52
Show file tree
Hide file tree
Showing 30 changed files with 659 additions and 180 deletions.
4 changes: 3 additions & 1 deletion libshortfin/examples/python/async/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ async def run(self):


async def main():
queue = lsys.create_queue("infeed")
queue = lsys.create_queue()
main_scope = lsys.create_scope()
# TODO: Also test named queues.
# queue = lsys.create_queue("infeed")
w1 = lsys.create_worker("w1")
w1_scope = lsys.create_scope(w1)
await asyncio.gather(
Expand Down
5 changes: 3 additions & 2 deletions libshortfin/examples/python/fastapi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ async def run(self):
)
await asyncio.sleep(0.01)
responder.stream_part(None)
except Exception as e:
responder.close_with_error()
except:
traceback.print_exc()
finally:
responder.ensure_response()


@asynccontextmanager
Expand Down
9 changes: 8 additions & 1 deletion libshortfin/python/_shortfin/asyncio_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def call_later(
w.delay_call(deadline, handle._sf_maybe_run)
return handle

def call_at(self, when, callback, *args, context=None) -> asyncio.TimerHandle:
w = self._worker
deadline = int(when * 1e9)
handle = _TimerHandle(when, callback, args, self, context)
w.delay_call(deadline, handle._sf_maybe_run)
return handle

def call_exception_handler(self, context) -> None:
# TODO: Should route this to the central exception handler. Should
# also play with ergonomics of how the errors get reported in
Expand All @@ -62,7 +69,7 @@ def call_exception_handler(self, context) -> None:

def _timer_handle_cancelled(self, handle):
# We don't do anything special: just skip it if it comes up.
pass
...


class _Handle(asyncio.Handle):
Expand Down
3 changes: 2 additions & 1 deletion libshortfin/python/array_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ void BindArray(py::module_ &m) {
auto refs = std::make_shared<Refs>();

py::class_<DType>(m, "DType")
.def_prop_ro("name", &DType::name)
.def_prop_ro("is_boolean", &DType::is_boolean)
.def_prop_ro("is_integer", &DType::is_integer)
.def_prop_ro("is_float", &DType::is_float)
Expand All @@ -203,6 +204,7 @@ void BindArray(py::module_ &m) {
.def_prop_ro("is_byte_aligned", &DType::is_byte_aligned)
.def_prop_ro("dense_byte_count", &DType::dense_byte_count)
.def("is_integer_bitwidth", &DType::is_integer_bitwidth)
.def("compute_dense_nd_size", &DType::compute_dense_nd_size)
.def(py::self == py::self)
.def("__repr__", &DType::name);

Expand Down Expand Up @@ -375,7 +377,6 @@ void BindArray(py::module_ &m) {
py::rv_policy::reference_internal)
.def_prop_ro("storage", &device_array::storage,
py::rv_policy::reference_internal)

.def(
"fill",
[](py::handle_t<device_array> self, py::handle buffer) {
Expand Down
206 changes: 163 additions & 43 deletions libshortfin/python/lib_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,18 @@ class PyWorkerExtension : public local::Worker::Extension {

class PyProcess : public local::detail::BaseProcess {
public:
PyProcess(std::shared_ptr<local::Scope> scope, std::shared_ptr<Refs> refs)
: BaseProcess(std::move(scope)), refs_(std::move(refs)) {}
PyProcess(std::shared_ptr<Refs> refs)
: BaseProcess(), refs_(std::move(refs)) {}
using BaseProcess::Initialize;
using BaseProcess::is_initialized;
using BaseProcess::Launch;

void AssertInitialized() {
if (!is_initialized()) {
throw std::logic_error("Process.__init__ not called in constructor");
}
}

void ScheduleOnWorker() override {
// This is tricky: We need to retain the object reference across the
// thread transition, but on the receiving side, the GIL will not be
Expand Down Expand Up @@ -314,6 +322,19 @@ py::object RunInForeground(std::shared_ptr<Refs> refs, local::System &self,
} // namespace

NB_MODULE(lib, m) {
// Tragically, debug builds of Python do the right thing and don't immortalize
// many identifiers and such. This makes the last chance leak checking that
// nanobind does somewhat unreliable since the reports it prints may be
// to identifiers that are no longer live (at a time in process shutdown
// where it is expected that everything left just gets dropped on the floor).
// This causes segfaults or ASAN violations in the leak checker on exit in
// certain scenarios where we have spurious "leaks" of global objects.
#if defined(Py_DEBUG)
py::set_leak_warnings(false);
#endif

logging::InitializeFromEnv();

py::register_exception_translator(
[](const std::exception_ptr &p, void * /*unused*/) {
try {
Expand Down Expand Up @@ -412,12 +433,15 @@ void BindLocal(py::module_ &m) {
py::rv_policy::reference_internal)
.def(
"create_queue",
[](local::System &self, std::string name) -> local::Queue & {
[](local::System &self,
std::optional<std::string> name) -> std::shared_ptr<local::Queue> {
local::Queue::Options options;
options.name = std::move(name);
if (name) {
options.name = std::move(*name);
}
return self.CreateQueue(std::move(options));
},
py::arg("name"), py::rv_policy::reference_internal)
py::arg("name") = py::none(), py::rv_policy::reference_internal)
.def("named_queue", &local::System::named_queue, py::arg("name"),
py::rv_policy::reference_internal)
.def(
Expand Down Expand Up @@ -512,7 +536,18 @@ void BindLocal(py::module_ &m) {
.def_prop_ro("exports", &local::ProgramModule::exports)
.def("__repr__", &local::ProgramModule::to_s)
.def_static("load", &local::ProgramModule::Load, py::arg("system"),
py::arg("path"), py::arg("mmap") = true);
py::arg("path"), py::arg("mmap") = true)
.def_static(
"parameter_provider",
[](local::System &system, py::args params) {
std::vector<local::BaseProgramParameters *> c_params;
c_params.reserve(params.size());
for (py::handle h : params) {
c_params.push_back(py::cast<local::BaseProgramParameters *>(h));
}
return local::ProgramModule::ParameterProvider(system, c_params);
},
py::arg("system"), py::arg("params"));
py::class_<local::ProgramInvocation::Ptr>(m, "ProgramInvocation")
.def("invoke",
[](local::ProgramInvocation::Ptr &self) {
Expand Down Expand Up @@ -559,11 +594,48 @@ void BindLocal(py::module_ &m) {
}
return PyRehydrateRef(self.get(), std::move(ref));
},
"Gets the i'th result");
"Gets the i'th result")
.def("__repr__", [](local::ProgramInvocation::Ptr &self) {
if (!self) return std::string("ProgramInvocation(INVALID)");
return self->to_s();
});

py::class_<local::BaseProgramParameters>(m, "BaseProgramParameters");
py::class_<local::StaticProgramParameters, local::BaseProgramParameters>(
m, "StaticProgramParameters")
.def(
py::init<local::System &, std::string_view, iree_host_size_t>(),
py::arg("system"), py::arg("parameter_scope"),
py::arg("max_concurrent_operations") =
IREE_IO_PARAMETER_INDEX_PROVIDER_DEFAULT_MAX_CONCURRENT_OPERATIONS)
.def(
"load",
[](local::StaticProgramParameters &self,
std::filesystem::path file_path, std::string_view format,
bool readable, bool writable, bool mmap) {
local::StaticProgramParameters::LoadOptions options;
options.format = format;
options.readable = readable;
options.writable = writable;
options.mmap = mmap;
self.Load(file_path, options);
},
py::arg("file_path"), py::arg("format") = std::string_view(),
py::arg("readable") = true, py::arg("writable") = false,
py::arg("mmap") = true);

struct DevicesSet {
DevicesSet(local::Scope &scope) : scope(scope) {}
local::Scope &scope;
DevicesSet(py::object scope_obj, std::optional<size_t> index = {})
: scope_obj(std::move(scope_obj)), index(index) {}
py::object KeepAlive(local::ScopedDevice device) {
py::object device_obj = py::cast(device);
py::detail::keep_alive(/*nurse=*/device_obj.ptr(),
/*patient=*/scope_obj.ptr());
return device_obj;
}
local::Scope &scope() { return py::cast<local::Scope &>(scope_obj); }
py::object scope_obj;
std::optional<size_t> index;
};
py::class_<local::Scope>(m, "Scope")
.def("__repr__", &local::Scope::to_s)
Expand All @@ -579,12 +651,22 @@ void BindLocal(py::module_ &m) {
return self.raw_device(name);
},
py::rv_policy::reference_internal)
.def_prop_ro(
"devices", [](local::Scope &self) { return DevicesSet(self); },
py::rv_policy::reference_internal)
.def_prop_ro("devices",
[](py::object self) { return DevicesSet(std::move(self)); })
.def_prop_ro("devices_dict",
[](py::handle self_obj) {
local::Scope &self = py::cast<local::Scope &>(self_obj);
py::dict d;
for (auto &it : self.raw_devices()) {
py::object scoped_device =
py::cast(self.device(it.second));
py::detail::keep_alive(/*nurse=*/scoped_device.ptr(),
/*patient=*/self_obj.ptr());
d[py::cast(it.first)] = scoped_device;
}
return d;
})
.def_prop_ro("device_names", &local::Scope::device_names)
.def_prop_ro("named_devices", &local::Scope::named_devices,
py::rv_policy::reference_internal)
.def(
"device",
[](local::Scope &self, py::args args) {
Expand All @@ -608,25 +690,35 @@ void BindLocal(py::module_ &m) {
.def("__repr__", &local::ScopedDevice::to_s);

py::class_<DevicesSet>(m, "_ScopeDevicesSet")
.def("__iter__",
[](DevicesSet &self) { return DevicesSet(self.scope_obj, 0); })
.def("__next__",
[](DevicesSet &self) {
auto &scope = self.scope();
if (!self.index || *self.index >= scope.raw_devices().size()) {
// Blurgh: Exception as flow control is not cheap. There is a
// very obnoxious way to make this not be exception based but
// this is a minority path.
throw py::stop_iteration();
}
return self.KeepAlive(scope.device((*self.index)++));
})
.def("__len__",
[](DevicesSet &self) { return self.scope.raw_devices().size(); })
.def(
"__getitem__",
[](DevicesSet &self, int index) { return self.scope.device(index); },
py::rv_policy::reference_internal)
.def(
"__getitem__",
[](DevicesSet &self, std::string_view name) {
return self.scope.device(name);
},
py::rv_policy::reference_internal)
.def(
"__getattr__",
[](DevicesSet &self, std::string_view name) {
return self.scope.device(name);
},
py::rv_policy::reference_internal);
[](DevicesSet &self) { return self.scope().raw_devices().size(); })
.def("__getitem__",
[](DevicesSet &self, size_t index) {
return self.KeepAlive(self.scope().device(index));
})
.def("__getitem__",
[](DevicesSet &self, std::string_view name) {
return self.KeepAlive(self.scope().device(name));
})
.def("__getattr__",
[](DevicesSet &self, std::string_view name) -> py::object {
return self.KeepAlive(self.scope().device(name));
});

;
py::class_<local::Worker>(m, "Worker", py::is_weak_referenceable())
.def_prop_ro("loop",
[](local::Worker &self) {
Expand Down Expand Up @@ -688,30 +780,51 @@ void BindLocal(py::module_ &m) {
.def("__repr__", &local::Worker::to_s);

py::class_<PyProcess>(m, "Process")
.def("__init__", [](py::args, py::kwargs) {})
.def(
"__init__",
[](py::handle self_obj, std::shared_ptr<local::Scope> scope) {
PyProcess &self = py::cast<PyProcess &>(self_obj);
self.Initialize(std::move(scope));
},
py::kw_only(), py::arg("scope"))
.def_static(
"__new__",
[refs](py::handle py_type, py::args,
std::shared_ptr<local::Scope> scope, py::kwargs) {
return custom_new<PyProcess>(py_type, std::move(scope), refs);
[refs](py::handle py_type, py::args, py::kwargs) {
return custom_new<PyProcess>(py_type, refs);
},
py::arg("type"), py::arg("args"), py::arg("scope"), py::arg("kwargs"))
py::arg("type"), py::arg("args"), py::arg("kwargs"))
.def_prop_ro("pid", &PyProcess::pid)
.def_prop_ro("scope", &PyProcess::scope)
.def_prop_ro("scope",
[](PyProcess &self) -> std::shared_ptr<local::Scope> {
self.AssertInitialized();
return self.scope();
})
.def_prop_ro("system",
[](PyProcess &self) {
self.AssertInitialized();
return self.scope()->system().shared_ptr();
})
.def("launch",
[](py::object self_obj) {
PyProcess &self = py::cast<PyProcess &>(self_obj);
self.AssertInitialized();
self.Launch();
return self_obj;
})
.def("__await__",
[](PyProcess &self) {
self.AssertInitialized();
py::object future =
py::cast(local::CompletionEvent(self.OnTermination()),
py::rv_policy::move);
return future.attr("__await__")();
})
.def("__repr__", &PyProcess::to_s);
.def("__repr__", [](PyProcess &self) {
if (!self.is_initialized()) {
return std::string("Process(UNINITIALIZED)");
}
return self.to_s();
});

py::class_<local::CompletionEvent>(m, "CompletionEvent")
.def(py::init<>())
Expand Down Expand Up @@ -800,10 +913,15 @@ void BindLocal(py::module_ &m) {
py::type<local::QueueWriter>(),
/*keep_alive=*/self, /*queue=*/self);
})
.def("reader", [](local::Queue &self) {
return custom_new_keep_alive<local::QueueReader>(
py::type<local::QueueReader>(),
/*keep_alive=*/self, /*queue=*/self);
.def("reader",
[](local::Queue &self) {
return custom_new_keep_alive<local::QueueReader>(
py::type<local::QueueReader>(),
/*keep_alive=*/self, /*queue=*/self);
})
.def_prop_ro("closed", &local::Queue::is_closed)
.def("write_nodelay", [](local::Queue &self, local::Message &message) {
self.WriteNoDelay(local::Message::Ref(message));
});
py::class_<local::QueueWriter>(m, "QueueWriter")
.def("__call__",
Expand Down Expand Up @@ -864,7 +982,9 @@ void BindLocal(py::module_ &m) {
});
return iter_ret;
});
py::class_<local::VoidFuture, local::Future>(m, "VoidFuture");
py::class_<local::VoidFuture, local::Future>(m, "VoidFuture")
.def(py::init<>())
.def("set_success", [](local::VoidFuture &self) { self.set_success(); });
py::class_<local::ProgramInvocation::Future, local::Future>(
m, "ProgramInvocationFuture")
.def("result", [](local::ProgramInvocation::Future &self) {
Expand Down
1 change: 1 addition & 0 deletions libshortfin/python/lib_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <nanobind/operators.h>
#include <nanobind/stl/filesystem.h>
#include <nanobind/stl/function.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/string_view.h>
Expand Down
Loading

0 comments on commit 89d5d52

Please sign in to comment.