Skip to content

feat: support shuffle dataset #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions infini_train/include/dataloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
#include <utility>

#include "infini_train/include/dataset.h"
#include "infini_train/include/sampler.h"
#include "infini_train/include/tensor.h"

namespace infini_train {
class DataLoaderIterator {
public:
DataLoaderIterator(const Dataset &dataset, size_t batch_size, size_t batch_idx, size_t max_batch_idx);
DataLoaderIterator(const Dataset &dataset, size_t batch_size, size_t batch_idx, size_t max_batch_idx,
const std::vector<size_t> &indices);

std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>> operator*() const;

Expand All @@ -26,11 +28,13 @@ class DataLoaderIterator {
size_t batch_size_ = 0;
size_t batch_idx_ = 0;
size_t max_batch_idx_ = 0;
const std::vector<size_t> *indices_ = nullptr;
};

class DataLoader {
public:
DataLoader(const std::shared_ptr<Dataset> &dataset, size_t batch_size);
DataLoader(const std::shared_ptr<Dataset> &dataset, size_t batch_size, bool shuffle = false,
std::unique_ptr<Sampler> sampler = nullptr);

DataLoaderIterator begin() const;
DataLoaderIterator end() const;
Expand All @@ -39,5 +43,7 @@ class DataLoader {
std::shared_ptr<Dataset> dataset_;
size_t batch_size_ = 0;
size_t max_batch_idx_ = 0;
std::unique_ptr<Sampler> sampler_;
std::shared_ptr<std::vector<size_t>> indices_;
};
} // namespace infini_train
41 changes: 41 additions & 0 deletions infini_train/include/sampler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#pragma once

#include <vector>

#include "infini_train/include/dataset.h"
#include "infini_train/include/tensor.h"

namespace infini_train {
class Sampler {
public:
virtual ~Sampler() = default;
virtual std::vector<size_t> GetIndices(size_t dataset_size) = 0;
};

class RandomSampler : public Sampler {
public:
RandomSampler(bool repalcement = false, size_t num_samples = 0)
: repalcement_(repalcement), num_samples_(num_samples) {}

std::vector<size_t> GetIndices(size_t dataset_size) override;

private:
bool repalcement_;
size_t num_samples_;
};

class SequentialSampler : public Sampler {
public:
SequentialSampler() {}
std::vector<size_t> GetIndices(size_t dataset_size) override;
};

class SubsetRandomSampler : public Sampler {
public:
SubsetRandomSampler(std::vector<size_t> &indices) : indices_(&indices) {}
std::vector<size_t> GetIndices(size_t dataset_size) override;

private:
std::shared_ptr<std::vector<size_t>> indices_;
};
} // namespace infini_train
35 changes: 28 additions & 7 deletions infini_train/src/dataloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "glog/logging.h"

#include "infini_train/include/dataset.h"
#include "infini_train/include/sampler.h"
#include "infini_train/include/tensor.h"

namespace infini_train {
Expand Down Expand Up @@ -36,8 +37,9 @@ std::shared_ptr<Tensor> Stack(const std::vector<std::shared_ptr<Tensor>> &tensor
} // namespace

DataLoaderIterator::DataLoaderIterator(const Dataset &dataset, size_t batch_size, size_t batch_idx,
size_t max_batch_idx)
: dataset_(&dataset), batch_size_(batch_size), batch_idx_(batch_idx), max_batch_idx_(max_batch_idx){};
size_t max_batch_idx, const std::vector<size_t> &indices)
: dataset_(&dataset), batch_size_(batch_size), batch_idx_(batch_idx), max_batch_idx_(max_batch_idx),
indices_(&indices){};

std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>> DataLoaderIterator::operator*() const {
/*
Expand All @@ -49,7 +51,7 @@ std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>> DataLoaderIterator::
std::vector<std::shared_ptr<Tensor>> data_vec;
std::vector<std::shared_ptr<Tensor>> label_vec;
for (int idx = batch_idx_ * batch_size_; idx < (batch_idx_ + 1) * batch_size_ && idx < dataset_->Size(); ++idx) {
auto &&[data, label] = dataset_->operator[](idx);
auto &&[data, label] = dataset_->operator[](indices_->at(idx));
data_vec.push_back(std::move(data));
label_vec.push_back(std::move(label));
}
Expand Down Expand Up @@ -77,12 +79,31 @@ bool operator==(const DataLoaderIterator &lhs, const DataLoaderIterator &rhs) {
return lhs.batch_idx_ == rhs.batch_idx_;
}

DataLoader::DataLoader(const std::shared_ptr<Dataset> &dataset, size_t batch_size)
: dataset_(dataset), batch_size_(batch_size), max_batch_idx_((dataset_->Size() + batch_size_ - 1) / batch_size_) {}
DataLoader::DataLoader(const std::shared_ptr<Dataset> &dataset, size_t batch_size, bool shuffle,
std::unique_ptr<Sampler> sampler)
: dataset_(dataset), batch_size_(batch_size), max_batch_idx_((dataset_->Size() + batch_size_ - 1) / batch_size_),
sampler_(std::move(sampler)) {

DataLoaderIterator DataLoader::begin() const { return DataLoaderIterator(*dataset_, batch_size_, 0, max_batch_idx_); }
if (sampler_ != nullptr && shuffle) {
LOG(FATAL) << "Sampler option is mutually exclusive with 'shuffle'";
}

if (sampler_ == nullptr) {
if (shuffle) {
sampler_ = std::make_unique<RandomSampler>(false);
} else {
sampler_ = std::make_unique<SequentialSampler>();
}
}

indices_ = std::make_shared<std::vector<size_t>>(sampler_->GetIndices(dataset_->Size()));
}

DataLoaderIterator DataLoader::begin() const {
return DataLoaderIterator(*dataset_, batch_size_, 0, max_batch_idx_, *indices_);
}

DataLoaderIterator DataLoader::end() const {
return DataLoaderIterator(*dataset_, batch_size_, max_batch_idx_, max_batch_idx_);
return DataLoaderIterator(*dataset_, batch_size_, max_batch_idx_, max_batch_idx_, *indices_);
}
} // namespace infini_train
46 changes: 46 additions & 0 deletions infini_train/src/sampler.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "glog/logging.h"

#include "infini_train/include/sampler.h"

namespace infini_train {
std::vector<size_t> RandomSampler::GetIndices(size_t dataset_size) {
std::vector<size_t> indices;
std::random_device rd;
std::mt19937 gen(rd());

const size_t n = num_samples_ > 0 ? num_samples_ : dataset_size;

if (repalcement_) {
indices.resize(n);
std::uniform_int_distribution<size_t> dist(0, dataset_size - 1);
for (size_t i = 0; i < n; ++i) { indices[i] = dist(gen); }
} else {
indices.resize(dataset_size);
for (size_t i = 0; i < dataset_size; ++i) { indices[i] = i; }
std::shuffle(indices.begin(), indices.end(), gen);
if (n < dataset_size) {
indices.resize(n);
}
}

return indices;
}

std::vector<size_t> SequentialSampler::GetIndices(size_t dataset_size) {
std::vector<size_t> indices(dataset_size);

for (int i = 0; i < dataset_size; i++) { indices[i] = i; }

return indices;
}

std::vector<size_t> SubsetRandomSampler::GetIndices(size_t dataset_size) {
std::random_device rd;
std::mt19937 gen(rd());
std::vector<size_t> indices(*indices_);

std::shuffle(indices.begin(), indices.end(), gen);

return indices;
};
} // namespace infini_train