Skip to content

Commit

Permalink
[VE] Add ApplyKerasMomentum
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ishizaka committed Jun 1, 2021
1 parent 6738dd2 commit dafc392
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 7 deletions.
81 changes: 74 additions & 7 deletions tensorflow/core/kernels/training_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,21 @@ using GPUDevice = Eigen::GpuDevice;
using Index = Eigen::Index;

#ifdef TENSORFLOW_USE_VE
typedef Eigen::VeDevice VEDevice;
using VEDevice = Eigen::VeDevice;

template <>
VariableInputLockHolder MaybeLockVariableInputMutexesInOrder<VEDevice, float> (
OpKernelContext* ctx, bool do_lock, bool sparse,
const std::vector<int>& input_ids) {
return VEMaybeLockVariableInputMutexesInOrder<float>(
ctx, do_lock, sparse, input_ids);
}

template <>
Status GetInputTensorFromVariable<VEDevice, float>(
OpKernelContext* ctx, int input, bool lock_held, bool sparse, Tensor* out) {
return VEGetInputTensorFromVariable<float>(ctx, input, lock_held, sparse, out);
}
#endif // TENSORFLOW_USE_VE

namespace {
Expand Down Expand Up @@ -3543,9 +3557,9 @@ TF_CALL_COMPLEX_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_KERNELS

template <typename Device, typename T>
class ApplyKerasMomentumOp : public OpKernel {
class ApplyKerasMomentumOpBase : public OpKernel {
public:
explicit ApplyKerasMomentumOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
explicit ApplyKerasMomentumOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
}
Expand Down Expand Up @@ -3590,16 +3604,34 @@ class ApplyKerasMomentumOp : public OpKernel {
errors::InvalidArgument("momentum is not a scalar: ",
momentum.shape().DebugString()));

const Device& device = ctx->template eigen_device<Device>();
functor::ApplyKerasMomentum<Device, T>()(
device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), grad.flat<T>(),
momentum.scalar<T>(), use_nesterov_);
_Compute(ctx, var, accum, lr, grad, momentum, use_nesterov_);
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}

private:
bool use_exclusive_lock_;
bool use_nesterov_;

virtual void _Compute(OpKernelContext* ctx, Tensor& var, Tensor& accum,
const Tensor& lr, const Tensor& grad,
const Tensor& momentum, bool use_nesterov) = 0;
};

template <typename Device, typename T>
class ApplyKerasMomentumOp : public ApplyKerasMomentumOpBase<Device, T> {
public:
explicit ApplyKerasMomentumOp(OpKernelConstruction* ctx) :
ApplyKerasMomentumOpBase<Device, T>(ctx) {}

private:
void _Compute(OpKernelContext* ctx, Tensor& var, Tensor& accum,
const Tensor& lr, const Tensor& grad,
const Tensor& momentum, bool use_nesterov) override {
const Device& device = ctx->template eigen_device<Device>();
functor::ApplyKerasMomentum<Device, T>()(
device, var.flat<T>(), accum.flat<T>(), lr.scalar<T>(), grad.flat<T>(),
momentum.scalar<T>(), use_nesterov);
}
};

#define REGISTER_KERNELS(D, T) \
Expand Down Expand Up @@ -3639,6 +3671,41 @@ REGISTER_KERNELS(GPU, double);
REGISTER_KERNELS(GPU, complex64);
REGISTER_KERNELS(GPU, complex128);
#endif

#ifdef TENSORFLOW_USE_VE
template <typename T>
class VEApplyKerasMomentumOp : public ApplyKerasMomentumOpBase<VEDevice, T> {
public:
explicit VEApplyKerasMomentumOp(OpKernelConstruction* ctx) :
ApplyKerasMomentumOpBase<VEDevice, T>(ctx) {}

private:
void _Compute(OpKernelContext* ctx, Tensor& var, Tensor& accum,
const Tensor& lr, const Tensor& grad, const Tensor& momentum,
bool use_nesterov) override {
VEOpKernelHelper::ArgsImpl<> Args;
Args.addArg<Tensor>(var);
Args.addArg<Tensor>(accum);
Args.addArg<Tensor>(lr);
Args.addArg<Tensor>(grad);
Args.addArg<Tensor>(momentum);
Args.addArg<int64_t>(use_nesterov ? 1 : 0) ;

VEOpKernelHelper::Call(ctx, "ApplyKerasMomentum", Args);
}
};

#define REGISTER_VE_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("ResourceApplyKerasMomentum") \
.Device(DEVICE_VE) \
.HostMemory("var") \
.HostMemory("accum") \
.TypeConstraint<T>("T"), \
VEApplyKerasMomentumOp<T>);
REGISTER_VE_KERNELS(float);
#undef REGISTER_VE_KERNELS
#endif // TENSORFLOW_USE_VE

#undef REGISTER_CPU_KERNELS
#undef REGISTER_KERNELS

Expand Down
Binary file modified third_party/veoffload/veorun_tf
Binary file not shown.

0 comments on commit dafc392

Please sign in to comment.