Skip to content

Commit b5eae57

Browse files
committed
Reinforcement Learning Model Added (Version 1.5.0)
1 parent 2753768 commit b5eae57

File tree

10 files changed

+758
-90
lines changed

10 files changed

+758
-90
lines changed

CMakeLists.txt

Lines changed: 204 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,228 @@
1-
cmake_minimum_required(VERSION 3.10)
2-
project(VelocityVortex)
1+
# cmake_minimum_required(VERSION 3.15)
2+
# project(VelocityVortex)
3+
4+
# set(CMAKE_CXX_STANDARD 20)
5+
# set(CMAKE_CXX_STANDARD_REQUIRED True)
6+
# set(CMAKE_PREFIX_PATH "/usr/local/libtorch")
7+
8+
9+
# find_package(Torch REQUIRED)
10+
# include_directories(include)
11+
# find_package(Python3 COMPONENTS Interpreter Development NumPy REQUIRED)
12+
13+
# include_directories(${Python3_INCLUDE_DIRS})
14+
# include_directories(${Python3_NumPy_INCLUDE_DIRS})
15+
16+
17+
# find_package(OpenSSL REQUIRED)
18+
# find_package(Boost REQUIRED COMPONENTS system)
19+
# find_package(jsoncpp REQUIRED)
20+
# find_package(CURL REQUIRED)
21+
22+
23+
# find_library(HIREDIS_LIB hiredis REQUIRED)
24+
# find_library(SQLITE3_LIB sqlite3 REQUIRED)
25+
# find_package(Threads REQUIRED)
26+
27+
28+
# add_library(websocketpp INTERFACE)
29+
# target_include_directories(websocketpp INTERFACE ${CMAKE_SOURCE_DIR}/websocketpp)
30+
31+
32+
# file(GLOB_RECURSE ALGO_ENGINE_SOURCES "src/AlgoEngine-Core/**/*.cpp")
33+
# file(GLOB_RECURSE DATA_FETCHER_SOURCES "src/Data-Fetcher-Core/*.cpp")
34+
# file(GLOB_RECURSE IO_BROKER_SOURCES "src/IO-Broker-Core/*.cpp")
35+
# file(GLOB_RECURSE ORDER_MANAGER_SOURCES "src/Order-Manager-Core/*.cpp")
36+
# file(GLOB_RECURSE RISK_ANALYSIS_SOURCES "src/Risk-Analysis-Core/*.cpp")
37+
# file(GLOB_RECURSE VELOCITY_BOT_SOURCES "src/Velocity-Bot/*.cpp")
38+
# file(GLOB_RECURSE BACKTESTING_BOT_SOURCES "src/Backtesting-Bot/*.cpp")
39+
# file(GLOB_RECURSE ORDERBOOK_SOURCES "src/Orderbook/*.cpp")
40+
# file(GLOB_RECURSE UTILITY_SOURCES "src/Utilities/*.cpp")
41+
42+
43+
# set(UTILITIES_SOURCES
44+
# include/Utilities/Bar.hpp
45+
# include/Utilities/OHLCV.hpp
46+
# include/Utilities/Quote.hpp
47+
# include/Utilities/SignalResult.hpp
48+
# include/Utilities/Trade.hpp
49+
# include/Utilities/Utilities.hpp
50+
# )
51+
52+
53+
# set(SOURCES
54+
# ${ALGO_ENGINE_SOURCES}
55+
# ${DATA_FETCHER_SOURCES}
56+
# ${IO_BROKER_SOURCES}
57+
# ${ORDER_MANAGER_SOURCES}
58+
# ${RISK_ANALYSIS_SOURCES}
59+
# ${VELOCITY_BOT_SOURCES}
60+
# ${BACKTESTING_BOT_SOURCES}
61+
# ${ORDERBOOK_SOURCES}
62+
# ${UTILITY_SOURCES}
63+
# src/main.cpp
64+
# )
65+
66+
67+
# add_executable(VelocityVortex ${SOURCES})
68+
69+
70+
# target_link_libraries(VelocityVortex
71+
# ${TORCH_LIBRARIES}
72+
# CURL::libcurl
73+
# jsoncpp_lib
74+
# ${HIREDIS_LIB}
75+
# ${SQLITE3_LIB}
76+
# ${Python3_LIBRARIES}
77+
# websocketpp
78+
# Threads::Threads
79+
# OpenSSL::SSL
80+
# OpenSSL::Crypto
81+
# ${Boost_LIBRARIES}
82+
# )
83+
84+
85+
# set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/..)
86+
87+
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
88+
89+
90+
cmake_minimum_required(VERSION 3.15)
91+
project(VelocityVortex
92+
VERSION 1.0.0
93+
DESCRIPTION "High-frequency trading system"
94+
LANGUAGES CXX)
95+
396

