From cd28a70bf30f4e68e422905eeaae1cef58a7dcb8 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 15 Jan 2023 18:58:37 -0800 Subject: [PATCH] Initial implementation of async execution. --- .../iree/integrations/pjrt/common/BUILD | 1 + .../iree/integrations/pjrt/common/api_impl.cc | 72 +++++++++++++++++-- .../iree/integrations/pjrt/common/api_impl.h | 28 +++++++- 3 files changed, 94 insertions(+), 7 deletions(-) diff --git a/pjrt-plugin/iree/integrations/pjrt/common/BUILD b/pjrt-plugin/iree/integrations/pjrt/common/BUILD index 340c8133..21c8ef2d 100644 --- a/pjrt-plugin/iree/integrations/pjrt/common/BUILD +++ b/pjrt-plugin/iree/integrations/pjrt/common/BUILD @@ -26,6 +26,7 @@ iree_pjrt_cc_library( deps = [ ":compiler", ":debugging", + "@iree_core//runtime/src/iree/base:tracing", "@iree_core//runtime/src/iree/hal", "@iree_core//runtime/src/iree/modules/hal", "@iree_core//runtime/src/iree/vm", diff --git a/pjrt-plugin/iree/integrations/pjrt/common/api_impl.cc b/pjrt-plugin/iree/integrations/pjrt/common/api_impl.cc index be45e9af..1a5fd7bf 100644 --- a/pjrt-plugin/iree/integrations/pjrt/common/api_impl.cc +++ b/pjrt-plugin/iree/integrations/pjrt/common/api_impl.cc @@ -9,6 +9,7 @@ #include #include +#include "iree/base/tracing.h" #include "iree/hal/api.h" namespace iree::pjrt { @@ -352,11 +353,13 @@ iree_status_t BufferInstance::GetXlaShape(xla::Shape** out_shape) { void BufferInstance::BindApi(PJRT_Api* api) { api->PJRT_Buffer_Destroy = +[](PJRT_Buffer_Destroy_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_Destroy"); delete BufferInstance::Unwrap(args->buffer); return nullptr; }; api->PJRT_Buffer_OnDeviceTrimmedShape = +[](PJRT_Buffer_OnDeviceTrimmedShape_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_OnDeviceTrimmedShape"); auto impl = [&]() -> iree_status_t { // TODO: This function is terrible and not exposed properly to C. // It is slated to be deleted... @@ -382,6 +385,7 @@ void BufferInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Buffer_ToHostBuffer = +[](PJRT_Buffer_ToHostBuffer_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_ToHostBuffer"); BufferInstance* buffer = BufferInstance::Unwrap(args->src); if (!args->dst) { // Size query. @@ -395,10 +399,12 @@ void BufferInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Buffer_OnDeviceSizeInBytes = +[](PJRT_Buffer_OnDeviceSizeInBytes_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_OnDeviceSizeInBytes"); return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_OnDeviceSizeInBytes")); }; api->PJRT_Buffer_Delete = +[](PJRT_Buffer_Delete_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_Delete"); return MakeError( iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_Delete")); }; @@ -409,6 +415,7 @@ void BufferInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Buffer_CopyToDevice = +[](PJRT_Buffer_CopyToDevice_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_CopyToDevice"); return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_CopyToDevice")); }; @@ -423,6 +430,7 @@ void BufferInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Buffer_ReadyEvent = +[](PJRT_Buffer_ReadyEvent_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Buffer_ReadyEvent"); return MakeError( iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_ReadyEvent")); }; @@ -440,6 +448,12 @@ iree_status_t BufferInstance::GetHostSizeInBytes(iree_host_size_t* host_size) { iree_status_t BufferInstance::CopyToHost(void* dst, iree_host_size_t dst_size, EventInstance** done_event) { + // TODO: Make unconditional to have a ready fence. + if (ready_fence_) { + IREE_RETURN_IF_ERROR( + iree_hal_fence_wait(ready_fence_.get(), iree_infinite_timeout())); + } + // TODO: Do an async transfer on a transfer queue like a grown up. iree_hal_device_t* hal_device; IREE_RETURN_IF_ERROR(device_.GetHalDevice(&hal_device)); @@ -504,10 +518,13 @@ void DeviceInstance::BindApi(PJRT_Api* api) { iree_status_t DeviceInstance::OpenDevice() { if (device_) return iree_ok_status(); - return iree_hal_driver_create_device_by_id( + IREE_RETURN_IF_ERROR(iree_hal_driver_create_device_by_id( driver_, /*device_id=*/info_->device_id, /*param_count=*/0, /*params=*/nullptr, client_.host_allocator(), - &device_); + &device_)); + IREE_RETURN_IF_ERROR( + iree_hal_semaphore_create(device_.get(), 0ull, &main_timeline_)); + return iree_ok_status(); } iree_status_t DeviceInstance::HostBufferToDevice( @@ -583,7 +600,7 @@ iree_status_t DeviceInstance::HostBufferToDevice( *out_done_with_host_buffer_event = new EventInstance(); // Construct and return a BufferInstance. - *out_buffer = new BufferInstance(*this, buffer_view); + *out_buffer = new BufferInstance(*this, buffer_view, /*ready_fence=*/nullptr); return iree_ok_status(); } @@ -621,6 +638,7 @@ void ClientInstance::BindApi(PJRT_Api* api) { // PJRT_Client_Create is polymorphic api->PJRT_Client_Destroy = +[](PJRT_Client_Destroy_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Client_Destroy"); delete ClientInstance::Unwrap(args->client); return nullptr; }; @@ -674,6 +692,7 @@ void ClientInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Client_Compile = +[](PJRT_Client_Compile_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Client_Compile"); // TODO: It is not great that we only get a client here vs a list of // devices to consider (or something). The issue is that systems often // have unrelated devices that will not actually be scheduled and those @@ -698,6 +717,7 @@ void ClientInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Client_BufferFromHostBuffer = +[](PJRT_Client_BufferFromHostBuffer_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Client_BufferFromHostBuffer"); auto status = DeviceInstance::Unwrap(args->device) ->HostBufferToDevice( @@ -799,6 +819,9 @@ PJRT_Error* ClientInstance::Compile(PJRT_Program* program, if (!job->SetFlag("--iree-input-type=mhlo")) { return MakeCompilerError(); } + if (!job->SetFlag("--iree-execution-model=async-external")) { + return MakeCompilerError(); + } if (!SetDefaultCompilerFlags(job.get())) { return MakeCompilerError(); } @@ -854,27 +877,39 @@ iree_status_t ClientInstance::PopulateVMModules( return iree_ok_status(); } +std::tuple ClientInstance::AdvanceTimeline() { + uint64_t current = execution_timeline_; + uint64_t next = current + 1; + execution_timeline_ = next; + return std::make_tuple(current, next); +} + //===----------------------------------------------------------------------===// // EventInstance //===----------------------------------------------------------------------===// void EventInstance::BindApi(PJRT_Api* api) { api->PJRT_Event_Destroy = +[](PJRT_Event_Destroy_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Event_Destroy"); delete EventInstance::Unwrap(args->event); return nullptr; }; api->PJRT_Event_IsReady = +[](PJRT_Event_IsReady_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Event_IsReady"); args->is_ready = EventInstance::Unwrap(args->event)->is_ready(); return nullptr; }; api->PJRT_Event_Error = +[](PJRT_Event_Error_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Event_Error"); return (PJRT_Error*)EventInstance::Unwrap(args->event)->error(); }; api->PJRT_Event_Await = +[](PJRT_Event_Await_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Event_Await"); return MakeError( iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Event_Await")); }; api->PJRT_Event_OnReady = +[](PJRT_Event_OnReady_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Event_OnReady"); return MakeError(EventInstance::Unwrap(args->event) ->OnReady(args->callback, args->user_arg)); }; @@ -894,6 +929,7 @@ iree_status_t EventInstance::OnReady(PJRT_Event_OnReadyCallback callback, void ExecutableInstance::BindApi(PJRT_Api* api) { api->PJRT_Executable_Destroy = +[](PJRT_Executable_Destroy_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Executable_Destroy"); delete ExecutableInstance::Unwrap(args->executable); return nullptr; }; @@ -930,11 +966,13 @@ void ExecutableInstance::BindApi(PJRT_Api* api) { }; api->PJRT_Executable_Execute = +[](PJRT_Executable_Execute_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Executable_Execute"); return MakeError( ExecutableInstance::Unwrap(args->executable)->BatchExecute(args)); }; api->PJRT_Executable_NumOutputs = +[](PJRT_Executable_NumOutputs_Args* args) -> PJRT_Error* { + IREE_TRACE_SCOPE0("PJRT_Executable_NumOutputs"); auto* exec = ExecutableInstance::Unwrap(args->executable); iree_host_size_t arg_count; iree_host_size_t result_count; @@ -967,6 +1005,7 @@ void ExecutableInstance::BindApi(PJRT_Api* api) { } iree_status_t ExecutableInstance::LoadAll() { + IREE_TRACE_SCOPE(); if (!loaded_executables_.empty()) return iree_ok_status(); std::vector new_list; @@ -1052,6 +1091,9 @@ iree_status_t ExecutableInstance::BatchExecute( // Make sure loaded. IREE_RETURN_IF_ERROR(LoadAll()); + // Timeline setup. + auto [wait_timepoint, signal_timepoint] = client_.AdvanceTimeline(); + // Initialize invocations. auto allocator = client_.host_allocator(); auto& loaded_execs = loaded_executables_; @@ -1059,6 +1101,7 @@ iree_status_t ExecutableInstance::BatchExecute( LoadedExecutable* dev_exe; iree::vm::ref inputs; iree::vm::ref outputs; + iree::vm::ref signal_fence; }; std::vector invs; invs.resize(args->num_devices); @@ -1079,6 +1122,26 @@ iree_status_t ExecutableInstance::BatchExecute( IREE_RETURN_IF_ERROR( iree_vm_list_push_ref_move(inv.inputs.get(), &bv_ref)); } + + // Add (wait, signal) fences as required by the async-external execution + // model. + // This is currently doing very simplistic in-execution-order scheduling. + // Alternatives would be to compute the wait timepoint based on some + // combination of max timepoints of input buffers, but this would put us + // in the position of defining scheduling semantics that need to be + // carefully specified (i.e. in multi device or memory constrained cases). + iree::vm::ref wait_fence; + iree_hal_semaphore_t* semaphore = + inv.dev_exe->device_instance->main_timeline(); + IREE_RETURN_IF_ERROR(iree_hal_fence_create_at( + semaphore, wait_timepoint, client_.host_allocator(), &wait_fence)); + IREE_RETURN_IF_ERROR(iree_hal_fence_create_at(semaphore, signal_timepoint, + client_.host_allocator(), + &inv.signal_fence)); + + // TODO: Is this the right way to move into the list? + iree_vm_list_push_ref_move(inv.inputs.get(), wait_fence); + iree_vm_list_push_ref_retain(inv.inputs.get(), inv.signal_fence); } // Issue invocations. @@ -1108,7 +1171,8 @@ iree_status_t ExecutableInstance::BatchExecute( IREE_ASSERT_ARGUMENT(ret_buffer_view); iree_hal_buffer_view_retain(ret_buffer_view); args->output_lists[dev_index][i] = - *(new BufferInstance(*inv.dev_exe->device_instance, ret_buffer_view)); + *(new BufferInstance(*inv.dev_exe->device_instance, ret_buffer_view, + /*ready_fence=*/inv.signal_fence)); } if (args->device_complete_events) { diff --git a/pjrt-plugin/iree/integrations/pjrt/common/api_impl.h b/pjrt-plugin/iree/integrations/pjrt/common/api_impl.h index 102f1280..168275c7 100644 --- a/pjrt-plugin/iree/integrations/pjrt/common/api_impl.h +++ b/pjrt-plugin/iree/integrations/pjrt/common/api_impl.h @@ -68,8 +68,11 @@ inline PJRT_Error* MakeError(iree_status_t status) { class BufferInstance { public: - BufferInstance(DeviceInstance& device, iree_hal_buffer_view_t* buffer_view) - : device_(device), buffer_view_(buffer_view) {} + BufferInstance(DeviceInstance& device, iree_hal_buffer_view_t* buffer_view, + iree::vm::ref ready_fence) + : device_(device), + buffer_view_(buffer_view), + ready_fence_(std::move(ready_fence)) {} ~BufferInstance(); operator PJRT_Buffer*() { return reinterpret_cast(this); } static BufferInstance* Unwrap(PJRT_Buffer* buffer) { @@ -94,7 +97,8 @@ class BufferInstance { private: DeviceInstance& device_; - iree::vm::ref buffer_view_; // Owned. + iree::vm::ref buffer_view_; + iree::vm::ref ready_fence_; // Various things require XLA's idea of shapes, layouts, etc. // We keep one around for such cases. std::optional cached_shape_; @@ -137,14 +141,19 @@ class DeviceInstance { EventInstance** out_done_with_host_buffer_event, BufferInstance** out_buffer); + // TODO(laurenzo): Eagerly set up device to allow simple access. iree_status_t GetHalDevice(iree_hal_device_t** out_device); + // Only valid once device opened. + iree_hal_semaphore_t* main_timeline() { return main_timeline_.get(); } + private: iree_status_t OpenDevice(); int client_id_; ClientInstance& client_; iree_hal_driver_t* driver_; // Owned by client. iree::vm::ref device_; + iree::vm::ref main_timeline_; iree_hal_device_info_t* info_; }; @@ -286,6 +295,9 @@ struct ClientInstance { // Returns false on failure (and sets error information on the compiler_job). virtual bool SetDefaultCompilerFlags(CompilerJob* compiler_job) = 0; + // Advances the timeline, returning (current, next) time point values. + std::tuple AdvanceTimeline(); + protected: iree_allocator_t host_allocator_; std::string cached_platform_name_; @@ -307,6 +319,16 @@ struct ClientInstance { // VM. iree::vm::ref vm_instance_; + + // Synchronization. + // We keep one global execution timeline across all devices. The management + // of this is currently somewhat primitive: we increment it by one for each + // invocation. Batch invocations (i.e. across multiple devices), only + // increment by one. In the future, additional parallelism could be plumbed + // up to the framework to allow different kinds of timeline management. + // Waiting on the current value of |execution_timeline_| will drain all + // scheduled work to date. + uint64_t execution_timeline_ = 0ull; }; //===----------------------------------------------------------------------===//