Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature]add tensor folder map to graph #10525

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions oneflow/api/common/folder_rule_table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef ONEFLOW_API_COMMON_FOLDER_RULE_TABLE_H_
#define ONEFLOW_API_COMMON_FOLDER_RULE_TABLE_H_

#include "oneflow/core/common/singleton.h"
#include "oneflow/core/framework/folder_rule_table.h"

namespace oneflow {

inline std::vector<std::string>& GetFolderRuleTable() {
auto folder_rule_table= Singleton<FolderRuleTable>::Get();
return folder_rule_table->GetRules();
}

inline void AppendRuleToFolderRuleTable(std::string new_rule) {
auto folder_rule_table= Singleton<FolderRuleTable>::Get();
folder_rule_table->Append(new_rule);
}

inline void ResetFolderRuleTable() {
auto folder_rule_table= Singleton<FolderRuleTable>::Get();
folder_rule_table->Reset();
}

} // namespace oneflow

#endif // ONEFLOW_API_COMMON_FOLDER_RULE_TABLE_H_
15 changes: 15 additions & 0 deletions oneflow/api/python/framework/folder_rule_table.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <tuple>
#include "oneflow/api/common/folder_rule_table.h"
#include "oneflow/api/python/of_api_registry.h"

namespace py = pybind11;

namespace oneflow {

ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("GetFolderRuleTable", &GetFolderRuleTable, py::return_value_policy::reference_internal);
}

} // namespace oneflow
40 changes: 40 additions & 0 deletions oneflow/core/framework/folder_rule_table.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#ifndef ONEFLOW_CORE_FRAMEWORK_FOLDER_RULE_TABLE_H_
#define ONEFLOW_CORE_FRAMEWORK_FOLDER_RULE_TABLE_H_

#include <vector>
#include <string>
#include "oneflow/core/common/util.h"

namespace oneflow {

template<typename T, typename Kind>
class Singleton;

class FolderRuleTable final {
public:
OF_DISALLOW_COPY_AND_MOVE(FolderRuleTable);
~FolderRuleTable() = default;
void Append(std::string new_rule) {
if(!infix_rules_.empty()){
for(auto& rule : infix_rules_) {
if(new_rule!=rule && new_rule.find(rule)!=std::string::npos) {
rule = new_rule;
return;
}
}
}
infix_rules_.push_back(new_rule);
}
void Reset() {
infix_rules_.clear();
}
std::vector<std::string>& GetRules() {return infix_rules_;}
private:
friend class Singleton<FolderRuleTable>;
FolderRuleTable() = default;
std::vector<std::string> infix_rules_;
};

} // namespace oneflow

#endif // ONEFLOW_CORE_FRAMEWORK_FOLDER_RULE_TABLE_H_
2 changes: 2 additions & 0 deletions oneflow/core/framework/multi_client_session_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "oneflow/core/job/collective_boxing/scheduler.h"
#include "oneflow/core/graph/task_stream_index_manager.h"
#include "oneflow/core/framework/variable_tensor_mgr.h"
#include "oneflow/core/framework/folder_rule_table.h"
#ifdef WITH_CUDA
#include <cuda.h>
#endif // WITH_CUDA
Expand Down Expand Up @@ -113,6 +114,7 @@ Maybe<void> MultiClientSessionContext::TryInit(const ConfigProto& config_proto)
Singleton<summary::EventsWriter>::New();
Singleton<boxing::collective::Scheduler>::New();
Singleton<VariableTensorMgr>::New();
Singleton<FolderRuleTable>::New();
}

is_inited_ = true;
Expand Down
70 changes: 53 additions & 17 deletions oneflow/ir/lib/OneFlow/OneFlowOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ limitations under the License.
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/framework/variable_tensor_mgr.h"
#include "oneflow/api/common/folder_rule_table.h"