497
set(CMAKE_CXX_STANDARD 20)
5-
set(CMAKE_CXX_STANDARD_REQUIRED True)
98+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
99+
set(CMAKE_CXX_EXTENSIONS OFF)
100+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
101+
102+
103+
include(CheckIPOSupported)
104+
check_ipo_supported(RESULT IPO_SUPPORTED)
105+
if(IPO_SUPPORTED)
106+
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
107+
endif()
108+
109+
110+
if(NOT CMAKE_BUILD_TYPE)
111+
set(CMAKE_BUILD_TYPE Release)
112+
endif()
113+
114+
115+
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
116+
add_compile_options(
117+
-O3
118+
-march=native
119+
-Wall
120+
-Wextra
121+
-Wpedantic
122+
-ffast-math
123+
)
124+
elseif(MSVC)
125+
add_compile_options(
126+
/O2
127+
/W4
128+
/arch:AVX2
129+
)
130+
endif()
131+
132+
133+
set(SOURCE_GROUPS
134+
"AlgoEngine-Core"
135+
"Data-Fetcher-Core"
136+
"IO-Broker-Core"
137+
"Order-Manager-Core"
138+
"Risk-Analysis-Core"
139+
"Velocity-Bot"
140+
"Backtesting-Bot"
141+
"Orderbook"
142+
"Utilities"
143+
)
6144

7-
include_directories(include)
8145

9-
find_package(Python3 COMPONENTS Interpreter Development NumPy REQUIRED)
146+
set(ALL_SOURCES src/main.cpp)
147+
foreach(GROUP ${SOURCE_GROUPS})
148+
file(GLOB_RECURSE GROUP_SOURCES
149+
"src/${GROUP}/*.cpp"
150+
"src/${GROUP}/*.hpp"
151+
)
152+
list(APPEND ALL_SOURCES ${GROUP_SOURCES})
153+
source_group(${GROUP} FILES ${GROUP_SOURCES})
154+
endforeach()
10155

11156

157+
add_executable(${PROJECT_NAME} ${ALL_SOURCES})
12158

13-
include_directories(${Python3_INCLUDE_DIRS})
14-
include_directories(${Python3_NumPy_INCLUDE_DIRS})
15-
include_directories(${OPENSSL_INCLUDE_DIR})
16-
include_directories(${Boost_INCLUDE_DIRS})
17-
include_directories(${JSONCPP_INCLUDE_DIRS})
159+
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
160+
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
161+
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
18162

19163

164+
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH} "/usr/local/libtorch")
165+
find_package(Torch REQUIRED)
166+
find_package(Python3 COMPONENTS Interpreter Development NumPy REQUIRED)
167+
find_package(OpenSSL REQUIRED)
168+
find_package(Boost REQUIRED COMPONENTS system)
169+
find_package(jsoncpp REQUIRED)
20170
find_package(CURL REQUIRED)
171+
find_package(Threads REQUIRED)
172+
173+
174+
21175
find_library(HIREDIS_LIB hiredis REQUIRED)
22176
find_library(SQLITE3_LIB sqlite3 REQUIRED)
23-
find_package(jsoncpp REQUIRED)
24177

