Skip to content

Commit

Permalink
Merge pull request #5 from mini-sora/update_doc
Browse files Browse the repository at this point in the history
Add environment doc
  • Loading branch information
PeterH0323 committed Mar 25, 2024
2 parents e883a9a + 763ea49 commit e658540
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,18 @@ openxlab dataset get --dataset-repo OpenDataLab/ImageNet-1K #数据集下载

## 复现步骤

1. 数据集预处理
目前已在 dev 分支提交了 DiT 在纯 torch 下的复现代码 [fast-DiT](https://github.com/chuanyangjin/fast-DiT),该版本使用了混合精度还有一些加速方案,可以极大程度降低显存,以及提升训练速度。

1. 环境安装

使用 dev 分支中的 `environment.yml` 可以复现环境

```bash
conda env create -f environment.yml
conda activate DiT
```

2. 数据集预处理

因为在原版 Meta 的 [DiT](https://github.com/facebookresearch/DiT) 中,每个 iter 都会对数据进行重复计算,为了节省训练的时间,可以先对图片进行预处理,在训练的时候可以节省这部分的时间

Expand All @@ -87,7 +98,7 @@ openxlab dataset get --dataset-repo OpenDataLab/ImageNet-1K #数据集下载

执行后会对每个图片生成一个 npy 文件,训练的时候直接读取

2. 使用 mmengine 重写数据流,下面是原版的 dataset,可见直接读取上一步生成的 npy 文件,省去了前处理时间
3. 使用 mmengine 重写数据流,下面是原版的 dataset,可见直接读取上一步生成的 npy 文件,省去了前处理时间

```python
class CustomDataset(Dataset):
Expand All @@ -112,8 +123,8 @@ class CustomDataset(Dataset):
return torch.from_numpy(features), torch.from_numpy(labels)
```

3. 重写 loss 计算
4. 使用 xtuner 调训练 pipeline
4. 重写 loss 计算
5. 使用 xtuner 调训练 pipeline

## 模型架构

Expand Down Expand Up @@ -188,4 +199,4 @@ class CustomDataset(Dataset):
[issues-shield]: https://img.shields.io/github/issues/mini-sora/minisora.svg?style=flat-square
[issues-url]: https://img.shields.io/github/issues/mini-sora/minisora.svg
[license-shield]: https://img.shields.io/github/license/mini-sora/minisora.svg?style=flat-square
[license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE
[license-url]: https://github.com/mini-sora/minisora/blob/main/LICENSE

0 comments on commit e658540

Please sign in to comment.