Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
lancerts committed Jan 24, 2024
1 parent 35e2770 commit d6ab7a1
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions cpp/dcgan/dcgan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ const int64_t kNoiseSize = 100;
// The batch size for training.
const int64_t kBatchSize = 64;

// The number of epochs to train.
const int64_t kNumberOfEpochs = 30;
// The default number of epochs to train.
int64_t kNumberOfEpochs = 30;

// Where to find the MNIST dataset.
const char* kDataFolder = "./data";
Expand Down Expand Up @@ -95,6 +95,19 @@ nn::Sequential create_discriminator() {
}

int main(int argc, const char* argv[]) {

if (argc > 1) {
std::string arg = argv[1];
if (std::all_of(arg.begin(), arg.end(), ::isdigit)) {
try {
kNumberOfEpochs = std::stoll(arg);
} catch (const std::invalid_argument& ia) {
// If unable to parse, do nothing and keep the default value
}
}
}
std::cout << "Traning with number of epochs: " << kNumberOfEpochs << std::endl;

torch::manual_seed(1);

// Create the device we pass around based on whether CUDA is available.
Expand Down Expand Up @@ -172,7 +185,7 @@ int main(int argc, const char* argv[]) {
batch_index++;
if (batch_index % kLogInterval == 0) {
std::printf(
"\r[%2lld/%2lld][%3lld/%3lld] D_loss: %.4f | G_loss: %.4f\n",
"\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f\n",
epoch,
kNumberOfEpochs,
batch_index,
Expand Down

0 comments on commit d6ab7a1

Please sign in to comment.