Skip to content

Commit

Permalink
Initial implementation of async execution.
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Jan 16, 2023
1 parent 8921e32 commit cd28a70
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 7 deletions.
1 change: 1 addition & 0 deletions pjrt-plugin/iree/integrations/pjrt/common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ iree_pjrt_cc_library(
deps = [
":compiler",
":debugging",
"@iree_core//runtime/src/iree/base:tracing",
"@iree_core//runtime/src/iree/hal",
"@iree_core//runtime/src/iree/modules/hal",
"@iree_core//runtime/src/iree/vm",
Expand Down
72 changes: 68 additions & 4 deletions pjrt-plugin/iree/integrations/pjrt/common/api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <iostream>
#include <optional>

#include "iree/base/tracing.h"
#include "iree/hal/api.h"

namespace iree::pjrt {
Expand Down Expand Up @@ -352,11 +353,13 @@ iree_status_t BufferInstance::GetXlaShape(xla::Shape** out_shape) {
void BufferInstance::BindApi(PJRT_Api* api) {
api->PJRT_Buffer_Destroy =
+[](PJRT_Buffer_Destroy_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Buffer_Destroy");
delete BufferInstance::Unwrap(args->buffer);
return nullptr;
};
api->PJRT_Buffer_OnDeviceTrimmedShape =
+[](PJRT_Buffer_OnDeviceTrimmedShape_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Buffer_OnDeviceTrimmedShape");
auto impl = [&]() -> iree_status_t {
// TODO: This function is terrible and not exposed properly to C.
// It is slated to be deleted...
Expand All @@ -382,6 +385,7 @@ void BufferInstance::BindApi(PJRT_Api* api) {
};
api->PJRT_Buffer_ToHostBuffer =
+[](PJRT_Buffer_ToHostBuffer_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Buffer_ToHostBuffer");
BufferInstance* buffer = BufferInstance::Unwrap(args->src);
if (!args->dst) {
// Size query.
Expand All @@ -395,10 +399,12 @@ void BufferInstance::BindApi(PJRT_Api* api) {
};
api->PJRT_Buffer_OnDeviceSizeInBytes =
+[](PJRT_Buffer_OnDeviceSizeInBytes_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Buffer_OnDeviceSizeInBytes");
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Buffer_OnDeviceSizeInBytes"));
};
api->PJRT_Buffer_Delete = +[](PJRT_Buffer_Delete_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Buffer_Delete");
return MakeError(
iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_Delete"));
};
Expand All @@ -409,6 +415,7 @@ void BufferInstance::BindApi(PJRT_Api* api) {
};
api->PJRT_Buffer_CopyToDevice =
+[](PJRT_Buffer_CopyToDevice_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Buffer_CopyToDevice");
return MakeError(iree_make_status(IREE_STATUS_UNIMPLEMENTED,
"PJRT_Buffer_CopyToDevice"));
};
Expand All @@ -423,6 +430,7 @@ void BufferInstance::BindApi(PJRT_Api* api) {
};
api->PJRT_Buffer_ReadyEvent =
+[](PJRT_Buffer_ReadyEvent_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Buffer_ReadyEvent");
return MakeError(
iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Buffer_ReadyEvent"));
};
Expand All @@ -440,6 +448,12 @@ iree_status_t BufferInstance::GetHostSizeInBytes(iree_host_size_t* host_size) {

iree_status_t BufferInstance::CopyToHost(void* dst, iree_host_size_t dst_size,
EventInstance** done_event) {
// TODO: Make unconditional to have a ready fence.
if (ready_fence_) {
IREE_RETURN_IF_ERROR(
iree_hal_fence_wait(ready_fence_.get(), iree_infinite_timeout()));
}

// TODO: Do an async transfer on a transfer queue like a grown up.
iree_hal_device_t* hal_device;
IREE_RETURN_IF_ERROR(device_.GetHalDevice(&hal_device));
Expand Down Expand Up @@ -504,10 +518,13 @@ void DeviceInstance::BindApi(PJRT_Api* api) {

iree_status_t DeviceInstance::OpenDevice() {
if (device_) return iree_ok_status();
return iree_hal_driver_create_device_by_id(
IREE_RETURN_IF_ERROR(iree_hal_driver_create_device_by_id(
driver_, /*device_id=*/info_->device_id,
/*param_count=*/0, /*params=*/nullptr, client_.host_allocator(),
&device_);
&device_));
IREE_RETURN_IF_ERROR(
iree_hal_semaphore_create(device_.get(), 0ull, &main_timeline_));
return iree_ok_status();
}

iree_status_t DeviceInstance::HostBufferToDevice(
Expand Down Expand Up @@ -583,7 +600,7 @@ iree_status_t DeviceInstance::HostBufferToDevice(
*out_done_with_host_buffer_event = new EventInstance();

// Construct and return a BufferInstance.
*out_buffer = new BufferInstance(*this, buffer_view);
*out_buffer = new BufferInstance(*this, buffer_view, /*ready_fence=*/nullptr);

return iree_ok_status();
}
Expand Down Expand Up @@ -621,6 +638,7 @@ void ClientInstance::BindApi(PJRT_Api* api) {
// PJRT_Client_Create is polymorphic
api->PJRT_Client_Destroy =
+[](PJRT_Client_Destroy_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Client_Destroy");
delete ClientInstance::Unwrap(args->client);
return nullptr;
};
Expand Down Expand Up @@ -674,6 +692,7 @@ void ClientInstance::BindApi(PJRT_Api* api) {
};
api->PJRT_Client_Compile =
+[](PJRT_Client_Compile_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Client_Compile");
// TODO: It is not great that we only get a client here vs a list of
// devices to consider (or something). The issue is that systems often
// have unrelated devices that will not actually be scheduled and those
Expand All @@ -698,6 +717,7 @@ void ClientInstance::BindApi(PJRT_Api* api) {
};
api->PJRT_Client_BufferFromHostBuffer =
+[](PJRT_Client_BufferFromHostBuffer_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Client_BufferFromHostBuffer");
auto status =
DeviceInstance::Unwrap(args->device)
->HostBufferToDevice(
Expand Down Expand Up @@ -799,6 +819,9 @@ PJRT_Error* ClientInstance::Compile(PJRT_Program* program,
if (!job->SetFlag("--iree-input-type=mhlo")) {
return MakeCompilerError();
}
if (!job->SetFlag("--iree-execution-model=async-external")) {
return MakeCompilerError();
}
if (!SetDefaultCompilerFlags(job.get())) {
return MakeCompilerError();
}
Expand Down Expand Up @@ -854,27 +877,39 @@ iree_status_t ClientInstance::PopulateVMModules(
return iree_ok_status();
}

std::tuple<uint64_t, uint64_t> ClientInstance::AdvanceTimeline() {
uint64_t current = execution_timeline_;
uint64_t next = current + 1;
execution_timeline_ = next;
return std::make_tuple(current, next);
}

//===----------------------------------------------------------------------===//
// EventInstance
//===----------------------------------------------------------------------===//

void EventInstance::BindApi(PJRT_Api* api) {
api->PJRT_Event_Destroy = +[](PJRT_Event_Destroy_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Event_Destroy");
delete EventInstance::Unwrap(args->event);
return nullptr;
};
api->PJRT_Event_IsReady = +[](PJRT_Event_IsReady_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Event_IsReady");
args->is_ready = EventInstance::Unwrap(args->event)->is_ready();
return nullptr;
};
api->PJRT_Event_Error = +[](PJRT_Event_Error_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Event_Error");
return (PJRT_Error*)EventInstance::Unwrap(args->event)->error();
};
api->PJRT_Event_Await = +[](PJRT_Event_Await_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Event_Await");
return MakeError(
iree_make_status(IREE_STATUS_UNIMPLEMENTED, "PJRT_Event_Await"));
};
api->PJRT_Event_OnReady = +[](PJRT_Event_OnReady_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Event_OnReady");
return MakeError(EventInstance::Unwrap(args->event)
->OnReady(args->callback, args->user_arg));
};
Expand All @@ -894,6 +929,7 @@ iree_status_t EventInstance::OnReady(PJRT_Event_OnReadyCallback callback,
void ExecutableInstance::BindApi(PJRT_Api* api) {
api->PJRT_Executable_Destroy =
+[](PJRT_Executable_Destroy_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Executable_Destroy");
delete ExecutableInstance::Unwrap(args->executable);
return nullptr;
};
Expand Down Expand Up @@ -930,11 +966,13 @@ void ExecutableInstance::BindApi(PJRT_Api* api) {
};
api->PJRT_Executable_Execute =
+[](PJRT_Executable_Execute_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Executable_Execute");
return MakeError(
ExecutableInstance::Unwrap(args->executable)->BatchExecute(args));
};
api->PJRT_Executable_NumOutputs =
+[](PJRT_Executable_NumOutputs_Args* args) -> PJRT_Error* {
IREE_TRACE_SCOPE0("PJRT_Executable_NumOutputs");
auto* exec = ExecutableInstance::Unwrap(args->executable);
iree_host_size_t arg_count;
iree_host_size_t result_count;
Expand Down Expand Up @@ -967,6 +1005,7 @@ void ExecutableInstance::BindApi(PJRT_Api* api) {
}

iree_status_t ExecutableInstance::LoadAll() {
IREE_TRACE_SCOPE();
if (!loaded_executables_.empty()) return iree_ok_status();

std::vector<LoadedExecutable> new_list;
Expand Down Expand Up @@ -1052,13 +1091,17 @@ iree_status_t ExecutableInstance::BatchExecute(
// Make sure loaded.
IREE_RETURN_IF_ERROR(LoadAll());

// Timeline setup.
auto [wait_timepoint, signal_timepoint] = client_.AdvanceTimeline();

// Initialize invocations.
auto allocator = client_.host_allocator();
auto& loaded_execs = loaded_executables_;
struct Invocation {
LoadedExecutable* dev_exe;
iree::vm::ref<iree_vm_list_t> inputs;
iree::vm::ref<iree_vm_list_t> outputs;
iree::vm::ref<iree_hal_fence_t> signal_fence;
};
std::vector<Invocation> invs;
invs.resize(args->num_devices);
Expand All @@ -1079,6 +1122,26 @@ iree_status_t ExecutableInstance::BatchExecute(
IREE_RETURN_IF_ERROR(
iree_vm_list_push_ref_move(inv.inputs.get(), &bv_ref));
}

// Add (wait, signal) fences as required by the async-external execution
// model.
// This is currently doing very simplistic in-execution-order scheduling.
// Alternatives would be to compute the wait timepoint based on some
// combination of max timepoints of input buffers, but this would put us
// in the position of defining scheduling semantics that need to be
// carefully specified (i.e. in multi device or memory constrained cases).
iree::vm::ref<iree_hal_fence_t> wait_fence;
iree_hal_semaphore_t* semaphore =
inv.dev_exe->device_instance->main_timeline();
IREE_RETURN_IF_ERROR(iree_hal_fence_create_at(
semaphore, wait_timepoint, client_.host_allocator(), &wait_fence));
IREE_RETURN_IF_ERROR(iree_hal_fence_create_at(semaphore, signal_timepoint,
client_.host_allocator(),
&inv.signal_fence));

// TODO: Is this the right way to move into the list?
iree_vm_list_push_ref_move(inv.inputs.get(), wait_fence);
iree_vm_list_push_ref_retain(inv.inputs.get(), inv.signal_fence);
}

// Issue invocations.
Expand Down Expand Up @@ -1108,7 +1171,8 @@ iree_status_t ExecutableInstance::BatchExecute(
IREE_ASSERT_ARGUMENT(ret_buffer_view);
iree_hal_buffer_view_retain(ret_buffer_view);
args->output_lists[dev_index][i] =
*(new BufferInstance(*inv.dev_exe->device_instance, ret_buffer_view));
*(new BufferInstance(*inv.dev_exe->device_instance, ret_buffer_view,
/*ready_fence=*/inv.signal_fence));
}

if (args->device_complete_events) {
Expand Down
28 changes: 25 additions & 3 deletions pjrt-plugin/iree/integrations/pjrt/common/api_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ inline PJRT_Error* MakeError(iree_status_t status) {

class BufferInstance {
public:
BufferInstance(DeviceInstance& device, iree_hal_buffer_view_t* buffer_view)
: device_(device), buffer_view_(buffer_view) {}
BufferInstance(DeviceInstance& device, iree_hal_buffer_view_t* buffer_view,
iree::vm::ref<iree_hal_fence_t> ready_fence)
: device_(device),
buffer_view_(buffer_view),
ready_fence_(std::move(ready_fence)) {}
~BufferInstance();
operator PJRT_Buffer*() { return reinterpret_cast<PJRT_Buffer*>(this); }
static BufferInstance* Unwrap(PJRT_Buffer* buffer) {
Expand All @@ -94,7 +97,8 @@ class BufferInstance {

private:
DeviceInstance& device_;
iree::vm::ref<iree_hal_buffer_view_t> buffer_view_; // Owned.
iree::vm::ref<iree_hal_buffer_view_t> buffer_view_;
iree::vm::ref<iree_hal_fence_t> ready_fence_;
// Various things require XLA's idea of shapes, layouts, etc.
// We keep one around for such cases.
std::optional<xla::Shape> cached_shape_;
Expand Down Expand Up @@ -137,14 +141,19 @@ class DeviceInstance {
EventInstance** out_done_with_host_buffer_event,
BufferInstance** out_buffer);

// TODO(laurenzo): Eagerly set up device to allow simple access.
iree_status_t GetHalDevice(iree_hal_device_t** out_device);

// Only valid once device opened.
iree_hal_semaphore_t* main_timeline() { return main_timeline_.get(); }

private:
iree_status_t OpenDevice();
int client_id_;
ClientInstance& client_;
iree_hal_driver_t* driver_; // Owned by client.
iree::vm::ref<iree_hal_device_t> device_;
iree::vm::ref<iree_hal_semaphore_t> main_timeline_;
iree_hal_device_info_t* info_;
};

Expand Down Expand Up @@ -286,6 +295,9 @@ struct ClientInstance {
// Returns false on failure (and sets error information on the compiler_job).
virtual bool SetDefaultCompilerFlags(CompilerJob* compiler_job) = 0;

// Advances the timeline, returning (current, next) time point values.
std::tuple<uint64_t, uint64_t> AdvanceTimeline();

protected:
iree_allocator_t host_allocator_;
std::string cached_platform_name_;
Expand All @@ -307,6 +319,16 @@ struct ClientInstance {

// VM.
iree::vm::ref<iree_vm_instance_t> vm_instance_;

// Synchronization.
// We keep one global execution timeline across all devices. The management
// of this is currently somewhat primitive: we increment it by one for each
// invocation. Batch invocations (i.e. across multiple devices), only
// increment by one. In the future, additional parallelism could be plumbed
// up to the framework to allow different kinds of timeline management.
// Waiting on the current value of |execution_timeline_| will drain all
// scheduled work to date.
uint64_t execution_timeline_ = 0ull;
};

//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit cd28a70

Please sign in to comment.