使用PaddlePaddle2.0高层API完成基于VGG16的图像分类任务
时间:2025-07-23 | 作者: | 阅读:0本文介绍如何用PaddlePaddle2.0高层API,基于VGG16完成Cifar10图像分类任务。包括利用高层API简化模型组网与训练,加载VGG16网络并查看结构,加载Cifar10数据集及数据增强,还讲解了用高层API进行模型训练、验证与测试的过程,最后提及使用时的注意事项。
使用PaddlePaddle2.0高层API完成基于VGG16的图像分类任务
本示例教程将会演示如何使用飞桨的卷积神经网络来完成目标检测任务。这是一个较为简单的示例,将会使用飞桨框架内置模型VGG16网络完成Cifar10数据集的图像分类任务。
一、PaddlePaddle2.0新亮点——高层API助力开发者快速上手深度学习
飞桨致力于让深度学习技术的创新与应用更简单
1.模型组网更简单
对于新手来说,完全可以省去以往复杂的组网代码,一行代码便可以完成组网。
目前PaddlePaddle2.0-rc1的内置模型有:
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'VGG', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'MobileNetV1', 'mobilenet_v1', 'MobileNetV2', 'mobilenet_v2', 'LeNet'
使用一行代码便可以加载:
ModelName = paddle.vision.models.ModelName()登录后复制 ? ? ? ?
(将ModelName替换成上面的模型名称即可,模型名称后面别忘了加括号!!!)
2.模型训练更简单
PaddlePaddle2.0-rc1增加了paddle.Model高层API,大部分任务可以使用此API用于简化训练、评估、预测类代码开发。
使用两句代码便可以训练模型:
# 训练前准备ModelName.prepare( paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()), paddle.nn.CrossEntropyLoss(), paddle.metric.Accuracy(topk=(1, 2)) )# 启动训练ModelName.fit(train_dataset, epochs=2, batch_size=64, log_freq=200)登录后复制 ? ?
二、使用飞桨快速加载VGG网络并查看模型结构
Very Deep Convolutional Networks For Large-Scale Image Recognition 论文地址:https://arxiv.org/pdf/1409.1556.pdf
1.查看飞桨框架内置模型
In [?]import paddleprint('飞桨框架内置模型:', paddle.vision.models.__all__)登录后复制 ? ? ? ?
飞桨框架内置模型: ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'VGG', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'MobileNetV1', 'mobilenet_v1', 'MobileNetV2', 'mobilenet_v2', 'LeNet']登录后复制 ? ? ? ?
2.一行代码加载VGG16
In [?]vgg16 = paddle.vision.models.vgg16()登录后复制 ? ?
3.查看VGG16的网络结构及参数
In [?]paddle.summary(vgg16, (64, 3, 32, 32))登录后复制 ? ? ? ?
------------------------------------------------------------------------------- Layer (type) Input Shape Output Shape Param # =============================================================================== Conv2D-1 [[64, 3, 32, 32]] [64, 64, 32, 32] 1,792 ReLU-1 [[64, 64, 32, 32]] [64, 64, 32, 32] 0 Conv2D-2 [[64, 64, 32, 32]] [64, 64, 32, 32] 36,928 ReLU-2 [[64, 64, 32, 32]] [64, 64, 32, 32] 0 MaxPool2D-1 [[64, 64, 32, 32]] [64, 64, 16, 16] 0 Conv2D-3 [[64, 64, 16, 16]] [64, 128, 16, 16] 73,856 ReLU-3 [[64, 128, 16, 16]] [64, 128, 16, 16] 0 Conv2D-4 [[64, 128, 16, 16]] [64, 128, 16, 16] 147,584 ReLU-4 [[64, 128, 16, 16]] [64, 128, 16, 16] 0 MaxPool2D-2 [[64, 128, 16, 16]] [64, 128, 8, 8] 0 Conv2D-5 [[64, 128, 8, 8]] [64, 256, 8, 8] 295,168 ReLU-5 [[64, 256, 8, 8]] [64, 256, 8, 8] 0 Conv2D-6 [[64, 256, 8, 8]] [64, 256, 8, 8] 590,080 ReLU-6 [[64, 256, 8, 8]] [64, 256, 8, 8] 0 Conv2D-7 [[64, 256, 8, 8]] [64, 256, 8, 8] 590,080 ReLU-7 [[64, 256, 8, 8]] [64, 256, 8, 8] 0 MaxPool2D-3 [[64, 256, 8, 8]] [64, 256, 4, 4] 0 Conv2D-8 [[64, 256, 4, 4]] [64, 512, 4, 4] 1,180,160 ReLU-8 [[64, 512, 4, 4]] [64, 512, 4, 4] 0 Conv2D-9 [[64, 512, 4, 4]] [64, 512, 4, 4] 2,359,808 ReLU-9 [[64, 512, 4, 4]] [64, 512, 4, 4] 0 Conv2D-10 [[64, 512, 4, 4]] [64, 512, 4, 4] 2,359,808 ReLU-10 [[64, 512, 4, 4]] [64, 512, 4, 4] 0 MaxPool2D-4 [[64, 512, 4, 4]] [64, 512, 2, 2] 0 Conv2D-11 [[64, 512, 2, 2]] [64, 512, 2, 2] 2,359,808 ReLU-11 [[64, 512, 2, 2]] [64, 512, 2, 2] 0 Conv2D-12 [[64, 512, 2, 2]] [64, 512, 2, 2] 2,359,808 ReLU-12 [[64, 512, 2, 2]] [64, 512, 2, 2] 0 Conv2D-13 [[64, 512, 2, 2]] [64, 512, 2, 2] 2,359,808 ReLU-13 [[64, 512, 2, 2]] [64, 512, 2, 2] 0 MaxPool2D-5 [[64, 512, 2, 2]] [64, 512, 1, 1] 0 AdaptiveAvgPool2D-1 [[64, 512, 1, 1]] [64, 512, 7, 7] 0 Linear-1 [[64, 25088]] [64, 4096] 102,764,544 ReLU-14 [[64, 4096]] [64, 4096] 0 Dropout-1 [[64, 4096]] [64, 4096] 0 Linear-2 [[64, 4096]] [64, 4096] 16,781,312 ReLU-15 [[64, 4096]] [64, 4096] 0 Dropout-2 [[64, 4096]] [64, 4096] 0 Linear-3 [[64, 4096]] [64, 1000] 4,097,000 ===============================================================================Total params: 138,357,544Trainable params: 138,357,544Non-trainable params: 0-------------------------------------------------------------------------------Input size (MB): 0.75Forward/backward pass size (MB): 309.99Params size (MB): 527.79Estimated Total Size (MB): 838.53-------------------------------------------------------------------------------登录后复制 ? ? ? ?
{'total_params': 138357544, 'trainable_params': 138357544}登录后复制 ? ? ? ? ? ? ? ?
三、使用飞桨框架API加载数据集
飞桨框架将一些我们常用到的数据集作为领域API对用户进行开放,对应API所在目录为paddle.vision.datasets与paddle.text.datasets
目前已经收录的数据集有:
视觉相关数据集: ['DatasetFolder', 'ImageFolder', 'MNIST', 'FashionMNIST', 'Flowers', 'Cifar10', 'Cifar100', 'VOC2012']
自然语言相关数据集: ['Conll05st', 'Imdb', 'Imikolov', 'Movielens', 'UCIHousing', 'WMT14', 'WMT16']
1.快速加载数据集
使用一行代码即可加载数据集到本机缓存目录~/.cache/paddle/dataset:
paddle.vision.datasets.DataSetName()
将DataSetName替换成上述数据集名称即可,别忘了名称后面跟一个小括号!
In [?]from paddle.vision.transforms import ToTensor# 训练数据集 用ToTensor将数据格式转为Tensortrain_dataset = paddle.vision.datasets.Cifar10(mode='train', transform=ToTensor())# 验证数据集val_dataset = paddle.vision.datasets.Cifar10(mode='test', transform=ToTensor())登录后复制 ? ?
2.对训练数据做数据增强
训练过程中有时会遇到过拟合的问题,其中一个解决方法就是对训练数据做增强,对数据进行处理得到不同的图像,从而泛化数据集。
查看飞桨框架提供的数据增强方法:
In [?]import paddleprint('数据处理方法:', paddle.vision.transforms.__all__)登录后复制 ? ? ? ?
数据处理方法: ['BaseTransform', 'Compose', 'Resize', 'RandomResizedCrop', 'CenterCrop', 'RandomHorizontalFlip', 'RandomVerticalFlip', 'Transpose', 'Normalize', 'BrightnessTransform', 'SaturationTransform', 'ContrastTransform', 'HueTransform', 'ColorJitter', 'RandomCrop', 'Pad', 'RandomRotation', 'Grayscale', 'ToTensor', 'to_tensor', 'hflip', 'vflip', 'resize', 'pad', 'rotate', 'to_grayscale', 'crop', 'center_crop', 'adjust_brightness', 'adjust_contrast', 'adjust_hue', 'normalize']登录后复制 ? ? ? ?In [?]
from paddle.vision.transforms import Compose, Resize, ColorJitter, ToTensor, RandomHorizontalFlip, RandomVerticalFlip, RandomRotationimport numpy as npfrom PIL import Image# 定义想要使用那些数据增强方式,这里用到了随机调整亮度、对比度和饱和度、图像翻转等transform = Compose([ColorJitter(), RandomHorizontalFlip(), ToTensor()])# 通过transform参数传递定义好的数据增项方法即可完成对自带数据集的应用train_dataset = paddle.vision.datasets.Cifar10(mode='train', transform=transform)# 验证数据集val_dataset = paddle.vision.datasets.Cifar10(mode='test', transform=transform)登录后复制 ? ?
这里需要注意的坑是,一定要把ToTensor()放在最后,否则会报错
检查数据集:
In [?]train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)for batch_id, data in enumerate(train_loader()): x_data = data[0] y_data = data[1] breakprint(x_data.numpy().shape)print(y_data.numpy().shape)登录后复制 ? ? ? ?
(64, 3, 32, 32)(64,)登录后复制 ? ? ? ?
四、使用高层API进行模型训练、验证与测试
飞桨框架提供了两种训练与预测的方法:
- 一种是用paddle.Model对模型进行封装,通过高层API如Model.fit()、Model.evaluate()、Model.predict()等完成模型的训练与预测;
- 另一种就是基于基础API常规的训练方式。
使用高层API只需要改动少量参数即可完成模型训练,对新手小白真的特别友好!
1.调用fit()接口来启动训练过程
In [26]import paddlefrom paddle.vision.transforms import ToTensorfrom paddle.vision.models import vgg16# build modelmodel = vgg16()# build vgg16 model with batch_normmodel = vgg16(batch_norm=True)# 使用高层API——paddle.Model对模型进行封装model = paddle.Model(model)# 为模型训练做准备,设置优化器,损失函数和精度计算方式model.prepare(optimizer=paddle.optimizer.Adam(parameters=model.parameters()), loss=paddle.nn.CrossEntropyLoss(), metrics=paddle.metric.Accuracy())# 启动模型训练,指定训练数据集,设置训练轮次,设置每次数据集计算的批次大小,设置日志格式model.fit(train_dataset, epochs=10, batch_size=256, save_dir=”vgg16/“, save_freq=10, verbose=1)登录后复制 ? ? ? ?
The loss value printed in the log is the current step, and the metric is the average value of previous step.Epoch 1/10step 196/196 [==============================] - loss: 2.3752 - acc: 0.1054 - 202ms/step save checkpoint at /home/aistudio/vgg16/0Epoch 2/10step 196/196 [==============================] - loss: 2.2819 - acc: 0.1114 - 203ms/step Epoch 3/10step 196/196 [==============================] - loss: 2.2614 - acc: 0.1306 - 201ms/step Epoch 4/10step 196/196 [==============================] - loss: 2.2229 - acc: 0.1822 - 195ms/step Epoch 5/10step 196/196 [==============================] - loss: 1.9132 - acc: 0.1846 - 198ms/step Epoch 6/10step 196/196 [==============================] - loss: 1.7738 - acc: 0.2000 - 194ms/step Epoch 7/10step 196/196 [==============================] - loss: 1.8450 - acc: 0.2286 - 195ms/step Epoch 8/10step 196/196 [==============================] - loss: 1.5770 - acc: 0.2782 - 198ms/step Epoch 9/10step 196/196 [==============================] - loss: 1.5743 - acc: 0.3446 - 195ms/step Epoch 10/10step 196/196 [==============================] - loss: 1.4283 - acc: 0.4043 - 201ms/step save checkpoint at /home/aistudio/vgg16/final登录后复制 ? ? ? ?
看到loss在明显下降、acc在明显上升,说明模型效果还不错,剩下需要慢慢调参优化
2.调用evaluate()在测试集上对模型进行验证
对于训练好的模型进行评估操作可以使用evaluate接口来实现,事先定义好用于评估使用的数据集后,可以简单的调用evaluate接口即可完成模型评估操作,结束后根据prepare中loss和metric的定义来进行相关评估结果计算返回。
In [27]# 用 evaluate 在测试集上对模型进行验证eval_result = model.evaluate(val_dataset, verbose=1)登录后复制 ? ? ? ?
Eval begin...The loss value printed in the log is the current batch, and the metric is the average value of previous step.step 10000/10000 [==============================] - loss: 0.2509 - acc: 0.4379 - 10ms/step Eval samples: 10000登录后复制 ? ? ? ?
3.调用predict()接口进行模型测试
高层API中提供了predict接口来方便用户对训练好的模型进行预测验证,只需要基于训练好的模型将需要进行预测测试的数据放到接口中进行计算即可,接口会将经过模型计算得到的预测结果进行返回。
In [28]# 用 predict 在测试集上对模型进行测试test_result = model.predict(val_dataset)登录后复制 ? ? ? ?
Predict begin...step 10000/10000 [==============================] - 10ms/step Predict samples: 10000登录后复制 ? ? ? ?
五、总结与展望——PaddlePaddle2.0 rc1入手指南
给大家总结一下我在使用PaddlePaddle2.0 rc1时遇到的坑,希望大家可以避免:
- 1.这个项目本来是使用VOC2012数据集进行目标检测任务的训练的,奈何VOC2012数据集的下载速度实在是太慢了,所以我果断放弃,希望后期可以找到解决办法
- 2.最好结合PaddlePaddle2.0的文档和GitHub上的源码来使用,特别是新手小白,不然出现一些报错可能会很难解决,可以多去GitHub上提issue
- 3.使用数据增强 Compose()方法时,切记!一定要把ToTensor()放在最后,这个问题看看源码就能解决
福利游戏
相关文章
更多-
- 电脑音箱有电流声,该怎么消除?
- 时间:2025-07-23
-
- exr 格式图片在影视后期中常用吗 与 hdr 有何不同
- 时间:2025-07-23
-
- 电脑的键盘输入时出现重复字符,如何解决?
- 时间:2025-07-23
-
- 一文搞懂Paddle2.0中的优化器
- 时间:2025-07-23
-
- 基于Ghost Module的生活垃圾智能分类算法
- 时间:2025-07-23
-
- 第29周新势力车型销量TOP10公布:小米SU7有挑战者了
- 时间:2025-07-23
-
- 从零实现深度学习框架 基础框架的构建
- 时间:2025-07-23
-
- 基于PaddlePaddle2.0-构建残差网络模型
- 时间:2025-07-23
大家都在玩
热门话题
大家都在看
更多-
- 腾讯客服回应微信实时对讲功能:已下线 暂无重新上线计划
- 时间:2025-07-23
-
- GAT币投资指南:深度分析未来潜力
- 时间:2025-07-23
-
- 网友爆料尊界S800自动泊车撞了:车主就在旁边看着 承担全责
- 时间:2025-07-23
-
- 3万级纯电代步小车!全新奔腾小马官图发布:7月27日正式上市
- 时间:2025-07-23
-
- 妖怪金手指石矶娘娘图鉴及对应克制神将
- 时间:2025-07-23
-
- 比特币交易所排行:全球顶级平台及选择指南
- 时间:2025-07-23
-
- 一高速出现断头路却无提醒:引流线导向隔离墙 汽车险些撞上
- 时间:2025-07-23
-
- 国内首个!夸克健康大模型通过12门主任医师考试
- 时间:2025-07-23