Skip to content

Commit

Permalink
Merge pull request #6 from mini-sora/update_doc
Browse files Browse the repository at this point in the history
Update doc
  • Loading branch information
PeterH0323 authored Mar 25, 2024
2 parents e658540 + d193100 commit d589f9c
Showing 1 changed file with 83 additions and 0 deletions.
83 changes: 83 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,89 @@ Mini Sora 开源社区定位为由社区同学自发组织的开源社区(**

- [**empty**:](empty)


## 数据集

- ImageNet-1K

可以在 OpenDataLab 进行下载 [ImageNet-1K](https://opendatalab.org.cn/OpenDataLab/ImageNet-1K)

```shell
pip install openxlab #安装
pip install -U openxlab #版本升级
openxlab login #进行登录,输入对应的AK/SK

cd ${dataset_dir}
openxlab dataset get --dataset-repo OpenDataLab/ImageNet-1K #数据集下载
```

## 复现步骤

目前已在 dev 分支提交了 DiT 在纯 torch 下的复现代码 [fast-DiT](https://github.com/chuanyangjin/fast-DiT),该版本使用了混合精度还有一些加速方案,可以极大程度降低显存,以及提升训练速度。

1. 环境安装

使用 dev 分支中的 `environment.yml` 可以复现环境

```bash
conda env create -f environment.yml
conda activate DiT
```

2. 数据集预处理

因为在原版 Meta 的 [DiT](https://github.com/facebookresearch/DiT) 中,每个 iter 都会对数据进行重复计算,为了节省训练的时间,可以先对图片进行预处理,在训练的时候可以节省这部分的时间

详见 dev 分支中的 [extract_features.py#L163](https://github.com/mini-sora/MiniSora-DiT/blob/ad13c58370842db333c77253709e3fbbc1e9a092/extract_features.py#L163-L177) ,处理需要时间较久,大概 1~2小时。

```python
for x, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
# Map input images to latent space + normalize latents:
x = vae.encode(x).latent_dist.sample().mul_(0.18215)

x = x.detach().cpu().numpy() # (1, 4, 32, 32)
np.save(f'{args.features_path}/imagenet256_features/{train_steps}.npy', x)

y = y.detach().cpu().numpy() # (1,)
np.save(f'{args.features_path}/imagenet256_labels/{train_steps}.npy', y)

train_steps += 1
print(train_steps)
```

执行后会对每个图片生成一个 npy 文件,训练的时候直接读取

3. 使用 mmengine 重写数据流,下面是原版的 dataset,可见直接读取上一步生成的 npy 文件,省去了前处理时间

```python
class CustomDataset(Dataset):
def __init__(self, features_dir, labels_dir):
self.features_dir = features_dir
self.labels_dir = labels_dir

self.features_files = sorted(os.listdir(features_dir))
self.labels_files = sorted(os.listdir(labels_dir))

def __len__(self):
assert len(self.features_files) == len(self.labels_files), \
"Number of feature files and label files should be same"
return len(self.features_files)

def __getitem__(self, idx):
feature_file = self.features_files[idx]
label_file = self.labels_files[idx]

features = np.load(os.path.join(self.features_dir, feature_file))
labels = np.load(os.path.join(self.labels_dir, label_file))
return torch.from_numpy(features), torch.from_numpy(labels)
```

4. 重写 loss 计算
5. 使用 xtuner 调训练 pipeline

## 论文共读计划

### 论文共读发表者募集
Expand Down

0 comments on commit d589f9c

Please sign in to comment.