Skip to content

Commit

Permalink
Load and save problem_cache from a file (#3203)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Sep 7, 2024
1 parent 4a73c69 commit 92f8f8a
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 11 deletions.
6 changes: 6 additions & 0 deletions docs/dev/env_vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ Set to "1" to print benchmarking trace.
Set to "2" to print detailed benchmarking trace.
Set to "3" to print compiled traces.

.. envvar:: MIGRAPHX_PROBLEM_CACHE

Set to path to json file to load and save problem cache.
This will load the json file into the problem cache if it exists, and when
compilation finishes it will save the problem cache.

MLIR vars
-------------

Expand Down
24 changes: 24 additions & 0 deletions docs/sphinx/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,27 @@
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
#####################################################################################
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
Expand Down
5 changes: 5 additions & 0 deletions src/file_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ std::string read_string(const fs::path& filename)
return generic_read_file<std::string>(filename);
}

void write_string(const fs::path& filename, const std::string& buffer)
{
write_buffer(filename, buffer.data(), buffer.size());
}

void write_buffer(const fs::path& filename, const char* buffer, std::size_t size)
{
std::ofstream os(filename, std::ios::out | std::ios::binary);
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/file_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ MIGRAPHX_EXPORT std::vector<char>
read_buffer(const fs::path& filename, size_t offset = 0, size_t nbytes = 0);
MIGRAPHX_EXPORT std::string read_string(const fs::path& filename);

MIGRAPHX_EXPORT void write_string(const fs::path& filename, const std::string& buffer);
MIGRAPHX_EXPORT void write_buffer(const fs::path& filename, const char* buffer, std::size_t size);
MIGRAPHX_EXPORT void write_buffer(const fs::path& filename, const std::vector<char>& buffer);

Expand Down
43 changes: 37 additions & 6 deletions src/include/migraphx/serialize.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -144,6 +144,12 @@ auto to_value_impl(rank<12>, const T& x)
return v;
}

template <class T, MIGRAPHX_REQUIRES(std::is_same<T, value>{})>
value to_value_impl(rank<13>, const T& x)
{
return x;
}

template <class T, MIGRAPHX_REQUIRES(std::is_empty<T>{})>
void from_value_impl(rank<0>, const value& v, T& x)
{
Expand Down Expand Up @@ -191,9 +197,28 @@ template <class T>
auto from_value_impl(rank<4>, const value& v, T& x)
-> decltype(x.insert(*x.begin()), std::declval<typename T::mapped_type>(), void())
{
x.clear();
for(auto&& e : v)
x.emplace(e.get_key(), from_value<typename T::mapped_type>(e));
if(v.is_object())
{
x.clear();
for(auto&& e : v)
x.emplace(from_value<typename T::key_type>(e.get_key()),
from_value<typename T::mapped_type>(e));
}
else if(v.is_array())
{
x.clear();
for(auto&& e : v)
{
if(e.size() != 2)
MIGRAPHX_THROW("Expected a pair");
x.emplace(from_value<typename T::key_type>(e[0]),
from_value<typename T::mapped_type>(e[1]));
}
}
else
{
MIGRAPHX_THROW("Expected object or array");
}
}

template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
Expand Down Expand Up @@ -233,18 +258,24 @@ auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_v
migraphx_from_value(v, x);
}

template <class T, MIGRAPHX_REQUIRES(std::is_same<T, value>{})>
void from_value_impl(rank<11>, const value& v, T& x)
{
x = v;
}

} // namespace detail

template <class T>
value to_value(const T& x)
{
return detail::to_value_impl(rank<12>{}, x);
return detail::to_value_impl(rank<13>{}, x);
}

template <class T>
void from_value(const value& v, T& x)
{
detail::from_value_impl(rank<10>{}, v, x);
detail::from_value_impl(rank<11>{}, v, x);
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
26 changes: 23 additions & 3 deletions src/targets/gpu/include/migraphx/gpu/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,25 @@ struct hip_device

struct context
{
struct auto_save_problem_cache : problem_cache
{
auto_save_problem_cache() : problem_cache{} {}

bool auto_save = false;

auto_save_problem_cache(const auto_save_problem_cache&) = delete;
auto_save_problem_cache& operator=(const auto_save_problem_cache&) = delete;
virtual ~auto_save_problem_cache()
{
if(auto_save)
this->save();
}
};
context(std::size_t device_id = 0, std::size_t n = value_of(MIGRAPHX_NSTREAMS{}, 1))
: current_device(std::make_shared<hip_device>(device_id, n)),
begin_event(create_event()),
finish_event(create_event())
finish_event(create_event()),
pc(std::make_shared<auto_save_problem_cache>())
{
}

Expand Down Expand Up @@ -334,7 +349,12 @@ struct context
return result;
}

problem_cache& get_problem_cache() { return pc; }
problem_cache& get_problem_cache() { return *pc; }
void load_problem_cache()
{
pc->load();
pc->auto_save = true;
}

private:
// TODO: Make this a vector to support multiple devices
Expand All @@ -348,7 +368,7 @@ struct context
// for stream syncronization
shared<hip_event_ptr> begin_event = nullptr;
shared<hip_event_ptr> finish_event = nullptr;
problem_cache pc{};
std::shared_ptr<auto_save_problem_cache> pc = nullptr;
};

inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
Expand Down
5 changes: 4 additions & 1 deletion src/targets/gpu/include/migraphx/gpu/problem_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/gpu/export.h>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct problem_cache
struct MIGRAPHX_GPU_EXPORT problem_cache
{
bool has(const std::string& name, const value& problem) const;
void insert(const std::string& name, const value& problem, const value& solution);
void mark(const std::string& name, const value& problem);
optional<value> get(const std::string& name, const value& problem) const;
void load();
void save() const;
std::unordered_map<value, value> cache;
};

Expand Down
27 changes: 27 additions & 0 deletions src/targets/gpu/problem_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,38 @@
*/
#include <migraphx/gpu/problem_cache.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/json.hpp>
#include <migraphx/env.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/file_buffer.hpp>
#include <iostream>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_PROBLEM_CACHE)

void problem_cache::load()
{
auto pc_path = string_value_of(MIGRAPHX_PROBLEM_CACHE{});
if(pc_path.empty())
return;
if(not fs::exists(pc_path))
{
std::cout << "Problem cache not found. Creating new file.\n";
return;
}
from_value(from_json_string(read_string(pc_path)), cache);
}
void problem_cache::save() const
{
auto pc_path = string_value_of(MIGRAPHX_PROBLEM_CACHE{});
if(pc_path.empty())
return;
write_string(pc_path, to_pretty_json_string(to_value(cache)));
}

static value create_key(const std::string& name, const value& problem)
{
return {{"name", name}, {"problem", problem}};
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
{
auto& ctx = any_cast<context>(gctx);
ctx.set_exhaustive_tune_flag(options.exhaustive_tune);
ctx.load_problem_cache();
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type);
unsupported_types.erase(shape::type_t::fp8e4m3fnuz_type);
Expand Down
9 changes: 9 additions & 0 deletions test/include/test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ inline Stream& operator<<(Stream& s, std::nullptr_t)
return s;
}

template <class Stream, class T, class U>
inline Stream& operator<<(Stream& s, const std::pair<T, U>& p)
{
s << "{";
s << p.first << ", " << p.second;
s << "}";
return s;
}

template <class Stream,
class Range,
class = typename std::enable_if<not std::is_convertible<Range, std::string>{}>::type>
Expand Down
40 changes: 39 additions & 1 deletion test/serialize_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,6 +26,8 @@
#include <test.hpp>

#include <numeric>
#include <map>
#include <vector>

struct empty_type
{
Expand Down Expand Up @@ -145,6 +147,42 @@ TEST_CASE(serialize_optional)
EXPECT(v.to<int>() == 2);
}

TEST_CASE(serialize_map)
{
std::map<int, int> m = {{1, 1}, {2, 4}, {3, 9}};
migraphx::value v = migraphx::to_value(m);
EXPECT(v.is_array());
EXPECT(m == migraphx::from_value<std::map<int, int>>(v));
}

TEST_CASE(serialize_invalid_map1)
{
migraphx::value v = {{1, 1}, {2, 4}, {3, 9, 27}};
EXPECT(test::throws([&] { migraphx::from_value<std::map<int, int>>(v); }));
}

TEST_CASE(serialize_invalid_map2)
{
migraphx::value v = 2;
EXPECT(test::throws([&] { migraphx::from_value<std::map<int, int>>(v); }));
}

TEST_CASE(serialize_struct_to_map)
{
migraphx::value v = migraphx::to_value(reflectable_type{});
EXPECT(v.is_object());
std::map<std::string, migraphx::value> m;
migraphx::from_value(v, m);
EXPECT(m.size() == v.size());
}

TEST_CASE(serialize_vector_value)
{
std::vector<migraphx::value> x = {{1}, {2}};
migraphx::value v = migraphx::to_value(x);
EXPECT(x == migraphx::from_value<std::vector<migraphx::value>>(v));
}

TEST_CASE(from_value_binary)
{
std::vector<std::uint8_t> data(10);
Expand Down

0 comments on commit 92f8f8a

Please sign in to comment.