Skip to content

ubatch : new splitting logic #14217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
916 changes: 557 additions & 359 deletions src/llama-batch.cpp

Large diffs are not rendered by default.

168 changes: 98 additions & 70 deletions src/llama-batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,118 +2,146 @@

#include "llama.h"

#include "llama-cparams.h"

#include <array>
#include <vector>
#include <set>
#include <bitset>
#include <unordered_map>

// very similar to llama_batch,
// but has more metadata about sequences
// keep this struct lightweight
// it points to data in `llama_batch_allocr`
struct llama_ubatch {
bool equal_seqs;
// TODO: whole_seqs for embeddings?

uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
uint32_t n_seq_tokens; // tokens per sequence
uint32_t n_seqs;

llama_token * token; // [n_tokens]
float * embd; // [n_embd, n_tokens]
llama_pos * pos; // [n_tokens]
int32_t * n_seq_id; // [n_seqs]
llama_seq_id ** seq_id; // [n_seqs]
int8_t * output; // [n_tokens]
};

struct llama_sbatch_seq {
int32_t n_seq_id;

llama_seq_id * seq_id;

size_t offset;
size_t length;
};

// sequence-length-aware batch splitting
struct llama_sbatch {
// tokens left in this batch
size_t n_tokens;

size_t n_embd;

// sorted indices into the batch
std::vector<int64_t> ids;
// batch indices of the output
std::vector<int64_t> out_ids;
std::vector<llama_sbatch_seq> seq;

const llama_batch * batch = nullptr;

// buffers for the ubatches
// TODO: very hacky, this needs a complete rework
struct ubatch_data {
std::vector<llama_token> token;
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<int8_t> output;
};

std::vector<ubatch_data> udatas;

llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);

void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);

// simple split, unknown number of sequences of unequal lengths
llama_ubatch split_simple(size_t n_ubatch);

// make batches of equal-length sequences
llama_ubatch split_equal(size_t n_ubatch);

// sequence-wise split
llama_ubatch split_seq(size_t n_ubatch);

llama_sbatch() = default;
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
uint32_t n_seq_tokens; // tokens per sequence set
uint32_t n_seqs; // sequence sets in the ubatch
uint32_t n_seqs_unq; // unique sequence ids in the ubatch

// seq_id_unq: unique sequence ids in the ubatch
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
// used for extracting sequence pooled embeddings

// // size | idx | val
llama_token * token; // [n_tokens] | i | id, token
float * embd; // [n_embd, n_tokens] | i | embd
llama_pos * pos; // [n_tokens] | i | pos
int32_t * n_seq_id; // [n_tokens] | i | -
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
int8_t * output; // [n_tokens] | i | -
};

// a helper for sanitizing and fulfilling a batch
// a helper for sanitizing, fulfilling and splitting a batch
class llama_batch_allocr {
public:
llama_batch_allocr();
llama_batch_allocr(uint32_t n_pos_per_embd);

// sanitize and auto-gen missing data in the input batch
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
bool init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory,
bool embd_all);
uint32_t n_embd,
bool output_all);

const llama_batch & get_batch() const;

uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const;

// the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids();

// min/max positions of each sequence in the current ubatch
llama_pos seq_pos_min(llama_seq_id seq_id) const;
llama_pos seq_pos_max(llama_seq_id seq_id) const;

// call once before splitting the batch to reset the internal state
void split_reset();

// simple split, unknown number of sequence sets of unequal lengths
llama_ubatch split_simple(uint32_t n_ubatch);

// make ubatches of equal-length sequences sets
llama_ubatch split_equal(uint32_t n_ubatch);

// sequence-set-wise split - each ubatch contains a single sequence-set
llama_ubatch split_seq(uint32_t n_ubatch);

// a helper method for creating a well-defined ubatch of tokens
// TODO: support embeddings if needed in the future
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);

private:
void clear();

// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);

// for debugging, start with LLAMA_BATCH_DEBUG=2
void ubatch_print(const llama_ubatch & ubatch, int debug);

llama_batch batch;

// only for debugging purposes
const llama_vocab * vocab;

// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
const uint32_t n_pos_per_embd;

uint32_t n_embd;
uint32_t n_outputs;

std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id

std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output;

std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
using pos_set_t = std::set<llama_pos>;
using seq_cpl_t = std::vector<bool>;

std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1

using idx_vec_t = std::vector<int32_t>;
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;

std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i

std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears

// batch indices of the output
std::vector<int32_t> out_ids;

// used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used;

// llama_ubatch points to this data:
struct ubatch {
std::vector<llama_token> token;
std::vector<float> embd;
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id *> seq_id;
std::vector<llama_seq_id> seq_id_unq;
std::vector<int32_t> seq_idx;
std::vector<int8_t> output;
};

// current splitting state:
std::vector<ubatch> ubatches;

int debug;
};
Loading
Loading