Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update the lock mechanism in the user offline tuning tool #1383

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace TensileLite

if(problemSolution.second > 0)
{
auto sol_iter = m_override.find(problemSolution.first);
auto sol_iter = m_override.find_range(problemSolution.first);
for(auto sol_idx = sol_iter.first; sol_idx != sol_iter.second; sol_idx++)
{
if(sol_idx->second == problemSolution.second)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,27 +160,23 @@ namespace TensileLite

int size()
{
std::shared_lock<std::shared_timed_mutex> lock(m_mutex);
auto size = m_override.size();
return size;
}

auto find(const ProblemOverride& prob_key)
auto find_range(const ProblemOverride& prob_key)
{
std::shared_lock<std::shared_timed_mutex> lock(m_mutex);
auto iter = m_override.equal_range(prob_key);
return iter;
}

void add(const std::pair<ProblemOverride, int>& problemSolution)
{
std::lock_guard<std::shared_timed_mutex> lock(m_mutex);
m_override.insert(problemSolution);
}

void erase(std::multimap<ProblemOverride, int>::iterator& sol_idx)
{
std::lock_guard<std::shared_timed_mutex> lock(m_mutex);
m_override.erase(sol_idx);
}

Expand All @@ -192,7 +188,6 @@ namespace TensileLite
private:
std::multimap<ProblemOverride, int> m_override;
std::mutex m_guard;
std::shared_timed_mutex m_mutex;
};
} // namespace Tensile

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ bool problem_override_from_file(rocblaslt_handle& handle,
std::vector<rocblaslt_matmul_heuristic_result> overrideResults;
std::vector<int> solutionIndex(1);
TensileLite::ProblemOverride prob_key(RocblasltContractionProblem2ProblemOverride(problem));
auto sol_iter = m_override.find(prob_key);
auto sol_iter = m_override.find_range(prob_key);

for(auto sol_idx = std::make_reverse_iterator(sol_iter.second);
!success && sol_idx != std::make_reverse_iterator(sol_iter.first);
Expand Down Expand Up @@ -196,7 +196,7 @@ bool problem_override_from_file_cpp(
std::vector<rocblaslt_matmul_heuristic_result> overrideResults;
std::vector<int> solutionIndex(1);
TensileLite::ProblemOverride prob_key(TensileDataGemm2ProblemOverride(gemmData));
auto sol_iter = m_override.find(prob_key);
auto sol_iter = m_override.find_range(prob_key);

for(auto sol_idx = std::make_reverse_iterator(sol_iter.second);
!success && sol_idx != std::make_reverse_iterator(sol_iter.first);
Expand Down
Loading