diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f51c2f..9e780b4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,6 +52,33 @@ packageProject( DEPENDENCIES ) +CPMAddPackage( + NAME libcopp + GITHUB_REPOSITORY owt5008137/libcopp + GIT_TAG 1.2.1 +) + +packageProject( + NAME copp + VERSION ${PROJECT_VERSION} + BINARY_DIR ${libcopp_BINARY_DIR} + INCLUDE_DIR ${libcopp_SOURCE_DIR}/include + INCLUDE_DESTINATION include/${PROJECT_NAME}-${PROJECT_VERSION} + DEPENDENCIES +) + +packageProject( + NAME cotask + VERSION ${PROJECT_VERSION} + BINARY_DIR ${libcopp_BINARY_DIR} + INCLUDE_DIR ${libcopp_SOURCE_DIR}/include + INCLUDE_DESTINATION include/${PROJECT_NAME}-${PROJECT_VERSION} + DEPENDENCIES +) + +target_include_directories(copp PUBLIC ${libcopp_SOURCE_DIR}/include) +target_include_directories(cotask PUBLIC ${libcopp_SOURCE_DIR}/include) + if (NOT TARGET lua) CPMFindPackage( NAME lua @@ -70,10 +97,12 @@ if (NOT TARGET lua) add_library(LuaForGlue ${lua_sources} ${lua_headers}) set_target_properties(LuaForGlue PROPERTIES LINKER_LANGUAGE C) - target_include_directories(LuaForGlue - PUBLIC - $ - ) + # target_include_directories(LuaForGlue + # PUBLIC + # $ + # ) + + target_include_directories(LuaForGlue PUBLIC ${lua_SOURCE_DIR}) if(ANDROID) target_compile_definitions(LuaForGlue PRIVATE LUA_USE_POSIX LUA_USE_DLOPEN) @@ -117,7 +146,7 @@ target_compile_options(Smark PUBLIC "$<$:/permissive->") # Link dependencies (if required) # target_link_libraries(Smark PUBLIC cxxopts) -target_link_libraries(Smark PUBLIC uv_a HttpParser LuaForGlue) +target_link_libraries(Smark PUBLIC uv_a HttpParser LuaForGlue copp cotask) target_include_directories(Smark PUBLIC @@ -134,5 +163,5 @@ packageProject( BINARY_DIR ${PROJECT_BINARY_DIR} INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include INCLUDE_DESTINATION include/${PROJECT_NAME}-${PROJECT_VERSION} - DEPENDENCIES "uv_a" "HttpParser" "LuaForGlue" + DEPENDENCIES ) diff --git a/include/client.h b/include/client.h index f312d3d..78c7d0a 100644 --- a/include/client.h +++ b/include/client.h @@ -4,21 +4,32 @@ #include #include +#include "tasks.h" #include "util.h" namespace smark { class TCPClient : public smark::util::Socket { public: - TCPClient(smark::util::EventLoop* el); + explicit TCPClient(smark::util::EventLoop* el); }; class HttpClient : public TCPClient { public: - HttpClient(smark::util::EventLoop* el); + explicit HttpClient(smark::util::EventLoop* el); void Request(util::HttpRequest* request); std::function)> on_response; private: util::HttpReponseParser parser_; }; + + class HttpAsyncClient : protected HttpClient { + // use protected inherit to prevent user changes on_complete during requesting. + public: + explicit HttpAsyncClient(smark::util::EventLoop* el); + std::shared_ptr> ConnectAsync(std::string ip, int16_t port); + std::shared_ptr>> RequestAsync( + util::HttpRequest* request); + void Close(); + }; } // namespace smark \ No newline at end of file diff --git a/include/tasks.h b/include/tasks.h new file mode 100644 index 0000000..a2e3af7 --- /dev/null +++ b/include/tasks.h @@ -0,0 +1,114 @@ +#pragma once +#include + +#include +#include +#include +#include +#include +#include + +#include "debug.h" +#include "util.h" + +namespace smark::tasks { + class TaskManager; + class Task; + + typedef std::function)> TaskProc; + + // TODO: add exception handler + class Task : public smark::util::enable_shared_from_this { + public: + typedef std::function)> ProcType; + enum State { New, Runable, Dead }; + State state = State::New; + explicit Task(TaskProc proc); + Task() = default; + void SetProc(TaskProc proc); + void Start(); + void Yield(); + void Resume(); + void Wait(std::shared_ptr task); + void WaitAll(const std::vector>* task_list); + void Stop(); + virtual ~Task() = default; + + protected: + template + void SetProcContext_(std::shared_ptr task_ptr, + std::function)> proc) { + task_ptr_ = cotask::task<>::create([&, task_ptr, proc]() { + // void task will not run to this. + DEFER(std::dynamic_pointer_cast(task_ptr)->state = State::Dead; + UnregisterTaskFromMap2Task();) + RegisterTaskToMap2Task(std::dynamic_pointer_cast(task_ptr)); + proc(task_ptr); + }); + } + + private: + cotask::task<>::ptr_t task_ptr_; + void* result_; + void RegisterTaskToMap2Task(std::shared_ptr task_ptr); + void UnregisterTaskFromMap2Task(); + }; + + template class ValueTask : public Task { + public: + typedef std::function>)> ProcType; + explicit ValueTask(ProcType proc) : Task() { SetProc(proc); } + ValueTask() = default; + void SetProc(ProcType proc) { + SetProcContext_>(shared_from_this>(), proc); + } + inline void Complete(T result) { + result_ = result; + Task::Stop(); + } + inline T GetResult() { return result_; } + inline State GetState() { return state; } + + private: + T result_; + }; + + class TaskManager { + public: + void Wait(std::shared_ptr waiter, std::shared_ptr waitting); + void StopTask(std::shared_ptr task); + int RunOnce(); + void Stop(); + bool IsEmpty(); + bool is_stopped = false; + + private: + std::set> completed_tasks_; + std::map, std::shared_ptr> waitting_tasks_; + std::queue> starting_tasks_; + }; + + std::shared_ptr GetCurrentTask(); + + extern thread_local TaskManager task_mgr; + +#define _async(...) GET_MACRO_V2(__VA_ARGS__, vt_async, task_async)(__VA_ARGS__) +#define task_async(proc) smark::util::make_shared(proc) +#define vt_async(T, proc) \ + smark::util::make_shared>(proc) +#define void_task(...) GET_MACRO_V0_1(_0, ##__VA_ARGS__, vt_void_task, task_void_task)(__VA_ARGS__) +#define task_void_task() smark::util::make_shared() +#define vt_void_task(T) smark::util::make_shared>() + + template std::shared_ptr await(std::shared_ptr task) { + auto current_task = GetCurrentTask(); + task_mgr.Wait(current_task, std::dynamic_pointer_cast(task)); + + current_task->Yield(); + + return std::dynamic_pointer_cast(task); + } +#ifdef DEBUG + extern thread_local std::map*, std::shared_ptr> map2task; +#endif +} // namespace smark::tasks \ No newline at end of file diff --git a/include/util.h b/include/util.h index 8a71ad3..5206d39 100644 --- a/include/util.h +++ b/include/util.h @@ -19,7 +19,48 @@ extern "C" { = std::unique_ptr>{reinterpret_cast(1), \ [&](void*) { X }}; +// macro overload util +#define GET_MACRO_V2(_1, _2, NAME, ...) NAME +#define GET_MACRO_V0_1(_0, _1, NAME, ...) NAME + namespace smark::util { + // enabe shared_from_this in constructor + template void deleter(PtrT* __ptr) { + auto t = static_cast(__ptr); + t->~T(); + free(t); + } + + /** + * Base class allowing use of member function shared_from_this. + */ + template class enable_shared_from_this { + public: + std::shared_ptr* _construct_pself; + std::weak_ptr _construct_self; + + template std::shared_ptr shared_from_this() { + if (_construct_pself) { + return std::static_pointer_cast(*_construct_pself); // in constructor + } else { + return std::static_pointer_cast(_construct_self.lock()); + } + } + }; + + template + std::shared_ptr make_shared(Params&&... args) { + std::shared_ptr rtn; + T* t = (T*)calloc(1, sizeof(T)); + rtn.reset(t, deleter); + t->_construct_pself = &rtn; + t = new (t) T(std::forward(args)...); + t->_construct_pself = NULL; + t->_construct_self = rtn; + + return std::static_pointer_cast(rtn); + } + typedef std::function CallbackType; // void(int status) class EventLoop; class IEventObj { diff --git a/source/client.cpp b/source/client.cpp index 7cbe996..5d776fe 100644 --- a/source/client.cpp +++ b/source/client.cpp @@ -25,4 +25,33 @@ namespace smark { }); } + HttpAsyncClient::HttpAsyncClient(smark::util::EventLoop* el) : HttpClient(el) {} + + std::shared_ptr> HttpAsyncClient::ConnectAsync(std::string ip, + int16_t port) { + auto task = _async(int, [=](std::shared_ptr> this_task) { + DLOG("Try to connect:" << LOG_VALUE(ip) << LOG_VALUE(port)); + Connect(ip, port, [=](int status) { + DLOG("Connected result:" << LOG_VALUE(status)); + this_task->Complete(status); + }); + }); + return task; + } + + std::shared_ptr>> + HttpAsyncClient::RequestAsync(util::HttpRequest* request) { + auto task = _async( + std::shared_ptr, + [=](std::shared_ptr>> this_task) { + on_response = [this_task](auto, std::shared_ptr res) { + this_task->Complete(res); + }; + HttpClient::Request(request); + }); + return task; + } + + void HttpAsyncClient::Close() { Socket::Close(); } + } // namespace smark diff --git a/source/tasks.cpp b/source/tasks.cpp new file mode 100644 index 0000000..23c83bf --- /dev/null +++ b/source/tasks.cpp @@ -0,0 +1,103 @@ +#include "tasks.h" + +#include + +namespace smark::tasks { + thread_local std::map*, std::shared_ptr> map2task; + + Task::Task(TaskProc proc) { SetProc(proc); } + + void Task::SetProc(TaskProc proc) { SetProcContext_(shared_from_this(), proc); } + + void Task::RegisterTaskToMap2Task(std::shared_ptr task_ptr) { + map2task[cotask::this_task::get>()] = task_ptr; + } + + void Task::UnregisterTaskFromMap2Task() { + map2task.erase(cotask::this_task::get>()); + } + + void Task::Start() { + state = State::Runable; + task_ptr_->start(); + } + + void Task::Yield() { task_ptr_->yield(); } + + void Task::Resume() { task_ptr_->resume(); } + + void Task::Wait(std::shared_ptr task) { task_mgr.Wait(shared_from_this(), task); } + + void Task::WaitAll(const std::vector>* task_list) { + for (auto iter = task_list->begin(); iter != task_list->end(); iter++) { + task_mgr.Wait(shared_from_this(), *iter); + } + } + + void Task::Stop() { + state = State::Dead; + task_mgr.StopTask(shared_from_this()); + } + + void TaskManager::Wait(std::shared_ptr waiter, std::shared_ptr waitting) { + if (waitting->state == Task::State::Dead) { // waitting taks is completed. + completed_tasks_.erase(waitting); + starting_tasks_.push(waiter); // start waiter. + return; + } + + // waitting task is no completed currently. + waitting_tasks_[waitting] = waiter; + starting_tasks_.push(waitting); + } + + void TaskManager::StopTask(std::shared_ptr task) { + auto iter = waitting_tasks_.find(task); + if (iter != waitting_tasks_.end()) { + starting_tasks_.push(iter->second); + waitting_tasks_.erase(iter); + return; + } + + // if no waitter is waitting, add to completed_tasks_ to avoid being freed when task is + // complete. + completed_tasks_.insert(task); + } + + int TaskManager::RunOnce() { + if (is_stopped) return 0; + int run_task_count = 0; + while (starting_tasks_.size()) { + auto task = starting_tasks_.front(); // use front before size check is undefined behaivour. + starting_tasks_.pop(); + switch (task->state) { + case Task::State::New: + task->Start(); + run_task_count++; + break; + + case Task::State::Runable: + task->Resume(); + run_task_count++; + break; + + default: + break; + } + } + return run_task_count; + } + + bool TaskManager::IsEmpty() { + if (completed_tasks_.size() || waitting_tasks_.size() || starting_tasks_.size()) return false; + return true; + } + + void TaskManager::Stop() { is_stopped = true; } + + std::shared_ptr GetCurrentTask() { + return map2task[cotask::this_task::get>()]; + } + + thread_local TaskManager task_mgr; +} // namespace smark::tasks \ No newline at end of file diff --git a/source/util.cpp b/source/util.cpp index 0b61d2e..0601ceb 100644 --- a/source/util.cpp +++ b/source/util.cpp @@ -5,11 +5,20 @@ #include "debug.h" #include "platform.h" +#include "tasks.h" namespace smark::util { EventLoop::EventLoop() { uv_loop_init(loop_.get()); } - void EventLoop::Wait() { uv_run(loop_.get(), UV_RUN_DEFAULT); } + void EventLoop::Wait() { + while (true) { + int rtc = smark::tasks::task_mgr.RunOnce(); + int uv_res = uv_run(loop_.get(), UV_RUN_ONCE); + DLOG("Run task count:" << LOG_VALUE(rtc) << LOG_VALUE(uv_res) + << LOG_NV("IsEmpty", smark::tasks::task_mgr.IsEmpty())); + if (smark::tasks::task_mgr.IsEmpty() && !uv_run(loop_.get(), UV_RUN_ONCE)) break; + }; + } void EventLoop::Stop() { uv_stop(loop_.get()); } diff --git a/test/source/smark.cpp b/test/source/smark.cpp index 0910098..f0ec2b1 100644 --- a/test/source/smark.cpp +++ b/test/source/smark.cpp @@ -10,6 +10,7 @@ DISABLE_SOME_WARNINGS #include "debug.h" #include "script.h" +#include "tasks.h" #include "util.h" #if defined(_WIN32) || defined(WIN32) @@ -19,16 +20,6 @@ DISABLE_SOME_WARNINGS #include "testsvr.h" -#define INIT_TASK int __task_count = __COUNTER__ -#define SUB_TASK(task) \ - (void)__COUNTER__; \ - task++ -#define END_TASK __task_count = __COUNTER__ - __task_count - 1 - -// do not use '==' to compare string -// do not use string.compare: fail on "This is a response" -#define STR_COMPARE(str, value) strcmp(str.c_str(), value) == 0 - using namespace smark_tests; uint16_t port = SVR_PORT; @@ -54,7 +45,7 @@ TEST_CASE("TCPClient") { int task = 0; INIT_TASK; - util::EventLoop el; + smark::util::EventLoop el; TCPClient cli(&el); const char data[] = "Hello world"; @@ -66,11 +57,11 @@ TEST_CASE("TCPClient") { }; cli.Connect("127.0.0.1", port, [&cli, &task, &data](int status) { if (status) { - ERR("Connect error:" << util::EventLoop::GetErrorStr(status)); + ERR("Connect error:" << smark::util::EventLoop::GetErrorStr(status)); } cli.Write(data, sizeof(data), [](int status) { if (status) { - ERR("Write error:" << util::EventLoop::GetErrorStr(status)); + ERR("Write error:" << smark::util::EventLoop::GetErrorStr(status)); } }); }); @@ -87,12 +78,12 @@ TEST_CASE("FailConnect") { int task = 0; INIT_TASK; - util::EventLoop el; + smark::util::EventLoop el; TCPClient cli(&el); cli.Connect("127.0.0.1", port, [&task](int status) { SUB_TASK(task); if (status) { - DLOG("Test fail connect:" << util::EventLoop::GetErrorStr(status)); + DLOG("Test fail connect:" << smark::util::EventLoop::GetErrorStr(status)); } // Use macro instead of actual status code. @@ -132,18 +123,18 @@ TEST_CASE("HttpClient") { DLOG("Run Http server on port:" << p); INIT_TASK; int task = 0; - auto req = std::make_shared(); + auto req = std::make_shared(); req->method = "Get"; req->request_uri = "/test"; - auto test_header = std::make_shared(); + auto test_header = std::make_shared(); test_header->name = "test-header"; test_header->value = "test_value"; req->headers.push_back(test_header); req->body = "This is a request"; - util::EventLoop el; + smark::util::EventLoop el; HttpClient cli(&el); - cli.on_response = [&task, &el, &cli](auto, std::shared_ptr res) { + cli.on_response = [&task, &el, &cli](auto, std::shared_ptr res) { SUB_TASK(task); CHECK(STR_COMPARE(res->status_code, "OK")); int header_count = res->headers.size(); @@ -156,7 +147,7 @@ TEST_CASE("HttpClient") { }; cli.Connect("127.0.0.1", p, [&cli, &req](int status) { if (status) { - ERR("Connect error:" << util::EventLoop::GetErrorStr(status)); + ERR("Connect error:" << smark::util::EventLoop::GetErrorStr(status)); } cli.Request(req.get()); }); @@ -166,6 +157,55 @@ TEST_CASE("HttpClient") { CHECK(task == __task_count); } +TEST_CASE("HttpAsyncClient") { + auto svr = new SimpleHttpServer(); + DEFER(delete svr;) + std::thread* thread = nullptr; + uint16_t p = RunServer(svr, &thread); + // DEFER(delete thread;) + DLOG("Run Http server on port:" << p); + INIT_TASK; + int task = 0; + + std::thread([=, &task]() { + auto req = std::make_shared(); + req->method = "Get"; + req->request_uri = "/test"; + auto test_header = std::make_shared(); + test_header->name = "test-header"; + test_header->value = "test_value"; + req->headers.push_back(test_header); + req->body = "This is a request"; + + smark::util::EventLoop el; + auto proc = [&](auto) { + HttpAsyncClient cli(&el); + auto status = await(cli.ConnectAsync("127.0.0.1", p))->GetResult(); + if (status) { + ERR("Connect error:" << smark::util::EventLoop::GetErrorStr(status)); + } + DLOG("Connected"); + auto res = await(cli.RequestAsync(req.get()))->GetResult(); + DLOG("Get response."); + CHECK(STR_COMPARE(res->status_code, "OK")); + int header_count = res->headers.size(); + CHECK(header_count == 1); + auto test_header = res->headers[0]; + CHECK(STR_COMPARE(test_header->name, "test-header")); + CHECK(STR_COMPARE(test_header->value, "test_value")); + CHECK(STR_COMPARE(res->body, "This is a response")); + cli.Close(); + SUB_TASK(task); + }; + _async(proc)->Start(); + + el.Wait(); + }).join(); + + END_TASK; + CHECK(task == __task_count); +} + TEST_CASE("Script_Setup") { LuaThread thread; Script script; @@ -246,10 +286,10 @@ TEST_CASE("Script_Response") { " pass=true\n" "end"; script.Run(code); - util::HttpResponse res; + smark::util::HttpResponse res; res.body = "content"; res.status_code = "200"; - auto header = std::make_shared(); + auto header = std::make_shared(); header->name = "test"; header->value = "value"; res.headers.push_back(header); diff --git a/test/source/tasks.cpp b/test/source/tasks.cpp new file mode 100644 index 0000000..cdec92f --- /dev/null +++ b/test/source/tasks.cpp @@ -0,0 +1,182 @@ +#include "platform.h" +DISABLE_SOME_WARNINGS +#include + +#include "tasks.h" +#include "util.h" + +using namespace smark::tasks; +using namespace smark::util; + +TEST_CASE("Task_Simple") { + int task = 0; + INIT_TASK; + + std::thread([&task]() { + auto co_task = make_shared([&](std::shared_ptr this_task) { + SUB_TASK(task); + this_task->Yield(); + SUB_TASK(task); + }); + CHECK(co_task->state == Task::State::New); + co_task->Start(); + CHECK(co_task->state == Task::State::Runable); + co_task->Resume(); + CHECK(co_task->state == Task::State::Dead); + }).join(); + + END_TASK; + CHECK(task == __task_count); +} + +TEST_CASE("Task_Map2Task") { + int task = 0; + INIT_TASK; + + std::thread([&task]() { + SUB_TASK(task); + cotask::task<>* t = nullptr; + auto co_task = make_shared([&t](std::shared_ptr this_task) { + t = cotask::this_task::get>(); + CHECK(GetCurrentTask().get() == this_task.get()); + this_task->Yield(); + CHECK(GetCurrentTask().get() == this_task.get()); + }); + co_task->Start(); + auto r1 = map2task[t]; + CHECK(r1.get() == co_task.get()); + co_task->Resume(); + CHECK(map2task.find(t) == map2task.end()); + }).join(); + + END_TASK; + CHECK(task == __task_count); +} + +TEST_CASE("Task_Async") { + int task = 0; + INIT_TASK; + + std::thread([&task]() { + auto t1 = _async([&](std::shared_ptr this_task) { + (void)this_task; + auto child_task = task_async([&](std::shared_ptr this_task) { + SUB_TASK(task); + this_task->Stop(); + }); + SUB_TASK(task); + await(child_task); + SUB_TASK(task); + }); + t1->Start(); + + while (task_mgr.RunOnce()) + ; + }).join(); + + END_TASK; + CHECK(task == __task_count); +} + +TEST_CASE("Task_StopFromOutside") { + int task = 0; + INIT_TASK; + + std::thread([&task]() { + std::function func([]() {}); + + auto t1 = _async([&](std::shared_ptr this_task) { + (void)this_task; + auto child_task = task_async([&](std::shared_ptr this_task) { + SUB_TASK(task); + func = [this_task]() { this_task->Stop(); }; + }); + SUB_TASK(task); + await(child_task); + SUB_TASK(task); + }); + t1->Start(); + + task_mgr.RunOnce(); // run child_task + func(); // set result of child_task + task_mgr.RunOnce(); // let child_task trigger parent_task + + CHECK(task_mgr.RunOnce() == 0); // ensure no running task remain. + }).join(); + + END_TASK; + CHECK(task == __task_count); +} + +TEST_CASE("Task_StopBeforeWait") { + int task = 0; + INIT_TASK; + + std::thread([&task]() { + auto proc = [&](std::shared_ptr this_task) { + (void)this_task; + auto child_task = void_task(); + child_task->Stop(); + SUB_TASK(task); + await(child_task); + SUB_TASK(task); + }; + auto t1 = _async(proc); + t1->Start(); + + task_mgr.RunOnce(); // restart t1 + + CHECK(task_mgr.RunOnce() == 0); // ensure no running task remain. + }).join(); + + END_TASK; + CHECK(task == __task_count); +} + +TEST_CASE("Task_ValueTask") { + int task = 0; + INIT_TASK; + + std::thread([&task]() { + auto co_task = _async(int, [&](std::shared_ptr> this_task) { + SUB_TASK(task); + this_task->Yield(); + SUB_TASK(task); + }); + CHECK(co_task->GetState() == Task::State::New); + co_task->Start(); + CHECK(co_task->GetState() == Task::State::Runable); + co_task->Resume(); + CHECK(co_task->GetState() == Task::State::Dead); + }).join(); + + END_TASK; + CHECK(task == __task_count); +} + +TEST_CASE("Task_ValueTaskAsync") { + int task = 0; + INIT_TASK; + + std::thread([&task]() { + auto proc = [&](std::shared_ptr this_task) { + (void)this_task; + auto child_task = _async(int, [&](std::shared_ptr> t) { + SUB_TASK(task); + t->Complete(10); + }); + SUB_TASK(task); + CHECK(await(child_task)->GetResult() == 10); + SUB_TASK(task); + }; // TODO: why? lambda-expression in template-argument only available with ‘-std=c++2a’ or + // ‘-std=gnu++2a’ + auto t1 = _async(proc); + t1->Start(); + + while (task_mgr.RunOnce()) + ; + }).join(); + + END_TASK; + CHECK(task == __task_count); +} \ No newline at end of file diff --git a/test/source/util.h b/test/source/util.h new file mode 100644 index 0000000..a571bb7 --- /dev/null +++ b/test/source/util.h @@ -0,0 +1,9 @@ +#define INIT_TASK int __task_count = __COUNTER__ +#define SUB_TASK(task) \ + (void)__COUNTER__; \ + task++ +#define END_TASK __task_count = __COUNTER__ - __task_count - 1 + +// do not use '==' to compare string +// do not use string.compare: fail on "This is a response" +#define STR_COMPARE(str, value) strcmp(str.c_str(), value) == 0 \ No newline at end of file