namespace mlir {
namespace oneflow {
Expand All @@ -44,6 +46,21 @@ StringAttr GenNewVariableOpName(MLIRContext* ctx, const std::string& key = "") {
return StringAttr::get(ctx, "variable_" + key + "_" + ::oneflow::NewUniqueId());
}

StringAttr GenNewUnaryVariableOpName(MLIRContext* ctx, const std::string& operand_name,
const std::string& oprator_name) {
std::string infix_rule = oprator_name + " ( " + operand_name + " )";
::oneflow::AppendRuleToFolderRuleTable(infix_rule);
return StringAttr::get(ctx, infix_rule);
}

StringAttr GenNewBinaryVariableOpName(MLIRContext* ctx, const std::string& lhs_operand_name,
const std::string& rhs_operand_name,
const std::string& oprator_name) {
std::string infix_rule = "( " + lhs_operand_name + " ) " + oprator_name + " ( " + rhs_operand_name + " )";
::oneflow::AppendRuleToFolderRuleTable(infix_rule);
return StringAttr::get(ctx, infix_rule);
}

bool MLIRDataTypesAreSame(const std::vector<DataType>& data_types) {
if (data_types.empty() || data_types.size() == 1) { return true; }
bool result = true;
Expand All @@ -63,6 +80,7 @@ bool DictionaryAttrsHaveSameDataType(const std::vector<mlir::DictionaryAttr>& at
}

OpFoldResult UnaryFold(MLIRContext* ctx, ArrayRef<Attribute> operands,
const std::string& operator_name,
const std::function<MaybeTensor(const TensorPtr&)>& f) {
::oneflow::LazyMode::Guard guard{false};
if (!operands.front()) { return {}; } // Important!
Expand All @@ -74,14 +92,17 @@ OpFoldResult UnaryFold(MLIRContext* ctx, ArrayRef<Attribute> operands,
attr_dict.get(OpTrait::IsOpConfCompatible<void>::getDeviceNameAttr()));
const auto result = f(tensor).GetPtrOrThrow();
attrs.set("value", support::TensorToDenseElementsAttr(result, ctx));
attrs.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(), GenNewVariableOpName(ctx));
auto operand_name = attr_dict.get("op_name").cast<mlir::StringAttr>().getValue().str();
attrs.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),
GenNewUnaryVariableOpName(ctx, operand_name, operator_name));
attrs.set(OpTrait::TensorSource<void>::getDataTypeAttrName(),
attr_dict.get(OpTrait::TensorSource<void>::getDataTypeAttrName()));

return attrs.getDictionary(ctx);
}

