Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions CUDA/includes/attention.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include "tensor.cuh"

#include <cuda_runtime.h>

namespace quadtrix {
namespace cuda {

Status attention_forward(
const TensorView& input_qkv,
TensorView preatt,
TensorView att,
TensorView output,
int num_heads,
cudaStream_t stream = nullptr);

Status attention_backward(
const TensorView& grad_output,
const TensorView& input_qkv,
const TensorView& att,
TensorView grad_input_qkv,
TensorView grad_preatt,
TensorView grad_att,
int num_heads,
cudaStream_t stream = nullptr);

} // namespace cuda
} // namespace quadtrix
25 changes: 25 additions & 0 deletions CUDA/includes/checkpoint.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include "tensor.cuh"

namespace quadtrix {
namespace cuda {

struct CheckpointMetadata {
int vocab_size = 0;
int max_sequence_length = 0;
int num_layers = 0;
int num_heads = 0;
int channels = 0;
};

inline bool load_checkpoint_metadata(const char*, CheckpointMetadata*) {
return false;
}

inline bool save_tensor_checkpoint(const char*, const TensorView&) {
return false;
}

} // namespace cuda
} // namespace quadtrix
120 changes: 120 additions & 0 deletions CUDA/includes/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#pragma once

#include <cuda_runtime.h>

#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <limits>

namespace quadtrix {
namespace cuda {

enum class DType : std::uint8_t {
F32,
F16,
BF16,
I32,
U8,
};

enum class DeviceKind : std::uint8_t {
CPU,
CUDA,
};

struct Status {
bool ok;
cudaError_t cuda_error;
const char* message;

static Status success() {
return {true, cudaSuccess, "ok"};
}

static Status failure(cudaError_t error, const char* message) {
return {false, error, message};
}
};

inline const char* dtype_name(DType dtype) {
switch (dtype) {
case DType::F32:
return "f32";
case DType::F16:
return "f16";
case DType::BF16:
return "bf16";
case DType::I32:
return "i32";
case DType::U8:
return "u8";
}
return "unknown";
}

inline std::size_t dtype_size(DType dtype) {
switch (dtype) {
case DType::F32:
return 4;
case DType::F16:
return 2;
case DType::BF16:
return 2;
case DType::I32:
return 4;
case DType::U8:
return 1;
}

std::fprintf(stderr, "Unknown CUDA dtype value %u\n", static_cast<unsigned int>(dtype));
std::abort();
}

inline bool checked_mul(std::size_t lhs, std::size_t rhs, std::size_t* out) {
if (lhs != 0 && rhs > std::numeric_limits<std::size_t>::max() / lhs) {
return false;
}
*out = lhs * rhs;
return true;
}

inline Status check_cuda(cudaError_t error, const char* expression, const char* file, int line) {
if (error == cudaSuccess) {
return Status::success();
}

std::fprintf(
stderr,
"CUDA error at %s:%d: %s failed with %s\n",
file,
line,
expression,
cudaGetErrorString(error));
return Status::failure(error, expression);
}

inline void abort_on_cuda(cudaError_t error, const char* expression, const char* file, int line) {
if (error == cudaSuccess) {
return;
}

std::fprintf(
stderr,
"Fatal CUDA error at %s:%d: %s failed with %s\n",
file,
line,
expression,
cudaGetErrorString(error));
std::abort();
}

} // namespace cuda
} // namespace quadtrix

#define QUADTRIX_CUDA_CHECK(expr) \
::quadtrix::cuda::check_cuda((expr), #expr, __FILE__, __LINE__)

#define QUADTRIX_CUDA_ABORT(expr) \
::quadtrix::cuda::abort_on_cuda((expr), #expr, __FILE__, __LINE__)
29 changes: 29 additions & 0 deletions CUDA/includes/dataloader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <cstddef>
#include <cstdint>

namespace quadtrix {
namespace cuda {

struct TokenBatchView {
const std::int32_t* inputs = nullptr;
const std::int32_t* targets = nullptr;
int batch_size = 0;
int sequence_length = 0;
};

class DataLoader {
public:
DataLoader() = default;

bool next(TokenBatchView* batch) {
if (batch != nullptr) {
*batch = {};
}
return false;
}
};

} // namespace cuda
} // namespace quadtrix
31 changes: 31 additions & 0 deletions CUDA/includes/gelu.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include "tensor.cuh"

#include <cuda_runtime.h>

#include <cstdint>

namespace quadtrix {
namespace cuda {

enum class GeluMode : std::uint8_t {
Exact,
Approximate,
};

Status gelu_forward(
const TensorView& input,
TensorView output,
GeluMode mode = GeluMode::Approximate,
cudaStream_t stream = nullptr);

Status gelu_backward(
const TensorView& grad_output,
const TensorView& input,
TensorView grad_input,
GeluMode mode = GeluMode::Approximate,
cudaStream_t stream = nullptr);

} // namespace cuda
} // namespace quadtrix
26 changes: 26 additions & 0 deletions CUDA/includes/global_norm.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "tensor.cuh"

#include <cuda_runtime.h>

namespace quadtrix {
namespace cuda {

Status global_norm_squared(
const TensorView& grads,
TensorView partial_sums,
cudaStream_t stream = nullptr);

Status clip_gradients_by_global_norm(
TensorView grads,
float global_norm,
float max_norm,
cudaStream_t stream = nullptr);

inline float clip_scale(float global_norm, float max_norm) {
return global_norm > max_norm && global_norm > 0.0f ? max_norm / global_norm : 1.0f;
}

} // namespace cuda
} // namespace quadtrix
Loading