diff --git a/src/amberscript/parser.cc b/src/amberscript/parser.cc index 7eb3b9647..57f9f2e59 100644 --- a/src/amberscript/parser.cc +++ b/src/amberscript/parser.cc @@ -1816,7 +1816,7 @@ Result Parser::ParseDebug() { return res; } - auto dbg = MakeUnique(); + auto dbg = debug::Script::Create(); for (auto token = tokenizer_->NextToken();; token = tokenizer_->NextToken()) { if (token->IsEOL()) continue; @@ -1841,11 +1841,6 @@ Result Parser::ParseDebug() { } Result Parser::ParseDebugThread(debug::Events* dbg) { - Result result; - auto parseThread = [&](debug::Thread* thread) { - result = ParseDebugThreadBody(thread); - }; - auto token = tokenizer_->NextToken(); if (token->AsString() == "GLOBAL_INVOCATION_ID") { uint32_t invocation[3] = {}; @@ -1855,19 +1850,33 @@ Result Parser::ParseDebugThread(debug::Events* dbg) { return Result("expected invocation index"); invocation[i] = token->AsUint32(); } + + auto thread = debug::ThreadScript::Create(); + auto result = ParseDebugThreadBody(thread.get()); + if (!result.IsSuccess()) { + return result; + } + dbg->BreakOnComputeGlobalInvocation(invocation[0], invocation[1], - invocation[2], parseThread); + invocation[2], thread); } else if (token->AsString() == "VERTEX_INDEX") { token = tokenizer_->NextToken(); if (!token->IsInteger()) return Result("expected vertex index"); auto vertex_index = token->AsUint32(); - dbg->BreakOnVertexIndex(vertex_index, parseThread); + + auto thread = debug::ThreadScript::Create(); + auto result = ParseDebugThreadBody(thread.get()); + if (!result.IsSuccess()) { + return result; + } + + dbg->BreakOnVertexIndex(vertex_index, thread); } else { return Result("expected GLOBAL_INVOCATION_ID or VERTEX_INDEX"); } - return result; + return Result(); } Result Parser::ParseDebugThreadBody(debug::Thread* thread) { diff --git a/src/debug.cc b/src/debug.cc index 7d6c2265f..b1b79da71 100644 --- a/src/debug.cc +++ b/src/debug.cc @@ -14,8 +14,10 @@ #include "src/debug.h" +#include #include #include +#include #include "src/make_unique.h" @@ -24,15 +26,44 @@ namespace debug { namespace { -// ThreadScript is an implementation of amber::debug::Thread that records all -// calls made on it, which can be later replayed using ThreadScript::Run(). -class ThreadScript : public Thread { +class ScriptImpl : public Script { public: - void Run(Thread* thread) { + void Run(Events* e) const override { + for (auto f : sequence_) { + f(e); + } + } + + void BreakOnComputeGlobalInvocation( + uint32_t x, + uint32_t y, + uint32_t z, + const std::shared_ptr& thread) override { + sequence_.emplace_back([=](Events* events) { + events->BreakOnComputeGlobalInvocation(x, y, z, thread); + }); + } + + void BreakOnVertexIndex( + uint32_t index, + const std::shared_ptr& thread) override { + sequence_.emplace_back( + [=](Events* events) { events->BreakOnVertexIndex(index, thread); }); + } + + private: + using Event = std::function; + std::vector sequence_; +}; + +class ThreadScriptImpl : public ThreadScript { + public: + void Run(Thread* thread) const override { for (auto f : sequence_) { f(thread); } } + // Thread compliance void StepOver() override { sequence_.emplace_back([](Thread* t) { t->StepOver(); }); @@ -78,37 +109,12 @@ class ThreadScript : public Thread { Thread::~Thread() = default; Events::~Events() = default; -void Script::Run(Events* e) { - for (auto f : sequence_) { - f(e); - } +std::unique_ptr