Skip to content

Latest commit

 

History

History
331 lines (253 loc) · 15.5 KB

README_zh-CN.md

File metadata and controls

331 lines (253 loc) · 15.5 KB
 
OpenMMLab 官网 HOT      OpenMMLab 开放平台 TRY IT OUT
 

PyPI - Python Version PyPI license open issues issue resolution

📘使用文档 | 🛠️安装教程 | 🤔报告问题

English | 简体中文

简介

MMEngine 是一个基于 PyTorch 用于深度学习模型训练的基础库,支持在 Linux、Windows、macOS 上运行。它具有如下三个亮点:

  1. 通用:MMEngine 实现了一个高级的通用训练器,它能够:

    • 支持用少量代码训练不同的任务,例如仅使用 80 行代码就可以训练 imagenet(原始pytorch example 400 行)
    • 轻松兼容流行的算法库 (如 TIMM、TorchVision 和 Detectron2 ) 中的模型
  2. 统一:MMEngine 设计了一个接口统一的开放架构,使得:

    • 用户可以仅依赖一份代码实现所有任务的轻量化,例如 MMRazor 1.x 相比 MMRazor 0.x 优化了 40% 的代码量
    • 上下游的对接更加统一便捷,在为上层算法库提供统一抽象的同时,支持多种后端设备。目前 MMEngine 支持 Nvidia CUDA、Mac MPS、AMD、MLU 等设备进行模型训练。
  3. 灵活:MMEngine 实现了“乐高”式的训练流程,支持了:

    • 根据迭代数、 loss 和评测结果等动态调整的训练流程、优化策略和数据增强策略,例如早停(early stopping)机制等
    • 任意形式的模型权重平均,如 Exponential Momentum Average (EMA) 和 Stochastic Weight Averaging (SWA)
    • 训练过程中针对任意数据和任意节点的灵活可视化和日志控制
    • 对神经网络模型中各个层的优化配置进行细粒度调整
    • 混合精度训练的灵活控制

最近进展

最新版本 v0.3.2 在 2022.11.24 发布。

如果想了解更多版本更新细节和历史信息,请阅读更新日志

安装

在安装 MMengine 之前,请确保 PyTorch 已成功安装在环境中,可以参考 PyTorch 官方安装文档

安装 MMEngine

pip install -U openmim
mim install mmengine

验证是否安装成功

python -c 'from mmengine.utils.dl_utils import collect_env;print(collect_env())'

更多安装方式请阅读安装文档

快速上手

以在 CIFAR-10 数据集上训练一个 ResNet-50 模型为例,我们将使用 80 行以内的代码,利用 MMEngine 构建一个完整的、可配置的训练和验证流程。

构建模型

首先,我们需要构建一个模型,在 MMEngine 中,我们约定这个模型应当继承 BaseModel,并且其 forward 方法除了接受来自数据集的若干参数外,还需要接受额外的参数 mode:对于训练,我们需要 mode 接受字符串 "loss",并返回一个包含 "loss" 字段的字典;对于验证,我们需要 mode 接受字符串 "predict",并返回同时包含预测信息和真实信息的结果。

import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel

class MMResNet50(BaseModel):
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet50()

    def forward(self, imgs, labels, mode):
        x = self.resnet(imgs)
        if mode == 'loss':
            return {'loss': F.cross_entropy(x, labels)}
        elif mode == 'predict':
            return x, labels
构建数据集

其次,我们需要构建训练和验证所需要的数据集 (Dataset)和数据加载器 (DataLoader)。 对于基础的训练和验证功能,我们可以直接使用符合 PyTorch 标准的数据加载器和数据集。

import torchvision.transforms as transforms
from torch.utils.data import DataLoader

norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=torchvision.datasets.CIFAR10(
                                  'data/cifar10',
                                  train=True,
                                  download=True,
                                  transform=transforms.Compose([
                                      transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(**norm_cfg)
                                  ])))
val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=torchvision.datasets.CIFAR10(
                                'data/cifar10',
                                train=False,
                                download=True,
                                transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(**norm_cfg)
                                ])))
构建评测指标

