diff --git a/infini_train/include/dataloader.h b/infini_train/include/dataloader.h index 5e9cd43..d4bbc32 100644 --- a/infini_train/include/dataloader.h +++ b/infini_train/include/dataloader.h @@ -5,12 +5,14 @@ #include #include "infini_train/include/dataset.h" +#include "infini_train/include/sampler.h" #include "infini_train/include/tensor.h" namespace infini_train { class DataLoaderIterator { public: - DataLoaderIterator(const Dataset &dataset, size_t batch_size, size_t batch_idx, size_t max_batch_idx); + DataLoaderIterator(const Dataset &dataset, size_t batch_size, size_t batch_idx, size_t max_batch_idx, + const std::vector &indices); std::pair, std::shared_ptr> operator*() const; @@ -26,11 +28,13 @@ class DataLoaderIterator { size_t batch_size_ = 0; size_t batch_idx_ = 0; size_t max_batch_idx_ = 0; + const std::vector *indices_ = nullptr; }; class DataLoader { public: - DataLoader(const std::shared_ptr &dataset, size_t batch_size); + DataLoader(const std::shared_ptr &dataset, size_t batch_size, bool shuffle = false, + std::unique_ptr sampler = nullptr); DataLoaderIterator begin() const; DataLoaderIterator end() const; @@ -39,5 +43,7 @@ class DataLoader { std::shared_ptr dataset_; size_t batch_size_ = 0; size_t max_batch_idx_ = 0; + std::unique_ptr sampler_; + std::shared_ptr> indices_; }; } // namespace infini_train diff --git a/infini_train/include/sampler.h b/infini_train/include/sampler.h new file mode 100644 index 0000000..347caf9 --- /dev/null +++ b/infini_train/include/sampler.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +#include "infini_train/include/dataset.h" +#include "infini_train/include/tensor.h" + +namespace infini_train { +class Sampler { +public: + virtual ~Sampler() = default; + virtual std::vector GetIndices(size_t dataset_size) = 0; +}; + +class RandomSampler : public Sampler { +public: + RandomSampler(bool repalcement = false, size_t num_samples = 0) + : repalcement_(repalcement), num_samples_(num_samples) {} + + std::vector GetIndices(size_t dataset_size) override; + +private: + bool repalcement_; + size_t num_samples_; +}; + +class SequentialSampler : public Sampler { +public: + SequentialSampler() {} + std::vector GetIndices(size_t dataset_size) override; +}; + +class SubsetRandomSampler : public Sampler { +public: + SubsetRandomSampler(std::vector &indices) : indices_(&indices) {} + std::vector GetIndices(size_t dataset_size) override; + +private: + std::shared_ptr> indices_; +}; +} // namespace infini_train \ No newline at end of file diff --git a/infini_train/src/dataloader.cc b/infini_train/src/dataloader.cc index 65a2d70..c99023d 100644 --- a/infini_train/src/dataloader.cc +++ b/infini_train/src/dataloader.cc @@ -9,6 +9,7 @@ #include "glog/logging.h" #include "infini_train/include/dataset.h" +#include "infini_train/include/sampler.h" #include "infini_train/include/tensor.h" namespace infini_train { @@ -36,8 +37,9 @@ std::shared_ptr Stack(const std::vector> &tensor } // namespace DataLoaderIterator::DataLoaderIterator(const Dataset &dataset, size_t batch_size, size_t batch_idx, - size_t max_batch_idx) - : dataset_(&dataset), batch_size_(batch_size), batch_idx_(batch_idx), max_batch_idx_(max_batch_idx){}; + size_t max_batch_idx, const std::vector &indices) + : dataset_(&dataset), batch_size_(batch_size), batch_idx_(batch_idx), max_batch_idx_(max_batch_idx), + indices_(&indices){}; std::pair, std::shared_ptr> DataLoaderIterator::operator*() const { /* @@ -49,7 +51,7 @@ std::pair, std::shared_ptr> DataLoaderIterator:: std::vector> data_vec; std::vector> label_vec; for (int idx = batch_idx_ * batch_size_; idx < (batch_idx_ + 1) * batch_size_ && idx < dataset_->Size(); ++idx) { - auto &&[data, label] = dataset_->operator[](idx); + auto &&[data, label] = dataset_->operator[](indices_->at(idx)); data_vec.push_back(std::move(data)); label_vec.push_back(std::move(label)); } @@ -77,12 +79,31 @@ bool operator==(const DataLoaderIterator &lhs, const DataLoaderIterator &rhs) { return lhs.batch_idx_ == rhs.batch_idx_; } -DataLoader::DataLoader(const std::shared_ptr &dataset, size_t batch_size) - : dataset_(dataset), batch_size_(batch_size), max_batch_idx_((dataset_->Size() + batch_size_ - 1) / batch_size_) {} +DataLoader::DataLoader(const std::shared_ptr &dataset, size_t batch_size, bool shuffle, + std::unique_ptr sampler) + : dataset_(dataset), batch_size_(batch_size), max_batch_idx_((dataset_->Size() + batch_size_ - 1) / batch_size_), + sampler_(std::move(sampler)) { -DataLoaderIterator DataLoader::begin() const { return DataLoaderIterator(*dataset_, batch_size_, 0, max_batch_idx_); } + if (sampler_ != nullptr && shuffle) { + LOG(FATAL) << "Sampler option is mutually exclusive with 'shuffle'"; + } + + if (sampler_ == nullptr) { + if (shuffle) { + sampler_ = std::make_unique(false); + } else { + sampler_ = std::make_unique(); + } + } + + indices_ = std::make_shared>(sampler_->GetIndices(dataset_->Size())); +} + +DataLoaderIterator DataLoader::begin() const { + return DataLoaderIterator(*dataset_, batch_size_, 0, max_batch_idx_, *indices_); +} DataLoaderIterator DataLoader::end() const { - return DataLoaderIterator(*dataset_, batch_size_, max_batch_idx_, max_batch_idx_); + return DataLoaderIterator(*dataset_, batch_size_, max_batch_idx_, max_batch_idx_, *indices_); } } // namespace infini_train diff --git a/infini_train/src/sampler.cc b/infini_train/src/sampler.cc new file mode 100644 index 0000000..a432b35 --- /dev/null +++ b/infini_train/src/sampler.cc @@ -0,0 +1,46 @@ +#include "glog/logging.h" + +#include "infini_train/include/sampler.h" + +namespace infini_train { +std::vector RandomSampler::GetIndices(size_t dataset_size) { + std::vector indices; + std::random_device rd; + std::mt19937 gen(rd()); + + const size_t n = num_samples_ > 0 ? num_samples_ : dataset_size; + + if (repalcement_) { + indices.resize(n); + std::uniform_int_distribution dist(0, dataset_size - 1); + for (size_t i = 0; i < n; ++i) { indices[i] = dist(gen); } + } else { + indices.resize(dataset_size); + for (size_t i = 0; i < dataset_size; ++i) { indices[i] = i; } + std::shuffle(indices.begin(), indices.end(), gen); + if (n < dataset_size) { + indices.resize(n); + } + } + + return indices; +} + +std::vector SequentialSampler::GetIndices(size_t dataset_size) { + std::vector indices(dataset_size); + + for (int i = 0; i < dataset_size; i++) { indices[i] = i; } + + return indices; +} + +std::vector SubsetRandomSampler::GetIndices(size_t dataset_size) { + std::random_device rd; + std::mt19937 gen(rd()); + std::vector indices(*indices_); + + std::shuffle(indices.begin(), indices.end(), gen); + + return indices; +}; +} // namespace infini_train \ No newline at end of file