From d193100632cb3bee2f09cf1f974b52f75be36574 Mon Sep 17 00:00:00 2001 From: HinGwenWoong Date: Mon, 25 Mar 2024 20:35:53 +0800 Subject: [PATCH] Update doc --- README.md | 83 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/README.md b/README.md index 85eeb2d..c9f6f6d 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,89 @@ Mini Sora 开源社区定位为由社区同学自发组织的开源社区(** - [**empty**:](empty) + +## 数据集 + +- ImageNet-1K + +可以在 OpenDataLab 进行下载 [ImageNet-1K](https://opendatalab.org.cn/OpenDataLab/ImageNet-1K) + +```shell +pip install openxlab #安装 +pip install -U openxlab #版本升级 +openxlab login #进行登录,输入对应的AK/SK + +cd ${dataset_dir} +openxlab dataset get --dataset-repo OpenDataLab/ImageNet-1K #数据集下载 +``` + +## 复现步骤 + +目前已在 dev 分支提交了 DiT 在纯 torch 下的复现代码 [fast-DiT](https://github.com/chuanyangjin/fast-DiT),该版本使用了混合精度还有一些加速方案,可以极大程度降低显存,以及提升训练速度。 + +1. 环境安装 + +使用 dev 分支中的 `environment.yml` 可以复现环境 + +```bash +conda env create -f environment.yml +conda activate DiT +``` + +2. 数据集预处理 + +因为在原版 Meta 的 [DiT](https://github.com/facebookresearch/DiT) 中,每个 iter 都会对数据进行重复计算,为了节省训练的时间,可以先对图片进行预处理,在训练的时候可以节省这部分的时间 + +详见 dev 分支中的 [extract_features.py#L163](https://github.com/mini-sora/MiniSora-DiT/blob/ad13c58370842db333c77253709e3fbbc1e9a092/extract_features.py#L163-L177) ,处理需要时间较久,大概 1~2小时。 + +```python + for x, y in loader: + x = x.to(device) + y = y.to(device) + with torch.no_grad(): + # Map input images to latent space + normalize latents: + x = vae.encode(x).latent_dist.sample().mul_(0.18215) + + x = x.detach().cpu().numpy() # (1, 4, 32, 32) + np.save(f'{args.features_path}/imagenet256_features/{train_steps}.npy', x) + + y = y.detach().cpu().numpy() # (1,) + np.save(f'{args.features_path}/imagenet256_labels/{train_steps}.npy', y) + + train_steps += 1 + print(train_steps) +``` + +执行后会对每个图片生成一个 npy 文件,训练的时候直接读取 + +3. 使用 mmengine 重写数据流,下面是原版的 dataset,可见直接读取上一步生成的 npy 文件,省去了前处理时间 + +```python +class CustomDataset(Dataset): + def __init__(self, features_dir, labels_dir): + self.features_dir = features_dir + self.labels_dir = labels_dir + + self.features_files = sorted(os.listdir(features_dir)) + self.labels_files = sorted(os.listdir(labels_dir)) + + def __len__(self): + assert len(self.features_files) == len(self.labels_files), \ + "Number of feature files and label files should be same" + return len(self.features_files) + + def __getitem__(self, idx): + feature_file = self.features_files[idx] + label_file = self.labels_files[idx] + + features = np.load(os.path.join(self.features_dir, feature_file)) + labels = np.load(os.path.join(self.labels_dir, label_file)) + return torch.from_numpy(features), torch.from_numpy(labels) +``` + +4. 重写 loss 计算 +5. 使用 xtuner 调训练 pipeline + ## 论文共读计划 ### 论文共读发表者募集