为了进行验证和测试,我们需要定义模型推理结果的评测指标。我们约定这一评测指标需要继承 BaseMetric,并实现 processcompute_metrics 方法。

from mmengine.evaluator import BaseMetric

class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        # 将一个批次的中间结果保存至 `self.results`
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })
    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        # 返回保存有评测指标结果的字典,其中键为指标名称
        return dict(accuracy=100 * total_correct / total_size)
构建执行器

最后,我们利用构建好的模型数据加载器评测指标构建一个执行器 (Runner),同时在其中配置 优化器工作路径训练与验证配置等选项

from torch.optim import SGD
from mmengine.runner import Runner

runner = Runner(
    # 用以训练和验证的模型,需要满足特定的接口需求
    model=MMResNet50(),
    # 工作路径,用以保存训练日志、权重文件信息
    work_dir='./work_dir',
    # 训练数据加载器,需要满足 PyTorch 数据加载器协议
    train_dataloader=train_dataloader,
    # 优化器包装,用于模型优化,并提供 AMP、梯度累积等附加功能
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    # 训练配置,用于指定训练周期、验证间隔等信息
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    # 验证数据加载器,需要满足 PyTorch 数据加载器协议
    val_dataloader=val_dataloader,
    # 验证配置,用于指定验证所需要的额外参数
    val_cfg=dict(),
    # 用于验证的评测器,这里使用默认评测器,并评测指标
    val_evaluator=dict(type=Accuracy),
)
开始训练
runner.train()

了解更多

入门教程
进阶教程
示例
架构设计
迁移指南

贡献指南

我们感谢所有的贡献者为改进和提升 MMEngine 所作出的努力。请参考贡献指南来了解参与项目贡献的相关指引。

开源许可证

该项目采用 Apache 2.0 license 开源许可证。

OpenMMLab 的其他项目

  • MIM: MIM 是 OpenMMLab 项目、算法、模型的统一入口
  • MMCV: OpenMMLab 计算机视觉基础库
  • MMEval: 统一开放的跨框架算法评测库
  • MMClassification: OpenMMLab 图像分类工具箱
  • MMDetection: OpenMMLab 目标检测工具箱
  • MMDetection3D: OpenMMLab 新一代通用 3D 目标检测平台
  • MMRotate: OpenMMLab 旋转框检测工具箱与测试基准
  • MMYOLO: OpenMMLab YOLO 系列工具箱与测试基准
  • MMSegmentation: OpenMMLab 语义分割工具箱
  • MMOCR: OpenMMLab 全流程文字检测识别理解工具包
  • MMPose: OpenMMLab 姿态估计工具箱
  • MMHuman3D: OpenMMLab 人体参数化模型工具箱与测试基准
  • MMSelfSup: OpenMMLab 自监督学习工具箱与测试基准
  • MMRazor: OpenMMLab 模型压缩工具箱与测试基准
  • MMFewShot: OpenMMLab 少样本学习工具箱与测试基准
  • MMAction2: OpenMMLab 新一代视频理解工具箱
  • MMTracking: OpenMMLab 一体化视频目标感知平台
  • MMFlow: OpenMMLab 光流估计工具箱与测试基准
  • MMEditing: OpenMMLab 图像视频编辑工具箱
  • MMGeneration: OpenMMLab 图片视频生成模型工具箱
  • MMDeploy: OpenMMLab 模型部署框架

欢迎加入 OpenMMLab 社区

扫描下方的二维码可关注 OpenMMLab 团队的 知乎官方账号,加入 OpenMMLab 团队的 官方交流 QQ 群,或通过添加微信“Open小喵Lab”加入官方交流微信群。

我们会在 OpenMMLab 社区为大家

  • 📢 分享 AI 框架的前沿核心技术
  • 💻 解读 PyTorch 常用模块源码
  • 📰 发布 OpenMMLab 的相关新闻
  • 🚀 介绍 OpenMMLab 开发的前沿算法
  • 🏃 获取更高效的问题答疑和意见反馈
  • 🔥 提供与各行各业开发者充分交流的平台

干货满满 📘,等你来撩 💗,OpenMMLab 社区期待您的加入 👬