Skip to content

Commit

Permalink
Merge pull request #4 from mini-sora/update_doc
Browse files Browse the repository at this point in the history
Update dataset doc
  • Loading branch information
PeterH0323 authored Mar 25, 2024
2 parents 7aa6083 + 488c279 commit e883a9a
Showing 1 changed file with 28 additions and 2 deletions.
30 changes: 28 additions & 2 deletions README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ English | [简体中文](./README_CN.md)

- ImageNet-1K

可以在 OpenDatalab 进行下载 [ImageNet-1K](https://opendatalab.org.cn/OpenDataLab/ImageNet-1K)
可以在 OpenDataLab 进行下载 [ImageNet-1K](https://opendatalab.org.cn/OpenDataLab/ImageNet-1K)

```shell
pip install openxlab #安装
Expand Down Expand Up @@ -85,7 +85,33 @@ openxlab dataset get --dataset-repo OpenDataLab/ImageNet-1K #数据集下载
print(train_steps)
```

2. 使用 mmengine 重写数据流
执行后会对每个图片生成一个 npy 文件,训练的时候直接读取

2. 使用 mmengine 重写数据流,下面是原版的 dataset,可见直接读取上一步生成的 npy 文件,省去了前处理时间

```python
class CustomDataset(Dataset):
def __init__(self, features_dir, labels_dir):
self.features_dir = features_dir
self.labels_dir = labels_dir

self.features_files = sorted(os.listdir(features_dir))
self.labels_files = sorted(os.listdir(labels_dir))

def __len__(self):
assert len(self.features_files) == len(self.labels_files), \
"Number of feature files and label files should be same"
return len(self.features_files)

def __getitem__(self, idx):
feature_file = self.features_files[idx]
label_file = self.labels_files[idx]

features = np.load(os.path.join(self.features_dir, feature_file))
labels = np.load(os.path.join(self.labels_dir, label_file))
return torch.from_numpy(features), torch.from_numpy(labels)
```

3. 重写 loss 计算
4. 使用 xtuner 调训练 pipeline

Expand Down

0 comments on commit e883a9a

Please sign in to comment.