位置:首页 > 新闻资讯 > ResNet_wide for CIFAR10

ResNet_wide for CIFAR10

时间:2025-07-28  |  作者:  |  阅读:0

本项目基于PaddlePaddle复现Wide Resnet,它是ResNet变体,改进了shortcut,采用更宽卷积并加dropout层。在CIFAR10测试集精度达96.6%,提供单卡和多卡训练方式,代码含模型、训练、评估等文件,依赖PaddlePaddle≥2.0.0,支持GPU和CPU运行。

一、简介

本项目基于paddlepaddle框架复现Wide Resnet,他是resnet的一种变体,主要区别在于对resnet的shortcut进行了改进,使用更“宽”的卷积以及加上了dropout层。

论文:

  • [1] Zagoruyko S , Komodakis N . Wide Residual Networks[J]. 2016.
  • 链接:Wide Residual Networks

参考项目:

  • https://github.com/xternalz/WideResNet-pytorch
  • https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py

二、复现精度

该列指标在cifar10的测试集测试

train from scratch细节:

epochoptbatch_sizedatasetmemorycardprecision1400SGD128CIFAR1016G10.9660

模型下载?模型地址:aistudio

三、数据集

CIFAR10数据集。

  • 数据集大小:
    • 训练集:50000张
    • 测试集:10000张
    • 尺寸:32 * 32
  • 数据格式:分类数据集

四、环境依赖

  • 硬件:GPU、CPU

  • 框架:

    • PaddlePaddle >= 2.0.0

五、快速开始

step1: clone

# clone this repogit clone https://github.com/PaddlePaddle/Contrib.gitcd wide_resnetexport PYTHONPATH=./登录后复制

安装依赖

python3 -m pip install -r requirements.txt登录后复制

step2: 训练

python3 train.py登录后复制登录后复制

如果你想分布式训练并使用多卡:

python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3' train.py登录后复制登录后复制

此时的输出为:

Epoch 0: PiecewiseDecay set learning rate to 0.05.iter:0 loss:2.4832iter:10 loss:2.3544iter:20 loss:2.3087iter:30 loss:2.2509iter:40 loss:2.2450登录后复制

step3: 测试

python3 eval.py登录后复制登录后复制

此时的输出为:

acc:9660 total:10000 ratio:0.966登录后复制登录后复制

六、代码结构与详细说明

6.1 代码结构

│ wide_resnet.py # 模型文件│ eval.py # 评估│ README.md # 英文readme│ README_cn.md # 中文readme│ requirement.txt # 依赖│ train.py # 训练登录后复制

6.2 参数说明

6.3 训练流程

单机训练

python3 train.py登录后复制登录后复制

多机训练

python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3' train.py登录后复制登录后复制

此时,程序会将每个进程的输出log导入到./debug路径下:

.├── debug│ ├── workerlog.0│ ├── workerlog.1│ ├── workerlog.2│ └── workerlog.3├── README.md└── train.py登录后复制

训练输出

执行训练开始后,将得到类似如下的输出。每一轮batch训练将会打印当前epoch、step以及loss值。

Epoch 0: PiecewiseDecay set learning rate to 0.05.iter:0 loss:2.4832iter:10 loss:2.3544iter:20 loss:2.3087iter:30 loss:2.2509iter:40 loss:2.2450登录后复制

6.4 评估流程

python3 eval.py登录后复制登录后复制

此时的输出为:

acc:9660 total:10000 ratio:0.966登录后复制登录后复制

七、模型信息

关于模型的其他信息,可以参考下表:

信息说明发布者徐铭远时间2021.08框架版本>=Paddle 2.0.2应用场景图像分类支持硬件GPU、CPU下载链接预训练模型In [?]

# 以下为在aistudio上直接运行登录后复制In [4]

# 训练!python3 train.py登录后复制

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations def convert_to_list(value, n, name, dtype=np.int):W0808 16:41:54.148313 32483 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1W0808 16:41:54.152312 32483 device_context.cc:372] device: 0, cuDNN Version: 7.6.Epoch 0: PiecewiseDecay set learning rate to 0.05.iter:0 loss:2.4279iter:10 loss:2.3434^C登录后复制In [?]

# 评估!python3 eval.py登录后复制

/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations def convert_to_list(value, n, name, dtype=np.int):W0808 16:37:21.490298 32096 device_context.cc:362] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1W0808 16:37:21.494925 32096 device_context.cc:372] device: 0, cuDNN Version: 7.6.acc:9660 total:10000 ratio:0.966登录后复制

来源:https://www.php.cn/faq/1428412.html
免责声明:文中图文均来自网络,如有侵权请联系删除,心愿游戏发布此文仅为传递信息,不代表心愿游戏认同其观点或证实其描述。

相关文章

更多

精选合集

更多

大家都在玩

热门话题

大家都在看

更多