diff --git a/get_deps.sh b/get_deps.sh index a5d0f07e1..11ac25ddd 100755 --- a/get_deps.sh +++ b/get_deps.sh @@ -134,7 +134,8 @@ if [[ $WITH_TF != 0 ]]; then mkdir $LIBTENSORFLOW.x tar xf $LIBTF_ARCHIVE --no-same-owner -C $LIBTENSORFLOW.x mv $LIBTENSORFLOW.x $LIBTENSORFLOW - + chmod u+w $LIBTENSORFLOW/include/tensorflow/c/eager/c_api_experimental.h + cp ../../opt/patches/c_api_experimental.h $LIBTENSORFLOW/include/tensorflow/c/eager echo "Done." else echo "TensorFlow is in place." diff --git a/opt/patches/c_api_experimental.h b/opt/patches/c_api_experimental.h new file mode 100755 index 000000000..f971f7d36 --- /dev/null +++ b/opt/patches/c_api_experimental.h @@ -0,0 +1,568 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ +#define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This +// is for performance optimization by reusing an exiting unused op rather than +// creating a new op every time. If `raw_device_name` is `NULL` or empty, it +// does not set the device name. If it's not `NULL`, then it attempts to parse +// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster +// than separately calling it because if the existing op has the same +// `raw_device_name`, it skips parsing and just leave as it is. +TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset, + const char* op_or_function_name, + const char* raw_device_name, + TF_Status* status); + +// Enables only graph collection in RunMetadata on the functions executed from +// this context. +TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx); + +// Disables only graph collection in RunMetadata on the functions executed from +// this context. +TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx); + +// TODO(fishx): Move these monitoring APIs into a separate file. +// ----------------------------------------------------------------------------- +// Monitoring Counter APIs. +// These APIs de-templated monitoring Counter for swig. + +typedef struct TFE_MonitoringCounterCell TFE_MonitoringCounterCell; + +// Atomically increments the value of the cell. The value must be non-negative. +TF_CAPI_EXPORT extern void TFE_MonitoringCounterCellIncrementBy( + TFE_MonitoringCounterCell* cell, int64_t value); + +// Retrieves the current value of the cell. +TF_CAPI_EXPORT extern int64_t TFE_MonitoringCounterCellValue( + TFE_MonitoringCounterCell* cell); + +// APIs for Counter without label. +typedef struct TFE_MonitoringCounter0 TFE_MonitoringCounter0; +// Returns a new Counter metric object. The caller should manage lifetime of +// the object. Using duplicate metric name will crash the program with fatal +// error. +TF_CAPI_EXPORT extern TFE_MonitoringCounter0* TFE_MonitoringNewCounter0( + const char* name, TF_Status* status, const char* description); +// Deletes the Counter object. +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter0( + TFE_MonitoringCounter0* counter); +// Retrieves the cell from the Counter object. The Counter object will manage +// lifetime of the cell. +TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0( + TFE_MonitoringCounter0* counter); + +// APIs for Counter with 1 label. +typedef struct TFE_MonitoringCounter1 TFE_MonitoringCounter1; +TF_CAPI_EXPORT extern TFE_MonitoringCounter1* TFE_MonitoringNewCounter1( + const char* name, TF_Status* status, const char* description, + const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter1( + TFE_MonitoringCounter1* counter); +TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1( + TFE_MonitoringCounter1* counter, const char* label1); + +// APIs for Counter with 2 labels. +typedef struct TFE_MonitoringCounter2 TFE_MonitoringCounter2; +TF_CAPI_EXPORT extern TFE_MonitoringCounter2* TFE_MonitoringNewCounter2( + const char* name, TF_Status* status, const char* description, + const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2( + TFE_MonitoringCounter2* counter); +TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2( + TFE_MonitoringCounter2* counter, const char* label1, const char* label2); + +// ----------------------------------------------------------------------------- +// Monitoring Gauge APIs. +// These APIs de-templated monitoring Gauge for swig. + +typedef struct TFE_MonitoringIntGaugeCell TFE_MonitoringIntGaugeCell; + +// Atomically set the value of the cell. +TF_CAPI_EXPORT extern void TFE_MonitoringIntGaugeCellSet( + TFE_MonitoringIntGaugeCell* cell, int64_t value); + +// Retrieves the current value of the cell. +TF_CAPI_EXPORT extern int64_t TFE_MonitoringIntGaugeCellValue( + TFE_MonitoringIntGaugeCell* cell); + +// APIs for Int Gauge without label. +typedef struct TFE_MonitoringIntGauge0 TFE_MonitoringIntGauge0; +TF_CAPI_EXPORT extern TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0( + const char* name, TF_Status* out_status, const char* description); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge0( + TFE_MonitoringIntGauge0* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell* +TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0* gauge); + +// APIs for Int Gauge with 1 label. +typedef struct TFE_MonitoringIntGauge1 TFE_MonitoringIntGauge1; +TF_CAPI_EXPORT extern TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1( + const char* name, TF_Status* out_status, const char* description, + const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge1( + TFE_MonitoringIntGauge1* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell* +TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1* gauge, + const char* label1); + +// APIs for Int Gauge with 2 label. +typedef struct TFE_MonitoringIntGauge2 TFE_MonitoringIntGauge2; +TF_CAPI_EXPORT extern TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2( + const char* name, TF_Status* out_status, const char* description, + const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge2( + TFE_MonitoringIntGauge2* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell* +TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2* gauge, + const char* label1, const char* label2); + +typedef struct TFE_MonitoringStringGaugeCell TFE_MonitoringStringGaugeCell; +TF_CAPI_EXPORT extern void TFE_MonitoringStringGaugeCellSet( + TFE_MonitoringStringGaugeCell* cell, const char* value); +// Retrieves the string value and saves it in buffer. +TF_CAPI_EXPORT extern const void TFE_MonitoringStringGaugeCellValue( + TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf); + +// APIs for String Gauge without label. +typedef struct TFE_MonitoringStringGauge0 TFE_MonitoringStringGauge0; +TF_CAPI_EXPORT extern TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0( + const char* name, TF_Status* out_status, const char* description); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge0( + TFE_MonitoringStringGauge0* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* +TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0* gauge); + +// APIs for String Gauge with 1 label. +typedef struct TFE_MonitoringStringGauge1 TFE_MonitoringStringGauge1; +TF_CAPI_EXPORT extern TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1( + const char* name, TF_Status* out_status, const char* description, + const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge1( + TFE_MonitoringStringGauge1* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* +TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1* gauge, + const char* label1); + +// APIs for String Gauge with 2 label. +typedef struct TFE_MonitoringStringGauge2 TFE_MonitoringStringGauge2; +TF_CAPI_EXPORT extern TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2( + const char* name, TF_Status* out_status, const char* description, + const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge2( + TFE_MonitoringStringGauge2* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell* +TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2* gauge, + const char* label1, const char* label2); + +typedef struct TFE_MonitoringBoolGaugeCell TFE_MonitoringBoolGaugeCell; +TF_CAPI_EXPORT extern void TFE_MonitoringBoolGaugeCellSet( + TFE_MonitoringBoolGaugeCell* cell, bool value); +TF_CAPI_EXPORT extern bool TFE_MonitoringBoolGaugeCellValue( + TFE_MonitoringBoolGaugeCell* cell); + +// APIs for Bool Gauge without label. +typedef struct TFE_MonitoringBoolGauge0 TFE_MonitoringBoolGauge0; +TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0( + const char* name, TF_Status* out_status, const char* description); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge0( + TFE_MonitoringBoolGauge0* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell* +TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0* gauge); + +// APIs for Bool Gauge with 1 label. +typedef struct TFE_MonitoringBoolGauge1 TFE_MonitoringBoolGauge1; +TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1( + const char* name, TF_Status* out_status, const char* description, + const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge1( + TFE_MonitoringBoolGauge1* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell* +TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1* gauge, + const char* label1); + +// APIs for Bool Gauge with 2 label. +typedef struct TFE_MonitoringBoolGauge2 TFE_MonitoringBoolGauge2; +TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2( + const char* name, TF_Status* out_status, const char* description, + const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge2( + TFE_MonitoringBoolGauge2* gauge); +TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell* +TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2* gauge, + const char* label1, const char* label2); + +// ----------------------------------------------------------------------------- +// Monitoring Sampler APIs. +// These APIs de-templated monitoring Sampler for swig. + +typedef struct TFE_MonitoringSamplerCell TFE_MonitoringSamplerCell; + +// Atomically add the value of the cell. +TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellAdd( + TFE_MonitoringSamplerCell* cell, double value); + +// Retrieves the current value of the cell. The return value is a HistogramProto +// saved in buffer. +TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellValue( + TFE_MonitoringSamplerCell* cell, TF_Buffer* buf); + +// APIs for sampler buckets +typedef struct TFE_MonitoringBuckets TFE_MonitoringBuckets; +TF_CAPI_EXPORT extern TFE_MonitoringBuckets* +TFE_MonitoringNewExponentialBuckets(double scale, double growth_factor, + int bucket_count); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBuckets( + TFE_MonitoringBuckets* buckets); + +// APIs for Sampler without label. +typedef struct TFE_MonitoringSampler0 TFE_MonitoringSampler0; +TF_CAPI_EXPORT extern TFE_MonitoringSampler0* TFE_MonitoringNewSampler0( + const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status, + const char* description); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler0( + TFE_MonitoringSampler0* sampler); +TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0( + TFE_MonitoringSampler0* sampler); + +// APIs for Sampler with 1 label. +typedef struct TFE_MonitoringSampler1 TFE_MonitoringSampler1; +TF_CAPI_EXPORT extern TFE_MonitoringSampler1* TFE_MonitoringNewSampler1( + const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status, + const char* description, const char* label1); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler1( + TFE_MonitoringSampler1* sampler); +TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1( + TFE_MonitoringSampler1* sampler, const char* label1); + +// APIs for Sampler with 2 label. +typedef struct TFE_MonitoringSampler2 TFE_MonitoringSampler2; +TF_CAPI_EXPORT extern TFE_MonitoringSampler2* TFE_MonitoringNewSampler2( + const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status, + const char* description, const char* label1, const char* label2); +TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2( + TFE_MonitoringSampler2* sampler); +TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( + TFE_MonitoringSampler2* sampler, const char* label1, const char* label2); + +// Sets whether to copy the remote inputs of a function lazily. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy( + TFE_ContextOptions*, bool lazy_copy); + +// Sets whether to use TFRT +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*, + bool use_tfrt); + +// Returns the context_id from the EagerContext which is used by the +// EagerService to maintain consistency between client and worker. The +// context_id is initialized with a dummy value and is later set when the worker +// is initialized (either locally or remotely). The context_id can change during +// the process lifetime although this should cause the worker to be +// reinitialized (e.g. cleared caches) as well. +TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx); + +// ----------------------------------------------------------------------------- +// Cancellation APIs. + +typedef struct TFE_CancellationManager TFE_CancellationManager; +TF_CAPI_EXPORT extern TFE_CancellationManager* TFE_NewCancellationManager(); +TF_CAPI_EXPORT extern bool TFE_CancellationManagerIsCancelled( + TFE_CancellationManager*); +TF_CAPI_EXPORT extern void TFE_CancellationManagerStartCancel( + TFE_CancellationManager*); +TF_CAPI_EXPORT extern void TFE_DeleteCancellationManager( + TFE_CancellationManager*); + +// Associates the given `cancellation_manager` with `op`, so that invoking +// `TFE_CancellationManagerStartCancel(cancellation_manager)` will cancel the +// execution of `op`. +typedef struct TFE_CancellationManager TFE_CancellationManager; +TF_CAPI_EXPORT extern void TFE_OpSetCancellationManager( + TFE_Op* op, TFE_CancellationManager* cancellation_manager, + TF_Status* status); + +// ----------------------------------------------------------------------------- +// Eager Executor APIs. +typedef struct TFE_Executor TFE_Executor; + +// Creates a new eager Executor. Nodes in one executor are guaranteed to be +// executed in sequence. Assigning nodes to different executors allows executing +// nodes in parallel. +TF_CAPI_EXPORT extern TFE_Executor* TFE_NewExecutor(bool is_async); + +// Deletes the eager Executor without waiting for enqueued nodes. Please call +// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to +// make sure all nodes are finished. +TF_CAPI_EXPORT extern void TFE_DeleteExecutor(TFE_Executor*); + +// Returns true if the executor is in async mode. +TF_CAPI_EXPORT extern bool TFE_ExecutorIsAsync(TFE_Executor*); + +// Causes the calling thread to block till all ops dispatched in this executor +// have been executed. Note that "execution" here refers to kernel execution / +// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee +// that lower level device queues (like GPU streams) have been flushed. +// +// This call may not block for execution of ops enqueued concurrently with this +// call. +TF_CAPI_EXPORT extern void TFE_ExecutorWaitForAllPendingNodes( + TFE_Executor*, TF_Status* status); + +// When an error happens, any pending operations are discarded and newly issued +// ops return an error. This call clears the error state and re-enables +// execution of newly issued ops. +// +// Note that outputs of discarded ops remain in a corrupt state and should not +// be used for future calls. +// TODO(agarwal): mark the affected handles and raise errors if they are used. +TF_CAPI_EXPORT extern void TFE_ExecutorClearError(TFE_Executor*); + +// Sets a custom Executor for current thread. All nodes created by this thread +// will be added to this Executor. It will override current executor. +TF_CAPI_EXPORT extern void TFE_ContextSetExecutorForThread(TFE_Context*, + TFE_Executor*); + +// Returns the Executor for current thread. +TF_CAPI_EXPORT extern TFE_Executor* TFE_ContextGetExecutorForThread( + TFE_Context*); + +// ----------------------------------------------------------------------------- +// Dynamic cluster API. + +// Update an existing context with a new set of servers defined in a ServerDef +// proto. Servers can be added to and removed from the list of remote workers +// in the context. New set of servers identified by the ServerDef must be up +// when the context is updated. +// +// This API is for experimental usage and may be subject to change. +TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx, + int keep_alive_secs, + const void* proto, + size_t proto_len, + TF_Status* status); + +// Checks whether a remote worker is alive or not. This will return true even if +// the context doesn't exist on the remote worker. +TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx, + const char* worker_name, + TF_Status* status); + +// Sync pending nodes in local executors (including the context default executor +// and thread executors) and streaming requests to remote executors, and get the +// combined status. +TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, + TF_Status* status); + +// This function will block till the operation that produces `h` has +// completed. This is only valid on local TFE_TensorHandles. The pointer +// returned will be on the device in which the TFE_TensorHandle resides (so e.g. +// for a GPU tensor this will return a pointer to GPU memory). The pointer is +// only guaranteed to be valid until TFE_DeleteTensorHandle is called on this +// TensorHandle. Only supports POD data types. +TF_CAPI_EXPORT extern void* TFE_TensorHandleDevicePointer(TFE_TensorHandle*, + TF_Status*); + +// This function will block till the operation that produces `h` has +// completed. This is only valid on local TFE_TensorHandles. Returns the size in +// bytes of the memory pointed to by the device pointer returned above. +TF_CAPI_EXPORT extern size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle*, + TF_Status*); + +// Creates a new TensorHandle from memory residing in device_name. Takes +// ownership of the memory, and will call deleter to release it after TF +// no longer needs it or in case of error. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( + TFE_Context* ctx, const char* device_name, TF_DataType, const int64_t* dims, + int num_dims, void* data, size_t len, + void (*deallocator)(void* data, size_t len, void* arg), + void* deallocator_arg, TF_Status* status); + +// Retrieves the address space (i.e. job, replia, task) of the local host and +// saves it in the buffer. +TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx, + TF_Buffer* buf); + +// APIs for generically dealing with op attributes (e.g. when forwarding them +// through custom device implementations). +// +// TODO(allenl): Currently these are black boxes, but we should have some way to +// inspect values. This would let people e.g. copy over most attributes and then +// modify some based on their values. + +// A reference to an op's name -> attribute mapping +typedef struct TFE_OpAttrs TFE_OpAttrs; + +// Fetch a reference to `op`'s attributes. The returned reference is only valid +// while `op` is alive. +TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op); +// Add attributes in `attrs` to `op`. +// +// Does not overwrite or update existing attributes, but adds new ones. +TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs); + +// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`), +// containing the op name and a map of its attributes. +TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, + TF_Buffer* buf, + TF_Status* status); + +// Set an op's attribute from a serialized AttrValue protocol buffer. +// +// Analogous to TF_SetAttrValueProto for building graph operations. +TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op, + const char* attr_name, + const void* proto, + size_t proto_len, + TF_Status* status); + +// TODO(b/166642410): It would be nice, for custom devices and for other users, +// to have a non-string representation of devices (TF_Device) extracted from +// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc. + +#define TFE_CUSTOM_DEVICE_VERSION 3 + +// Struct to be filled in +typedef struct TFE_CustomDevice { + int version; // = TFE_CUSTOM_DEVICE_VERSION; + // Method to copy a tensor to the custom device. + TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context, + TFE_TensorHandle* tensor, + TF_Status* status, + void* device_info); // = nullptr; + + // Method to copy a tensor from the custom device to a target device. + TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context, + TFE_TensorHandle* tensor, + const char* target_device_name, + TF_Status* status, + void* device_info); + + // Method to execute an operation. + // + // Arguments provide enough information to reconstruct the original `TFE_Op`, + // or construct a transformed version, by inspecting the passed `op`. + // + // TFE_OpGetDevice(op) records the original placement of the operation. It may + // be an empty string if no device was explicitly requested, but will + // otherwise be the name of this custom device. Ops are placed onto a custom + // device if any of their inputs are on that custom device, but custom devices + // are free to set a bad status in order to require explicit placement. + void (*execute)(const TFE_Op* op, int* num_outputs, + TFE_TensorHandle** outputs, TF_Status* s, void* device_info); + + // Method to delete a device. + void (*delete_device)(void* device_info); +} TFE_CustomDevice; + +// Registers a custom device for use with eager execution. +// +// Eager operations may be placed on this device, e.g. `with +// tf.device("CUSTOM"):` from Python if `device_name` for this call is +// "/job:localhost/replica:0/task:0/device:CUSTOM:0". +// +// The custom device defines copy operations for moving TensorHandles on and +// off, and an an execution operation for named operations. Often execution will +// simply wrap op execution on one or more physical devices. +// +// device_info is an opaque caller-defined type stored with the custom device +// which is passed to the functions referenced in the TFE_CustomDevice struct +// `device` (execute, delete_device, etc.). It can for example contain the +// names of wrapped devices. +// +// There are currently no graph semantics implemented for registered custom +// devices, so executing tf.functions which contain operations placed on custom +// devices will fail. +// +// `device_name` must not name an existing physical or custom device. It must +// follow the format: +// +// /job:/replica:/task:/device:: +// +// If the device is successfully registered, `status` is set to TF_OK. Otherwise +// the device is not usable. In case of a bad status, `device.delete_device` is +// still called on `device_info` (i.e. the caller does not retain ownership). +// +// This API is highly experimental, and in particular is expected to change when +// it starts supporting operations with attributes and when tf.function support +// is added. +TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx, + TFE_CustomDevice device, + const char* device_name, + void* device_info, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx, + const char* function_name, + TF_Buffer* buf, + TF_Status* status); + +// Allocate and return a new Tensor on the host. +// +// The caller must set the Tensor values by writing them to the pointer returned +// by TF_TensorData with length TF_TensorByteSize. +TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, + TF_DataType dtype, + const int64_t* dims, + int num_dims, + TF_Status* status); + +// Given a Tensor, wrap it with a TensorHandle +// +// Similar to TFE_NewTensorHandle, but includes a pointer to the TFE_Context. +// The context should be identical to that of the Tensor. +TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor( + TFE_Context* ctx, TF_Tensor* t, TF_Status* status); + +// Create a packed TensorHandle with the given list of TensorHandles. +// If `handles` are on the same device, assign the same device to the packed +// handle; if `handles` are on different deivces, assign a CompositeDevice to +// it. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle( + TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles, + TF_Status* status); + +// Configure soft device placement policy for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + +// Configure device placement policy logging for the eager executor. Note this +// policy is applied to any subsequent op executions. +TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, + unsigned char enable, + TF_Status* status); + +// Returns the device type of the operation that produced `h`. +TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType( + TFE_TensorHandle* h, TF_Status* status); + +// Returns the device ID of the operation that produced `h`. +TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, + TF_Status* status); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_ diff --git a/src/DAG/dag.c b/src/DAG/dag.c index 0205f9c8d..fb07ee6f8 100644 --- a/src/DAG/dag.c +++ b/src/DAG/dag.c @@ -560,7 +560,9 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc return REDISMODULE_OK; } - if (RAI_GetErrorCode(rinfo->err) == RAI_EDAGRUN) { + if (RAI_GetErrorCode(rinfo->err) == RAI_EDAGRUN || + RAI_GetErrorCode(rinfo->err) == RAI_EMODELRUN || + RAI_GetErrorCode(rinfo->err) == RAI_ESCRIPTRUN) { RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(rinfo->err)); return REDISMODULE_OK; } diff --git a/src/backends/tensorflow.c b/src/backends/tensorflow.c index 72c4a7368..0007af5d1 100644 --- a/src/backends/tensorflow.c +++ b/src/backends/tensorflow.c @@ -6,6 +6,11 @@ #include "model.h" #include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" + +#define RAI_TF_FN_NAME "rai_tf_forward" int RAI_InitBackendTF(int (*get_api_fn)(const char *, void *)) { get_api_fn("RedisModule_Alloc", ((void **)&RedisModule_Alloc)); @@ -17,199 +22,324 @@ int RAI_InitBackendTF(int (*get_api_fn)(const char *, void *)) { return REDISMODULE_OK; } -TF_DataType RAI_GetTFDataTypeFromDL(DLDataType dtype) { +struct TFDLManagedTensorCtx { + TFE_TensorHandle *reference; + int64_t ndim; + int64_t *shape; + int64_t *strides; + DLManagedTensor tensor; +}; +typedef struct TFDLManagedTensorCtx TFDLManagedTensorCtx; + +TFDLManagedTensorCtx *TFDLManagedTensorCtx_Create(TFE_TensorHandle *h, TF_Status *status) { + TFDLManagedTensorCtx *ctx = RedisModule_Alloc(sizeof(TFDLManagedTensorCtx)); + ctx->reference = h; + ctx->ndim = TFE_TensorHandleNumDims(h, status); + ctx->shape = RedisModule_Calloc(ctx->ndim, sizeof(int64_t)); + ctx->strides = RedisModule_Calloc(ctx->ndim, sizeof(int64_t)); + for (int i = 0; i < ctx->ndim; i++) { + ctx->shape[i] = TFE_TensorHandleDim(h, i, status); + ctx->strides[i] = 1; + } + for (int i = ctx->ndim - 2; i >= 0; i--) { + ctx->strides[i] = ctx->shape[i + 1] * ctx->strides[i + 1]; + } + return ctx; +} - if (dtype.code == kDLFloat) { - switch (dtype.bits) { - case 32: - return TF_FLOAT; - break; - case 64: - return TF_DOUBLE; - break; - default: - return 0; - } - } else if (dtype.code == kDLInt) { - switch (dtype.bits) { - case 8: - return TF_INT8; - break; - case 16: - return TF_INT16; - break; - case 32: - return TF_INT32; - break; - case 64: - return TF_INT64; - break; - default: - return 0; - } - } else if (dtype.code == kDLUInt) { - switch (dtype.bits) { - case 8: - return TF_UINT8; - break; - case 16: - return TF_UINT16; - break; - default: - return 0; - } - } - return 0; +void TFDLManagedTensorCtx_Free(TFDLManagedTensorCtx *ctx) { + RedisModule_Free(ctx->shape); + RedisModule_Free(ctx->strides); + RedisModule_Free(ctx); } -DLDataType RAI_GetDLDataTypeFromTF(TF_DataType dtype) { - switch (dtype) { +void DLManagedTensorDeleter(DLManagedTensor *arg) { + TFDLManagedTensorCtx *owner = (TFDLManagedTensorCtx *)(arg->manager_ctx); + TFE_DeleteTensorHandle(owner->reference); + TFDLManagedTensorCtx_Free(owner); +} + +DLDataType GetDLDataType(TF_DataType data_type, TF_Status *status, RAI_Error *error) { + DLDataType dtype; + dtype.lanes = 1; + dtype.bits = TF_DataTypeSize(data_type) * 8; + switch (data_type) { + case TF_HALF: case TF_FLOAT: - return (DLDataType){.code = kDLFloat, .bits = 32, .lanes = 1}; case TF_DOUBLE: - return (DLDataType){.code = kDLFloat, .bits = 64, .lanes = 1}; + dtype.code = kDLFloat; + break; case TF_INT8: - return (DLDataType){.code = kDLInt, .bits = 8, .lanes = 1}; case TF_INT16: - return (DLDataType){.code = kDLInt, .bits = 16, .lanes = 1}; case TF_INT32: - return (DLDataType){.code = kDLInt, .bits = 32, .lanes = 1}; case TF_INT64: - return (DLDataType){.code = kDLInt, .bits = 64, .lanes = 1}; + dtype.code = kDLInt; + break; + case TF_BOOL: case TF_UINT8: - return (DLDataType){.code = kDLUInt, .bits = 8, .lanes = 1}; case TF_UINT16: - return (DLDataType){.code = kDLUInt, .bits = 16, .lanes = 1}; + case TF_UINT32: + case TF_UINT64: + dtype.code = kDLUInt; + break; + case TF_BFLOAT16: + dtype.code = kDLBfloat; + break; default: - return (DLDataType){.bits = 0}; + RAI_SetError(error, RAI_EMODELIMPORT, "Unsupported data type in DLPack"); + break; } - return (DLDataType){.bits = 0}; + return dtype; } -RAI_Tensor *RAI_TensorCreateFromTFTensor(TF_Tensor *tensor, size_t batch_offset, - long long batch_size) { - RAI_Tensor *ret = RAI_TensorNew(); - - DLDevice device = (DLDevice){.device_type = kDLCPU, .device_id = 0}; +DLDevice GetDLDevice(TFE_TensorHandle *h, TF_Status *status, RAI_Error *error) { + DLDevice device; + const char *device_name = TFE_TensorHandleBackingDeviceName(h, status); - const size_t ndims = TF_NumDims(tensor); + if (TF_GetCode(status) != TF_OK) { + char *errorMessage = RedisModule_Strdup(TF_Message(status)); + RAI_SetError(error, RAI_EMODELRUN, errorMessage); + RedisModule_Free(errorMessage); + return device; + } - int64_t total_batch_size = TF_Dim(tensor, 0); - total_batch_size = total_batch_size > 0 ? total_batch_size : 1; + if (error->code != RAI_OK) { + return device; + } - int64_t *shape = RedisModule_Calloc(ndims, sizeof(*shape)); - int64_t *strides = RedisModule_Calloc(ndims, sizeof(*strides)); - for (int64_t i = 0; i < ndims; ++i) { - shape[i] = TF_Dim(tensor, i); - strides[i] = 1; + char device_type[5]; + int device_id = 0; + size_t device_len = strlen(device_name); + char *device_name_substr = strstr(device_name, "CPU"); + if (device_name_substr == NULL) { + device_name_substr = strstr(device_name, "GPU"); + } + if (device_name_substr == NULL) { + RAI_SetError(error, RAI_EMODELRUN, "Unsupported device type for DLPack"); + return device; } - if (batch_size != -1) { - shape[0] = batch_size; + strncpy(device_type, device_name_substr, 3); + if (strlen(device_name_substr) > 4) { + device_id = atoi(device_name_substr + 4); + } + + if (strcasecmp(device_type, "CPU") == 0) { + device.device_id = 0; + device.device_type = kDLCPU; + } else if (strcasecmp(device_type, "GPU") == 0) { + device.device_id = device_id; + device.device_type = kDLGPU; } else { - batch_size = total_batch_size; - } - for (int64_t i = ndims - 2; i >= 0; --i) { - strides[i] *= strides[i + 1] * shape[i + 1]; - } - - const size_t sample_bytesize = TF_TensorByteSize(tensor) / total_batch_size; - - // FIXME: In TF, RunSession allocates memory for output tensors - // This means that we either memcpy the tensor data and let - // Redis be responsible for the memory, or we reuse the TF - // allocated memory, which might not be optimal down the road - // Note: on YOLO this has no impact on perf -#ifdef RAI_COPY_RUN_OUTPUT - const size_t len = sample_bytesize * batch_size; - char *data = RedisModule_Calloc(len, sizeof(*data)); - memcpy(data, TF_TensorData(tensor) + sample_bytesize * batch_offset, len); -#endif - - // TODO: use manager_ctx to ensure TF tensor doesn't get deallocated - // This applies to outputs - - ret->tensor = (DLManagedTensor){ - .dl_tensor = (DLTensor){.device = device, -#ifdef RAI_COPY_RUN_OUTPUT - .data = data, -#else - .data = TF_TensorData(tensor), -#endif - .ndim = ndims, - .dtype = RAI_GetDLDataTypeFromTF(TF_TensorType(tensor)), - .shape = shape, - .strides = strides, - .byte_offset = 0}, - .manager_ctx = NULL, - .deleter = NULL}; + RAI_SetError(error, RAI_EMODELRUN, "Unsupported device type for DLPack"); + return device; + } - return ret; + return device; } -void RAI_TFDeallocator(void *data, size_t len, void *arg) { - // printf("DEALLOCATOR CALLED\n"); - // do nothing, memory is managed by Redis +int DeviceNameFromDLContext(const DLDevice *device, char device_name[64]) { + switch (device->device_type) { + case kDLCPU: + strcpy(device_name, "CPU:0"); + return REDISMODULE_OK; + case kDLGPU: + sprintf(device_name, "GPU:%d", device->device_id); + return REDISMODULE_OK; + } + return REDISMODULE_ERR; } -TF_Tensor *RAI_TFTensorFromTensor(RAI_Tensor *t) { -#ifdef RAI_COPY_RUN_INPUT - TF_Tensor *out = TF_AllocateTensor(RAI_GetTFDataTypeFromDL(t->tensor.dl_tensor.dtype), - t->tensor.dl_tensor.shape, t->tensor.dl_tensor.ndim, - RAI_TensorByteSize(t)); - memcpy(TF_TensorData(out), t->tensor.dl_tensor.data, TF_TensorByteSize(out)); - return out; -#else - return TF_NewTensor(RAI_GetTFDataTypeFromDL(t->tensor.dl_tensor.dtype), - t->tensor.dl_tensor.shape, t->tensor.dl_tensor.ndim, - t->tensor.dl_tensor.data, RAI_TensorByteSize(t), &RAI_TFDeallocator, NULL); -#endif /* RAI_COPY_RUN_INPUT */ +int TFDataTypeFromDLDataType(const DLDataType *dtype, TF_DataType *tf_dtype) { + switch (dtype->code) { + case kDLUInt: + switch (dtype->bits) { + case 8: + *tf_dtype = TF_UINT8; + return REDISMODULE_OK; + case 16: + *tf_dtype = TF_UINT16; + return REDISMODULE_OK; + case 32: + *tf_dtype = TF_UINT32; + return REDISMODULE_OK; + case 64: + *tf_dtype = TF_UINT64; + return REDISMODULE_OK; + default: + return REDISMODULE_ERR; + } + break; + case kDLInt: + switch (dtype->bits) { + case 8: + *tf_dtype = TF_INT8; + return REDISMODULE_OK; + case 16: + *tf_dtype = TF_INT16; + return REDISMODULE_OK; + case 32: + *tf_dtype = TF_INT32; + return REDISMODULE_OK; + case 64: + *tf_dtype = TF_INT64; + return REDISMODULE_OK; + default: + return REDISMODULE_ERR; + } + break; + case kDLFloat: + switch (dtype->bits) { + case 16: + *tf_dtype = TF_HALF; + return REDISMODULE_OK; + case 32: + *tf_dtype = TF_FLOAT; + return REDISMODULE_OK; + case 64: + *tf_dtype = TF_DOUBLE; + return REDISMODULE_OK; + default: + return REDISMODULE_ERR; + } + break; + case kDLBfloat: + switch (dtype->bits) { + case 16: + *tf_dtype = TF_BFLOAT16; + return REDISMODULE_OK; + default: + return REDISMODULE_ERR; + } + break; + default: + return REDISMODULE_ERR; + } +} + +void DeallocatorWrapperFunc(void *data, size_t len, void *dlmt_vptr) { + // NOTE: in the original TF implementation, the TFE_NewTensorHandleFromDeviceMemory + // function takes ownership of the device memory. The following function call is + // performed in order to deallocate the underlying DLPack structure + // In our case we are making the call from TFE_HandleFromDLPack, so the memory + // is already managed by the DLPack managed tensor that originally created it, + // and it is regulated by reference counting from within RedisAI. + // Therefore the present function should do nothing, the comment is retained + // for clarity only. + // TFE_CallDLManagedTensorDeleter(dlmt_vptr); } -TF_Tensor *RAI_TFTensorFromTensors(RAI_Tensor **ts, size_t count) { +bool IsValidStrideCompactRowMajorData(int64_t *shape_arr, int64_t *stride_arr, int ndim) { + if (ndim >= 1 && stride_arr[ndim - 1] != 1) { + return false; + } + for (int i = ndim - 2; i >= 0; --i) { + if (stride_arr[i] != shape_arr[i + 1] * stride_arr[i + 1]) { + return false; + } + } + return true; +} + +void TFE_CallDLManagedTensorDeleter(void *dlm_ptr) { + DLManagedTensor *dlMTensor = (DLManagedTensor *)dlm_ptr; + if (dlMTensor->deleter != NULL) { + dlMTensor->deleter(dlMTensor); + } +} - if (count == 0) { +DLManagedTensor *TFE_HandleToDLPack(TFE_TensorHandle *h, TF_Status *status, RAI_Error *error) { + DLDevice tf_dlm_device = GetDLDevice(h, status, error); + if (TF_GetCode(status) != TF_OK) { + char *errorMessage = RedisModule_Strdup(TF_Message(status)); + RAI_SetError(error, RAI_EMODELRUN, errorMessage); + RedisModule_Free(errorMessage); return NULL; } - size_t batch_size = 0; - size_t batch_byte_size = 0; + if (error->code != RAI_OK) { + return NULL; + } - for (size_t i = 0; i < count; i++) { - batch_size += ts[i]->tensor.dl_tensor.shape[0]; - batch_byte_size += RAI_TensorByteSize(ts[i]); + void *tf_dlm_data = TFE_TensorHandleDevicePointer(h, status); + if (TF_GetCode(status) != TF_OK) { + char *errorMessage = RedisModule_Strdup(TF_Message(status)); + RAI_SetError(error, RAI_EMODELRUN, errorMessage); + RedisModule_Free(errorMessage); + return NULL; } - RAI_Tensor *t0 = ts[0]; + TF_DataType data_type = TFE_TensorHandleDataType(h); - const int ndim = t0->tensor.dl_tensor.ndim; - int64_t batched_shape[ndim]; + DLDataType tf_dlm_type = GetDLDataType(data_type, status, error); + if (TF_GetCode(status) != TF_OK) { + char *errorMessage = RedisModule_Strdup(TF_Message(status)); + RAI_SetError(error, RAI_EMODELRUN, errorMessage); + RedisModule_Free(errorMessage); + return NULL; + } + if (error->code != RAI_OK) { + return NULL; + } + + TFDLManagedTensorCtx *tf_dlm_tensor_ctx = TFDLManagedTensorCtx_Create(h, status); - for (size_t i = 0; i < ndim; i++) { - batched_shape[i] = t0->tensor.dl_tensor.shape[i]; + DLManagedTensor *dlm_tensor = &tf_dlm_tensor_ctx->tensor; + dlm_tensor->manager_ctx = tf_dlm_tensor_ctx; + dlm_tensor->deleter = &DLManagedTensorDeleter; + dlm_tensor->dl_tensor.device = tf_dlm_device; + dlm_tensor->dl_tensor.ndim = tf_dlm_tensor_ctx->ndim; + dlm_tensor->dl_tensor.data = tf_dlm_data; + dlm_tensor->dl_tensor.dtype = tf_dlm_type; + dlm_tensor->dl_tensor.shape = tf_dlm_tensor_ctx->shape; + dlm_tensor->dl_tensor.strides = tf_dlm_tensor_ctx->strides; + dlm_tensor->dl_tensor.byte_offset = 0; + + return (void *)dlm_tensor; +} + +TFE_TensorHandle *TFE_HandleFromDLPack(void *dlm, TF_Status *status, TFE_Context *ctx, + RAI_Error *error) { + DLManagedTensor *dlmt = (DLManagedTensor *)dlm; + DLTensor *dl_tensor = &dlmt->dl_tensor; + char device_name[64]; + int ret = DeviceNameFromDLContext(&dl_tensor->device, device_name); + if (ret != 0) { + RAI_SetError(error, RAI_EMODELRUN, "ERR Unsupported device type for TFE"); + return NULL; } + TF_DataType dtype; + ret = TFDataTypeFromDLDataType(&dl_tensor->dtype, &dtype); + if (ret != 0) { + RAI_SetError(error, RAI_EMODELRUN, "ERR Unsupported data type in DLPack conversion to TFE"); + return NULL; + } + int num_dims = dl_tensor->ndim; + const int64_t *dims = dl_tensor->shape; + void *data = dl_tensor->data; - batched_shape[0] = batch_size; + size_t total_bytes = dl_tensor->dtype.bits / 8; + for (int i = 0; i < num_dims; i++) { + total_bytes *= dims[i]; + } - TF_Tensor *out = NULL; + if (dl_tensor->strides != NULL && + !IsValidStrideCompactRowMajorData(dl_tensor->shape, dl_tensor->strides, num_dims)) { + RAI_SetError(error, RAI_EMODELRUN, "ERR Invalid strides array from DLPack"); + return NULL; + } - if (count > 1) { - out = TF_AllocateTensor(RAI_GetTFDataTypeFromDL(t0->tensor.dl_tensor.dtype), batched_shape, - t0->tensor.dl_tensor.ndim, batch_byte_size); + TFE_TensorHandle *handle = + TFE_NewTensorHandleFromDeviceMemory(ctx, device_name, dtype, dims, num_dims, data, + total_bytes, &DeallocatorWrapperFunc, dlmt, status); - size_t offset = 0; - for (size_t i = 0; i < count; i++) { - size_t tbytesize = RAI_TensorByteSize(ts[i]); - memcpy(TF_TensorData(out) + offset, ts[i]->tensor.dl_tensor.data, tbytesize); - offset += tbytesize; - } - } else { - out = TF_NewTensor(RAI_GetTFDataTypeFromDL(t0->tensor.dl_tensor.dtype), - t0->tensor.dl_tensor.shape, t0->tensor.dl_tensor.ndim, - t0->tensor.dl_tensor.data, RAI_TensorByteSize(t0), &RAI_TFDeallocator, - NULL); + if (TF_GetCode(status) != TF_OK) { + char *errorMessage = RedisModule_Strdup(TF_Message(status)); + RAI_SetError(error, RAI_EMODELRUN, errorMessage); + RedisModule_Free(errorMessage); + return NULL; } - return out; + return handle; } RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_ModelOpts opts, @@ -221,21 +351,18 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod if (!parseDeviceStr(devicestr, &device, &deviceid)) { RAI_SetError(error, RAI_EMODELIMPORT, "ERR unsupported device"); + return NULL; } - TF_Graph *model = TF_NewGraph(); + TF_Graph *graph = TF_NewGraph(); + TF_ImportGraphDefOptions *options = TF_NewImportGraphDefOptions(); TF_Status *status = TF_NewStatus(); TF_Buffer *tfbuffer = TF_NewBuffer(); - TF_ImportGraphDefOptions *options = TF_NewImportGraphDefOptions(); - TF_Status *optionsStatus = NULL; - TF_SessionOptions *sessionOptions = NULL; - TF_Status *sessionStatus = NULL; - TF_Session *session = NULL; tfbuffer->length = modellen; tfbuffer->data = modeldef; - TF_GraphImportGraphDef(model, tfbuffer, options, status); + TF_GraphImportGraphDef(graph, tfbuffer, options, status); if (TF_GetCode(status) != TF_OK) { char *errorMessage = RedisModule_Strdup(TF_Message(status)); @@ -245,26 +372,26 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod } for (size_t i = 0; i < ninputs; ++i) { - TF_Operation *oper = TF_GraphOperationByName(model, inputs[i]); + TF_Operation *oper = TF_GraphOperationByName(graph, inputs[i]); if (oper == NULL || strcmp(TF_OperationOpType(oper), "Placeholder") != 0) { size_t len = strlen(inputs[i]); char *msg = RedisModule_Calloc(60 + len, sizeof(*msg)); sprintf(msg, "ERR Input node named \"%s\" not found in TF graph.", inputs[i]); RAI_SetError(error, RAI_EMODELIMPORT, msg); RedisModule_Free(msg); - goto cleanup; + return NULL; } } for (size_t i = 0; i < noutputs; ++i) { - TF_Operation *oper = TF_GraphOperationByName(model, outputs[i]); + TF_Operation *oper = TF_GraphOperationByName(graph, outputs[i]); if (oper == NULL) { size_t len = strlen(outputs[i]); char *msg = RedisModule_Calloc(60 + len, sizeof(*msg)); sprintf(msg, "ERR Output node named \"%s\" not found in TF graph", outputs[i]); RAI_SetError(error, RAI_EMODELIMPORT, msg); RedisModule_Free(msg); - goto cleanup; + return NULL; } } @@ -272,13 +399,48 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod options = NULL; TF_DeleteBuffer(tfbuffer); tfbuffer = NULL; - TF_DeleteStatus(status); - status = NULL; - optionsStatus = TF_NewStatus(); - sessionOptions = TF_NewSessionOptions(); + TF_Output tf_inputs[ninputs]; + TF_Output tf_outputs[noutputs]; + + for (size_t i = 0; i < ninputs; ++i) { + TF_Output port; + port.oper = TF_GraphOperationByName(graph, inputs[i]); + port.index = 0; + if (port.oper == NULL) { + return NULL; + } + tf_inputs[i] = port; + } + + for (size_t i = 0; i < noutputs; ++i) { + TF_Output port; + port.oper = TF_GraphOperationByName(graph, outputs[i]); + port.index = 0; + if (port.oper == NULL) { + return NULL; + } + tf_outputs[i] = port; + } + + TF_Function *function = + TF_GraphToFunction(graph, // fn_body + RAI_TF_FN_NAME, 0, // fn_name, append_hash_to_fn_name, + -1, NULL, // num_opers, opers + ninputs, tf_inputs, // ninputs, inputs, + noutputs, tf_outputs, // noutputs, outputs + outputs, // output_names, + NULL, // opts + NULL, // description + status // status + ); + + if (TF_GetCode(status) != TF_OK) { + RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(status))); + goto cleanup; + } - // For setting config options in session from the C API see: + // For setting additional config options in session from the C API see: // https://github.com/tensorflow/tensorflow/issues/13853 // import tensorflow as tf // config = tf.ConfigProto(device_count = {'GPU': 0}) @@ -286,68 +448,53 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod // result = list(map(hex, serialized)) // print(result) - if (device == RAI_DEVICE_CPU) { - // Set number of GPU to 0 with - // config.device_count = {'GPU': 0} - uint8_t config[] = {0x0a, 0x07, 0x0a, 0x03, 0x47, 0x50, 0x55, 0x10, 0x00}; - TF_SetConfig(sessionOptions, (void *)config, sizeof(config), optionsStatus); + TFE_ContextOptions *context_opts = TFE_NewContextOptions(); - if (TF_GetCode(optionsStatus) != TF_OK) { - RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(optionsStatus))); + TF_Buffer *config_proto = TF_CreateConfig(0 /*unsigned char enable_xla_compilation*/, + 1 /*unsigned char gpu_memory_allow_growth*/, + 1 /*unsigned int num_cpu_devices*/); + TFE_ContextOptionsSetConfig(context_opts, (void *)config_proto, sizeof(config_proto), status); + + if (opts.backends_intra_op_parallelism > 0) { + uint8_t proto[] = {0x10, (uint8_t)opts.backends_intra_op_parallelism}; + TFE_ContextOptionsSetConfig(context_opts, proto, sizeof(proto), status); + if (TF_GetCode(status) != TF_OK) { + RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(status))); goto cleanup; } + } - if (opts.backends_intra_op_parallelism > 0) { - uint8_t proto[] = {0x10, (uint8_t)opts.backends_intra_op_parallelism}; - TF_SetConfig(sessionOptions, proto, sizeof(proto), optionsStatus); - if (TF_GetCode(optionsStatus) != TF_OK) { - RAI_SetError(error, RAI_EMODELCONFIGURE, - RedisModule_Strdup(TF_Message(optionsStatus))); - goto cleanup; - } + if (opts.backends_inter_op_parallelism > 0) { + uint8_t proto1[] = {0x28, (uint8_t)opts.backends_inter_op_parallelism}; + TFE_ContextOptionsSetConfig(context_opts, proto1, sizeof(proto1), status); + if (TF_GetCode(status) != TF_OK) { + RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(status))); + goto cleanup; } + } - if (opts.backends_inter_op_parallelism > 0) { - uint8_t proto1[] = {0x28, (uint8_t)opts.backends_inter_op_parallelism}; - TF_SetConfig(sessionOptions, proto1, sizeof(proto1), optionsStatus); - if (TF_GetCode(optionsStatus) != TF_OK) { - RAI_SetError(error, RAI_EMODELCONFIGURE, - RedisModule_Strdup(TF_Message(optionsStatus))); - goto cleanup; - } - } - } else if (device == RAI_DEVICE_GPU) { - if (deviceid == -1) { - // Set - // config.gpu_options.allow_growth = True - uint8_t config[4] = {0x32, 0x02, 0x20, 0x01}; - TF_SetConfig(sessionOptions, (void *)config, 4, optionsStatus); - } else { - // Set - // config.gpu_options.allow_growth = True - // config.gpu_options.visible_device_list = '' - uint8_t config[7] = {0x32, 0x05, 0x20, 0x01, 0x2a, 0x01, 0x30}; - config[6] += (uint8_t)deviceid; - TF_SetConfig(sessionOptions, (void *)config, 7, optionsStatus); - } + TFE_ContextOptionsSetAsync(context_opts, 0); + TFE_ContextOptionsSetDevicePlacementPolicy(context_opts, TFE_DEVICE_PLACEMENT_EXPLICIT); + + TFE_Context *context = TFE_NewContext(context_opts, status); + if (TF_GetCode(status) != TF_OK) { + RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(status))); + goto cleanup; } - if (TF_GetCode(optionsStatus) != TF_OK) { - RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(optionsStatus))); + TFE_ContextAddFunction(context, function, status); + if (TF_GetCode(status) != TF_OK) { + RAI_SetError(error, RAI_EMODELCONFIGURE, RedisModule_Strdup(TF_Message(status))); goto cleanup; } - TF_DeleteStatus(optionsStatus); - optionsStatus = NULL; - sessionStatus = TF_NewStatus(); - session = TF_NewSession(model, sessionOptions, sessionStatus); + TFE_DeleteContextOptions(context_opts); - TF_Status *deviceListStatus = TF_NewStatus(); - TF_DeviceList *deviceList = TF_SessionListDevices(session, deviceListStatus); + TF_DeviceList *deviceList = TFE_ContextListDevices(context, status); const int num_devices = TF_DeviceListCount(deviceList); int foundNoGPU = 1; for (int i = 0; i < num_devices; ++i) { - const char *device_type = TF_DeviceListType(deviceList, i, deviceListStatus); + const char *device_type = TF_DeviceListType(deviceList, i, status); int cmp = strcmp(device_type, "GPU"); if (cmp == 0) { foundNoGPU = 0; @@ -357,19 +504,17 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod if (foundNoGPU == 1 && device == RAI_DEVICE_GPU) { RAI_SetError(error, RAI_EMODELCREATE, "ERR GPU requested but TF couldn't find CUDA"); TF_DeleteDeviceList(deviceList); - TF_DeleteStatus(deviceListStatus); + TF_DeleteStatus(status); goto cleanup; } TF_DeleteDeviceList(deviceList); - TF_DeleteStatus(deviceListStatus); - if (TF_GetCode(sessionStatus) != TF_OK) { + if (TF_GetCode(status) != TF_OK) { RAI_SetError(error, RAI_EMODELCREATE, RedisModule_Strdup(TF_Message(status))); goto cleanup; } - TF_DeleteSessionOptions(sessionOptions); - TF_DeleteStatus(sessionStatus); + TF_DeleteStatus(status); char **inputs_ = array_new(char *, ninputs); for (long long i = 0; i < ninputs; i++) { @@ -385,8 +530,8 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod memcpy(buffer, modeldef, modellen); RAI_Model *ret = RedisModule_Calloc(1, sizeof(*ret)); - ret->model = model; - ret->session = session; + ret->model = graph; + ret->session = context; ret->backend = backend; ret->devicestr = RedisModule_Strdup(devicestr); ret->ninputs = ninputs; @@ -401,37 +546,20 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod return ret; cleanup: - TF_DeleteGraph(model); + TF_DeleteGraph(graph); if (options) TF_DeleteImportGraphDefOptions(options); if (tfbuffer) TF_DeleteBuffer(tfbuffer); if (status) TF_DeleteStatus(status); - if (sessionOptions) - TF_DeleteSessionOptions(sessionOptions); - if (sessionStatus) - TF_DeleteStatus(sessionStatus); return NULL; } void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) { - TF_Status *status = TF_NewStatus(); - TF_CloseSession(model->session, status); - - if (TF_GetCode(status) != TF_OK) { - RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status))); - return; - } - - TF_DeleteSession(model->session, status); + TFE_DeleteContext(model->session); model->session = NULL; - if (TF_GetCode(status) != TF_OK) { - RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status))); - return; - } - TF_DeleteGraph(model->model); model->model = NULL; @@ -456,8 +584,6 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) { if (model->data) { RedisModule_Free(model->data); } - - TF_DeleteStatus(status); } int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) { @@ -466,15 +592,20 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) { const size_t nbatches = array_len(mctxs); if (nbatches == 0) { RAI_SetError(error, RAI_EMODELRUN, "ERR No batches to run"); - return 1; + return REDISMODULE_ERR; } const size_t ninputs = array_len(mctxs[0]->inputs); const size_t noutputs = array_len(mctxs[0]->outputs); - TF_Tensor *inputTensorsValues[ninputs]; - TF_Output inputs[ninputs]; - TF_Tensor *outputTensorsValues[noutputs]; - TF_Output outputs[noutputs]; + TFE_TensorHandle *inputTensorsHandles[ninputs]; + TFE_TensorHandle *outputTensorsHandles[noutputs]; + TFE_TensorHandle *deviceInputTensorsHandles[ninputs]; + TFE_TensorHandle *deviceOutputTensorsHandles[noutputs]; + + bool on_cpu = false; + if (strncasecmp(mctxs[0]->model->devicestr, "CPU", 3) == 0) { + on_cpu = true; + } size_t batch_sizes[nbatches]; size_t batch_offsets[nbatches]; @@ -490,38 +621,89 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) { } } + char tf_devicestr[256]; + int devicestr_len = strlen(mctxs[0]->model->devicestr); + if (on_cpu) { + sprintf(tf_devicestr, "/device:CPU:0"); + } else if (devicestr_len == 3) { + sprintf(tf_devicestr, "/device:%s:0", mctxs[0]->model->devicestr); + } else { + sprintf(tf_devicestr, "/device:%s", mctxs[0]->model->devicestr); + } + for (size_t i = 0; i < ninputs; ++i) { RAI_Tensor *batched_input_tensors[nbatches]; for (size_t b = 0; b < nbatches; ++b) { batched_input_tensors[b] = mctxs[b]->inputs[i].tensor; } - inputTensorsValues[i] = RAI_TFTensorFromTensors(batched_input_tensors, nbatches); - TF_Output port; - port.oper = TF_GraphOperationByName(mctxs[0]->model->model, mctxs[0]->inputs[i].name); - port.index = 0; - if (port.oper == NULL) { - return 1; + + if (nbatches > 1) { + RAI_Tensor *batched_tensor = + RAI_TensorCreateByConcatenatingTensors(batched_input_tensors, nbatches); + inputTensorsHandles[i] = + TFE_HandleFromDLPack(batched_tensor, status, mctxs[0]->model->session, error); + } else { + inputTensorsHandles[i] = TFE_HandleFromDLPack(batched_input_tensors[0], status, + mctxs[0]->model->session, error); } - inputs[i] = port; - } - for (size_t i = 0; i < noutputs; ++i) { - TF_Output port; - port.oper = TF_GraphOperationByName(mctxs[0]->model->model, mctxs[0]->outputs[i].name); - port.index = 0; - if (port.oper == NULL) { - return 1; + if (TF_GetCode(status) != TF_OK) { + char *errorMessage = RedisModule_Strdup(TF_Message(status)); + RAI_SetError(error, RAI_EMODELRUN, errorMessage); + TF_DeleteStatus(status); + RedisModule_Free(errorMessage); + return REDISMODULE_ERR; + } + + if (error->code != RAI_OK) { + return REDISMODULE_ERR; + } + + deviceInputTensorsHandles[i] = TFE_TensorHandleCopyToDevice( + inputTensorsHandles[i], mctxs[0]->model->session, tf_devicestr, status); + + if (TF_GetCode(status) != TF_OK) { + char *errorMessage = RedisModule_Strdup(TF_Message(status)); + RAI_SetError(error, RAI_EMODELRUN, errorMessage); + TF_DeleteStatus(status); + RedisModule_Free(errorMessage); + return REDISMODULE_ERR; } - outputs[i] = port; } - TF_SessionRun(mctxs[0]->model->session, NULL /* run_options */, inputs, inputTensorsValues, - ninputs, outputs, outputTensorsValues, noutputs, NULL /* target_opers */, - 0 /* ntargets */, NULL /* run_Metadata */, status); + TFE_Op *fn_op = TFE_NewOp(mctxs[0]->model->session, RAI_TF_FN_NAME, status); + if (TF_GetCode(status) != TF_OK) { + char *errorMessage = RedisModule_Strdup(TF_Message(status)); + RAI_SetError(error, RAI_EMODELRUN, errorMessage); + TF_DeleteStatus(status); + RedisModule_Free(errorMessage); + return REDISMODULE_ERR; + } + + TFE_OpAddInputList(fn_op, deviceInputTensorsHandles, ninputs, status); + if (TF_GetCode(status) != TF_OK) { + char *errorMessage = RedisModule_Strdup(TF_Message(status)); + RAI_SetError(error, RAI_EMODELRUN, errorMessage); + TF_DeleteStatus(status); + RedisModule_Free(errorMessage); + return REDISMODULE_ERR; + } + + int noutputs_ = noutputs; + TFE_Execute(fn_op, deviceOutputTensorsHandles, &noutputs_, status); + + if (TF_GetCode(status) != TF_OK) { + char *errorMessage = RedisModule_Strdup(TF_Message(status)); + RAI_SetError(error, RAI_EMODELRUN, errorMessage); + TF_DeleteStatus(status); + RedisModule_Free(errorMessage); + return REDISMODULE_ERR; + } for (size_t i = 0; i < ninputs; ++i) { - TF_DeleteTensor(inputTensorsValues[i]); + TFE_DeleteTensorHandle(inputTensorsHandles[i]); + TFE_DeleteTensorHandle(deviceInputTensorsHandles[i]); } if (TF_GetCode(status) != TF_OK) { @@ -529,36 +711,60 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) { RAI_SetError(error, RAI_EMODELRUN, errorMessage); TF_DeleteStatus(status); RedisModule_Free(errorMessage); - return 1; + return REDISMODULE_ERR; } for (size_t i = 0; i < noutputs; ++i) { + outputTensorsHandles[i] = TFE_TensorHandleCopyToDevice( + deviceOutputTensorsHandles[i], mctxs[0]->model->session, "/device:CPU:0", status); + + DLManagedTensor *outputDLTensor = + TFE_HandleToDLPack(outputTensorsHandles[i], status, error); + + if (TF_GetCode(status) != TF_OK) { + char *errorMessage = RedisModule_Strdup(TF_Message(status)); + RAI_SetError(error, RAI_EMODELRUN, errorMessage); + TF_DeleteStatus(status); + RedisModule_Free(errorMessage); + break; + } + + if (error->code != RAI_OK) { + break; + } + + RAI_Tensor *outputTensor = RAI_TensorCreateFromDLTensor(outputDLTensor); + if (nbatches > 1) { - if (TF_NumDims(outputTensorsValues[i]) == 0) { + if (RAI_TensorNumDims(outputTensor) == 0) { continue; } - if (TF_Dim(outputTensorsValues[i], 0) != total_batch_size) { - TF_DeleteTensor(outputTensorsValues[i]); + if (RAI_TensorDim(outputTensor, 0) != total_batch_size) { + RAI_TensorFree(outputTensor); TF_DeleteStatus(status); RAI_SetError(error, RAI_EMODELRUN, "ERR Model did not generate the expected batch size"); - return 1; + return REDISMODULE_ERR; } for (size_t b = 0; b < nbatches; b++) { - mctxs[b]->outputs[i].tensor = RAI_TensorCreateFromTFTensor( - outputTensorsValues[i], batch_offsets[b], batch_sizes[b]); + mctxs[b]->outputs[i].tensor = + RAI_TensorCreateBySlicingTensor(outputTensor, batch_offsets[b], batch_sizes[b]); } } else { - mctxs[0]->outputs[i].tensor = - RAI_TensorCreateFromTFTensor(outputTensorsValues[i], 0, -1); + mctxs[0]->outputs[i].tensor = RAI_TensorGetShallowCopy(outputTensor); } - TF_DeleteTensor(outputTensorsValues[i]); + RAI_TensorFree(outputTensor); + TFE_DeleteTensorHandle(deviceOutputTensorsHandles[i]); } TF_DeleteStatus(status); - return 0; + if (error->code != RAI_OK) { + return REDISMODULE_ERR; + } + + return REDISMODULE_OK; } int RAI_ModelSerializeTF(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) { @@ -577,7 +783,7 @@ int RAI_ModelSerializeTF(RAI_Model *model, char **buffer, size_t *len, RAI_Error RAI_SetError(error, RAI_EMODELSERIALIZE, "ERR Error serializing TF model"); TF_DeleteBuffer(tf_buffer); TF_DeleteStatus(status); - return 1; + return REDISMODULE_ERR; } *buffer = RedisModule_Alloc(tf_buffer->length); @@ -588,7 +794,7 @@ int RAI_ModelSerializeTF(RAI_Model *model, char **buffer, size_t *len, RAI_Error TF_DeleteStatus(status); } - return 0; + return REDISMODULE_OK; } const char *RAI_GetBackendVersionTF(void) { return TF_Version(); }