25178

26-
find_package(Threads REQUIRED)
27-
find_package(OpenSSL REQUIRED)
28-
find_package(Boost REQUIRED COMPONENTS system)
29-
30179

31180
add_library(websocketpp INTERFACE)
32-
target_include_directories(websocketpp INTERFACE ${CMAKE_SOURCE_DIR}/websocketpp)
33-
34-
35-
file(GLOB_RECURSE ALGO_ENGINE_SOURCES "src/AlgoEngine-Core/indicators/*.cpp")
36-
file(GLOB_RECURSE DATA_FETCHER_SOURCES "src/Data-Fetcher-Core/*.cpp")
37-
file(GLOB_RECURSE IO_BROKER_SOURCES "src/IO-Broker-Core/*.cpp")
38-
file(GLOB_RECURSE ORDER_MANAGER_SOURCES "src/Order-Manager-Core/*.cpp")
39-
file(GLOB_RECURSE RISK_ANALYSIS_SOURCES "src/Risk-Analysis-Core/*.cpp")
40-
file(GLOB_RECURSE VELOCITY_BOT_SOURCES "src/Velocity-Bot/*.cpp")
41-
file(GLOB_RECURSE BACKTESTING_BOT_SOURCES "src/Backtesting-Bot/*.cpp")
42-
file(GLOB_RECURSE ORDERBOOK_SOURCES "src/Orderbook/*.cpp")
43-
file(GLOB_RECURSE UTILITY_SOURCES "src/Utilities/*.cpp")
44-
45-
46-
set(UTILITIES_SOURCES
47-
include/Utilities/Bar.hpp
48-
include/Utilities/OHLCV.hpp
49-
include/Utilities/Quote.hpp
50-
include/Utilities/SignalResult.hpp
51-
include/Utilities/Trade.hpp
52-
include/Utilities/Utilities.hpp
181+
target_include_directories(websocketpp
182+
INTERFACE
183+
${CMAKE_SOURCE_DIR}/websocketpp
53184
)
54185

55186

56-
set(SOURCES
57-
${ALGO_ENGINE_SOURCES}
58-
${DATA_FETCHER_SOURCES}
59-
${IO_BROKER_SOURCES}
60-
${ORDER_MANAGER_SOURCES}
61-
${RISK_ANALYSIS_SOURCES}
62-
${VELOCITY_BOT_SOURCES}
63-
${BACKTESTING_BOT_SOURCES}
64-
${ORDERBOOK_SOURCES}
65-
${UTILITY_SOURCES}
66-
src/main.cpp
187+
target_include_directories(${PROJECT_NAME}
188+
PRIVATE
189+
${CMAKE_SOURCE_DIR}/include
190+
${Python3_INCLUDE_DIRS}
191+
${Python3_NumPy_INCLUDE_DIRS}
67192
)
68193

69-
70-
add_executable(VelocityVortex ${SOURCES})
71-
72-
73-
target_link_libraries(VelocityVortex
74-
CURL::libcurl
75-
jsoncpp_lib
76-
${HIREDIS_LIB}
77-
${SQLITE3_LIB}
78-
${Python3_LIBRARIES}
79-
websocketpp
80-
Threads::Threads
81-
OpenSSL::SSL
82-
OpenSSL::Crypto
83-
${Boost_LIBRARIES}
194+
target_link_libraries(${PROJECT_NAME}
195+
PRIVATE
196+
${TORCH_LIBRARIES}
197+
CURL::libcurl
198+
jsoncpp_lib
199+
${HIREDIS_LIB}
200+
${SQLITE3_LIB}
201+
${Python3_LIBRARIES}
202+
websocketpp
203+
Threads::Threads
204+
OpenSSL::SSL
205+
OpenSSL::Crypto
206+
${Boost_LIBRARIES}
84207
)
85208

86209

87-
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/..)
210+
target_compile_options(${PROJECT_NAME} PRIVATE ${TORCH_CXX_FLAGS})
211+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pipe")
212+
213+
if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.16)
214+
target_precompile_headers(${PROJECT_NAME}
215+
PRIVATE
216+
<vector>
217+
<string>
218+
<memory>
219+
<map>
220+
<unordered_map>
221+
)
222+
endif()
223+
224+
install(TARGETS ${PROJECT_NAME}
225+
RUNTIME DESTINATION bin
226+
LIBRARY DESTINATION lib
227+
ARCHIVE DESTINATION lib
228+
)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#pragma once
2+
#include <torch/torch.h>
3+
#include <vector>
4+
#include <random>
5+
#include <fstream>
6+
#include <sstream>
7+
#include <memory>
8+
#include <sstream>
9+
#include "AlgoEngine-Core/Reinforcement_models/Deep_Q-Learning.hpp"
10+
#include "Utilities/ModelTransfer.hpp"
11+
#include "Utilities/ReplayMemory.hpp"
12+
#include <torch/torch.h>
13+
#include <fstream>
14+
#include <memory>
15+
16+
class Agent
17+
{
18+
public:
19+
Agent(int replay_mem_size = 10000,
20+
int batch_size = 40,
21+
double gamma = 0.98,
22+
double eps_start = 1.0,
23+
double eps_end = 0.12,
24+
int eps_steps = 300,
25+
double learning_rate = 0.001,
26+
int input_dim = 24,
27+
int action_number = 3,
28+
bool is_double_dqn = true);
29+
30+
torch::Tensor select_action(torch::Tensor state);
31+
void optimize_model();
32+
void store_transition(torch::Tensor state, torch::Tensor action,
33+
torch::Tensor next_state, torch::Tensor reward);
34+
void update_target_network();
35+
36+
private:
37+
void optimize_double_dqn(const torch::Tensor &states, torch::Tensor &next_state_values,
38+
const std::vector<bool> &non_final_mask,
39+
const std::vector<torch::Tensor> &next_states);
40+
void optimize_regular_dqn(torch::Tensor &next_state_values,
41+
const std::vector<bool> &non_final_mask,
42+
const std::vector<torch::Tensor> &next_states);
43+
44+
int replay_mem_size;
45+
int batch_size;
46+
double gamma;
47+
double eps_start;
48+
double eps_end;
49+
int eps_steps;
50+
std::unique_ptr<ConvDuelingDQN> policy_net;
51+
std::unique_ptr<ConvDuelingDQN> target_net;
52+
torch::optim::Adam optimizer;
53+
ReplayMemory memory;
54+
std::mt19937 rng;
55+
int steps_done;
56+
bool is_double_dqn;
57+
torch::Device device;
58+
};
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#pragma once
2+
#ifndef CONV_NETWORKS_HPP
3+
#define CONV_NETWORKS_HPP
4+
5+
#include <torch/torch.h>
6+
7+
class ConvDQN : public torch::nn::Module
8+
{
9+
private:
10+
torch::nn::Sequential layers{nullptr};
11+
int input_dimension;
12+
int action_number;
13+
14+
public:
15+
ConvDQN(int input_dim, int action_number);
16+
17+
torch::Tensor forward(torch::Tensor x);
18+
};
19+
20+
class ConvDuelingDQN : public torch::nn::Module
21+
{
22+
private:
23+
torch::nn::Sequential value_stream{nullptr};
24+
torch::nn::Sequential advantage_stream{nullptr};
25+
int input_dimension;
26+
int action_number;
27+
28+
public:
29+
ConvDuelingDQN(int input_dim, int action_number);
30+
31+
torch::Tensor forward(torch::Tensor x);
32+
};
33+
34+
#endif

0 commit comments

Comments
 (0)