OpFoldResult BinaryFold(MLIRContext* ctx, ArrayRef<Attribute> operands,
const std::string& operator_name,
const std::function<MaybeTensor(const TensorPtr&, const TensorPtr&)>& f) {
::oneflow::LazyMode::Guard guard{false};
if (!(operands.front() && operands.back())) { return {}; } // Important!
Expand All @@ -107,7 +128,10 @@ OpFoldResult BinaryFold(MLIRContext* ctx, ArrayRef<Attribute> operands,
const auto result = f(lhs_tensor, rhs_tensor).GetPtrOrThrow();

attrs.set("value", support::TensorToDenseElementsAttr(result, ctx));
attrs.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(), GenNewVariableOpName(ctx));
auto lhs_operand_name = lhs_attr_dict.get("op_name").cast<mlir::StringAttr>().getValue().str();
auto rhs_operand_name = rhs_attr_dict.get("op_name").cast<mlir::StringAttr>().getValue().str();
attrs.set(OpTrait::IsOpConfCompatible<void>::getOpNameAttr(),
GenNewBinaryVariableOpName(ctx, lhs_operand_name, rhs_operand_name, operator_name));
attrs.set(OpTrait::TensorSource<void>::getDataTypeAttrName(),
lhs_attr_dict.get(OpTrait::TensorSource<void>::getDataTypeAttrName()));

Expand All @@ -129,30 +153,41 @@ OpFoldResult FrozenVariableOp::fold(FoldAdaptor adaptor) {
return DictionaryAttr::get(getContext(), attrs);
}

template<typename T>
std::string VectorToString(const std::vector<T>& vec) {
std::stringstream ss;
ss << "[";
for (const auto& elem : vec) {
ss << elem << ",";
}
ss << "]";
return ss.str();
}

OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return UnaryFold(getContext(), operands, [this](const auto& tensor) {
std::vector<int32_t> perm_;
for (auto& x : getPerm().getValue()) { perm_.emplace_back(x.cast<IntegerAttr>().getSInt()); }
std::vector<int32_t> perm_;
for (auto& x : getPerm().getValue()) { perm_.emplace_back(x.cast<IntegerAttr>().getSInt()); }
return UnaryFold(getContext(), operands, "Transpose("+VectorToString(perm_)+")", [this, &perm_](const auto& tensor) {
return functional::Transpose(tensor, perm_);
});
}

OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return UnaryFold(getContext(), operands, [this](const auto& tensor) {
std::vector<int64_t> shape_vec;
for (auto& x : getShape().getValue()) {
shape_vec.emplace_back(x.cast<mlir::IntegerAttr>().getValue().getSExtValue());
}
std::vector<int64_t> shape_vec;
for (auto& x : getShape().getValue()) {
shape_vec.emplace_back(x.cast<mlir::IntegerAttr>().getValue().getSExtValue());
}
return UnaryFold(getContext(), operands, "Reshape("+VectorToString(shape_vec)+")", [this, &shape_vec](const auto& tensor) {
return functional::Reshape(
tensor, ::oneflow::Shape(::oneflow::DimVector(shape_vec.begin(), shape_vec.end())));
});
}

OpFoldResult ScalarAddOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return UnaryFold(getContext(), operands, [this](const auto& tensor) -> MaybeTensor {
return UnaryFold(getContext(), operands, "ScalarAdd("+std::to_string(getIntOperand())+")", [this](const auto& tensor) -> MaybeTensor {
if (getHasIntOperand()) { return functional::ScalarAdd(tensor, getIntOperand(), 1, false); }
if (getHasFloatOperand()) {
return functional::ScalarAdd(tensor, getFloatOperand().convertToDouble(), 1, false);
Expand All @@ -164,24 +199,25 @@ OpFoldResult ScalarAddOp::fold(FoldAdaptor adaptor) {

OpFoldResult SqrtOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return UnaryFold(getContext(), operands, functional::Sqrt);
return UnaryFold(getContext(), operands, "Sqrt", functional::Sqrt);
}

OpFoldResult BroadcastMulOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return BinaryFold(getContext(), operands, functional::Mul);
return BinaryFold(getContext(), operands, "BroadcastMul", functional::Mul);
}

OpFoldResult BroadcastDivOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return BinaryFold(getContext(), operands, functional::Div);
return BinaryFold(getContext(), operands, "BroadcastDiv", functional::Div);
}

OpFoldResult BroadcastSubOp::fold(FoldAdaptor adaptor) {
auto operands = adaptor.getOperands();
return BinaryFold(getContext(), operands, [](const auto& lhs, const auto& rhs) -> MaybeTensor {
return functional::Sub(lhs, rhs, /*alpha=*/1.0, false);
});
return BinaryFold(getContext(), operands, "BroadcastSub",
[](const auto& lhs, const auto& rhs) -> MaybeTensor {
return functional::Sub(lhs, rhs, /*alpha=*/1.0, false);
});
}

} // namespace oneflow
Expand Down
5 changes: 5 additions & 0 deletions python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
from oneflow.nn.optimizer.lr_scheduler import LRScheduler
from oneflow.optim.optimizer import Optimizer

from oneflow.nn.graph.tensor_folder_map import TensorFolderMap


class Graph(object):
r"""Base class for training or evaluating a neural network in static graph mode.
Expand Down Expand Up @@ -282,6 +284,9 @@ def __call__(self, *args, **kwargs):

if not self._is_compiled:
self._compile(*args, **kwargs)

# generater tensor folder map
self.tensor_folder_map = TensorFolderMap(oneflow._oneflow_internal.GetFolderRuleTable())

return self.__run(*args, **kwargs)

Expand Down
Loading