Skip to content

Commit

Permalink
[tflchef] Introduce ModelChef (#13632)
Browse files Browse the repository at this point in the history
This will introduce struct ModelChef to provide model chef functionality
as beginning of refactoring to provide huge model generation.

ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark authored Aug 9, 2024
1 parent 06be246 commit 0fc4de4
Showing 1 changed file with 82 additions and 56 deletions.
138 changes: 82 additions & 56 deletions compiler/tflchef/core/src/ModelChef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,6 @@ using namespace souschef;
namespace
{

class GeneratedModelImpl final : public tflchef::GeneratedModel::Impl
{
public:
GeneratedModelImpl(std::unique_ptr<flatbuffers::FlatBufferBuilder> &&builder)
: _builder{std::move(builder)}
{
// DO NOTHING
}

public:
const char *base(void) const override
{
// Return the base address of generated flatbuffer model
return reinterpret_cast<const char *>(_builder->GetBufferPointer());
}

public:
size_t size(void) const override
{
// Return the size of generated flatbuffer model
return _builder->GetSize();
}

private:
std::unique_ptr<flatbuffers::FlatBufferBuilder> _builder;
};

} // namespace

namespace
{

struct DataChefRegistry final : public Registry<DataChefFactory>
{
};
Expand Down Expand Up @@ -209,6 +177,8 @@ std::set<std::string> gather_customcode_set(const ::tflchef::ModelRecipe &model_
namespace
{

// TODO remove
#if 0
struct CookParams
{
std::vector<flatbuffers::Offset<::tflite::Buffer>> &buffer_vec;
Expand All @@ -219,6 +189,20 @@ struct CookParams
std::vector<std::string> &custom_code_vec;
std::string noname;
};
#endif

struct ModelChef
{
std::unique_ptr<flatbuffers::FlatBufferBuilder> flatbuffer_builder;

std::vector<flatbuffers::Offset<::tflite::SignatureDef>> signdef_vec;
std::vector<flatbuffers::Offset<::tflite::Buffer>> buffer_vec;
std::vector<flatbuffers::Offset<::tflite::OperatorCode>> code_vec;
std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vec;
std::map<tflite::BuiltinOperator, int32_t> builtin_code_map;
std::vector<std::string> custom_code_vec;
std::string graph_name;
};

std::vector<flatbuffers::Offset<tflite::DimensionMetadata>>
make_dim_metadata_vec(flatbuffers::FlatBufferBuilder *flatbuffer_builder, int32_t dims_count,
Expand Down Expand Up @@ -255,16 +239,17 @@ make_dim_metadata_vec(flatbuffers::FlatBufferBuilder *flatbuffer_builder, int32_
return dim_metadata_vec;
}

template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph, CookParams &cp)
template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph, ModelChef &mc)
{
LOGGER(l);

std::vector<flatbuffers::Offset<::tflite::Buffer>> &buffer_vec = cp.buffer_vec;
std::vector<flatbuffers::Offset<::tflite::OperatorCode>> &code_vec = cp.code_vec;
std::vector<flatbuffers::Offset<::tflite::SubGraph>> &subgraph_vec = cp.subgraph_vec;
std::unique_ptr<flatbuffers::FlatBufferBuilder> &flatbuffer_builder = cp.flatbuffer_builder;
std::map<tflite::BuiltinOperator, int32_t> &builtin_code_map = cp.builtin_code_map;
std::vector<std::string> &custom_code_vec = cp.custom_code_vec;
// TODO remove references
std::vector<flatbuffers::Offset<::tflite::Buffer>> &buffer_vec = mc.buffer_vec;
std::vector<flatbuffers::Offset<::tflite::OperatorCode>> &code_vec = mc.code_vec;
std::vector<flatbuffers::Offset<::tflite::SubGraph>> &subgraph_vec = mc.subgraph_vec;
std::unique_ptr<flatbuffers::FlatBufferBuilder> &flatbuffer_builder = mc.flatbuffer_builder;
std::map<tflite::BuiltinOperator, int32_t> &builtin_code_map = mc.builtin_code_map;
std::vector<std::string> &custom_code_vec = mc.custom_code_vec;

// Operand-related
std::vector<flatbuffers::Offset<::tflite::Tensor>> tensor_vec;
Expand All @@ -273,7 +258,7 @@ template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph,
std::vector<flatbuffers::Offset<::tflite::Operator>> operator_vec;

// default name for graph
std::string graph_name = cp.noname;
std::string graph_name = mc.graph_name;
if (graph.has_name())
graph_name = graph.name();

Expand Down Expand Up @@ -722,6 +707,40 @@ template <typename T> std::map<std::string, int32_t> cook_graph(const T &graph,

} // namespace

namespace
{

class GeneratedModelImpl final : public tflchef::GeneratedModel::Impl
{
public:
GeneratedModelImpl()
{
// DO NOTHING
}

public:
const char *base(void) const override
{
// Return the base address of generated flatbuffer model
return reinterpret_cast<const char *>(_mc.flatbuffer_builder->GetBufferPointer());
}

public:
size_t size(void) const override
{
// Return the size of generated flatbuffer model
return _mc.flatbuffer_builder->GetSize();
}

public:
ModelChef &model_chef(void) { return _mc; }

private:
ModelChef _mc;
};

} // namespace

namespace tflchef
{

Expand All @@ -743,26 +762,35 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
#include "DataChef.def"
#undef DATA_CHEF

std::unique_ptr<GeneratedModelImpl> gen_model(new GeneratedModelImpl());

ModelChef &mc = gen_model->model_chef();

mc.flatbuffer_builder =
std::unique_ptr<flatbuffers::FlatBufferBuilder>(new flatbuffers::FlatBufferBuilder(1024));

// TODO remove references

//
// Create FlatBufferBuilder
//
auto flatbuffer_builder =
std::unique_ptr<flatbuffers::FlatBufferBuilder>(new flatbuffers::FlatBufferBuilder(1024));
std::unique_ptr<flatbuffers::FlatBufferBuilder> &flatbuffer_builder = mc.flatbuffer_builder;

// Operand-related
std::vector<flatbuffers::Offset<::tflite::Buffer>> buffer_vec;
std::vector<flatbuffers::Offset<::tflite::Buffer>> &buffer_vec = mc.buffer_vec;

// Operation-related
std::vector<flatbuffers::Offset<::tflite::OperatorCode>> code_vec;
std::vector<flatbuffers::Offset<::tflite::OperatorCode>> &code_vec = mc.code_vec;

// SignatureDef-related
std::vector<flatbuffers::Offset<::tflite::SignatureDef>> signdef_vec;
std::vector<flatbuffers::Offset<::tflite::SignatureDef>> &signdef_vec = mc.signdef_vec;

// Graphs-related
std::vector<flatbuffers::Offset<::tflite::SubGraph>> subgraph_vec;
std::vector<flatbuffers::Offset<::tflite::SubGraph>> &subgraph_vec = mc.subgraph_vec;

// Create OperatorCode with Builtin Operator
auto builtin_code_map = gather_builtincode_map(model_recipe);
mc.builtin_code_map = gather_builtincode_map(model_recipe);
std::map<tflite::BuiltinOperator, int32_t> &builtin_code_map = mc.builtin_code_map;
for (auto const &opcode : builtin_code_map)
{
tflite::OperatorCodeBuilder code_builder{*flatbuffer_builder};
Expand All @@ -788,7 +816,8 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)

// Create OperatorCode with Custom Operator
std::set<std::string> custom_code_set = gather_customcode_set(model_recipe);
std::vector<std::string> custom_code_vec{custom_code_set.begin(), custom_code_set.end()};
mc.custom_code_vec = {custom_code_set.begin(), custom_code_set.end()};
std::vector<std::string> &custom_code_vec = mc.custom_code_vec;

for (auto opcode : custom_code_vec)
{
Expand Down Expand Up @@ -818,10 +847,9 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
//
// Create Main graph
//
CookParams cp{buffer_vec, code_vec, subgraph_vec, flatbuffer_builder,
builtin_code_map, custom_code_vec, "main"};

auto table = cook_graph<::tflchef::ModelRecipe>(model_recipe, cp);
mc.graph_name = "main";
auto table = cook_graph<::tflchef::ModelRecipe>(model_recipe, mc);
symbol_tables.push_back(table);

//
Expand All @@ -834,10 +862,9 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
std::ostringstream stringStream;
stringStream << "sub_" << (g + 1);

CookParams cp{buffer_vec, code_vec, subgraph_vec, flatbuffer_builder,
builtin_code_map, custom_code_vec, stringStream.str()};
mc.graph_name = stringStream.str();

auto table = cook_graph<::tflchef::Graph>(graph, cp);
auto table = cook_graph<::tflchef::Graph>(graph, mc);
symbol_tables.push_back(table);
}

Expand Down Expand Up @@ -946,8 +973,7 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe)
::tflite::FinishModelBuffer(*flatbuffer_builder, model);

// Return "GenerateModel"
return GeneratedModel{
std::unique_ptr<GeneratedModelImpl>(new GeneratedModelImpl(std::move(flatbuffer_builder)))};
return GeneratedModel{std::move(gen_model)};
}

} // namespace tflchef

0 comments on commit 0fc4de4

Please sign in to comment.