位置:首页 > 新闻资讯 > 【论文复现】CSRA-Paddle: 残差注意力机制模型

【论文复现】CSRA-Paddle: 残差注意力机制模型

时间:2025-07-29  |  作者:  |  阅读:0

本文介绍基于PaddlePaddle复现ICCV 2021论文的CSRA-Paddle项目。该项目通过类特定残余注意力模块(CSRA),结合类别无关平均池化特征与类特定空间注意力特征,提升多标签识别效果。在Pascal VOC 2007数据集上,Resnet101+CSRA模型复现精度达94.7 mAP,提供了完整的数据集准备、训练、验证及推理流程。

【论文复现】CSRA-Paddle: 残差注意力机制模型_wishdown.com

CSRA-Paddle: 残差注意力机制模型

1.1 简介

本项目基于PaddlePaddle?复现了ICCV 2021?上发表的论文:Residual Attention: A Simple But Effective Method for Multi-Label Recoginition

【论文复现】CSRA-Paddle: 残差注意力机制模型_wishdown.com

? ? ? ?

为了有效地捕捉来自不同类别的对象所占据的不同空间区域,这篇文章提出了一个非常简单的模块,称为类特定的残余注意力(CSRA)。 CSRA通过提出一个简单的空间注意力分数为每个类别生成特定于类的特征,然后将其与与类别无关的平均池化特征相结合。CSRA 在多标签识别上取得了 state-of-the-art 的结果,同时相比于其他方法简单得多。

本项目基于PaddlePaddle框架复现了CSRA,并在Pascal VOC数据集上进行了实验。

论文:

  • [1] Zhu, K. , and J. Wu .?Residual Attention: A Simple But Effective Method for Multi-Label Recoginition. ICCV, 2021.

项目参考:

  • https://github.com/Kevinz-code/CSRA

上述CSRA的核心代码块:

class CSRA(nn.Layer): # one basic block def __init__(self, input_dim, num_classes, T, lam): super(CSRA, self).__init__() self.T = T # temperature self.lam = lam # Lambda self.head = nn.Conv2D(input_dim, num_classes, 1, bias_attr=False) self.softmax = nn.Softmax(axis=2) def forward(self, x): # x (B d H W) # normalize classifier # score (B C HxW) score = self.head(x) / paddle.norm(self.head.weight, axis=1, keepdim=True).transpose((1, 0, 2, 3)) score = score.flatten(2) base_logit = paddle.mean(score, axis=2) if self.T == 99: # max-pooling att_logit = paddle.max(score, axis=2)[0] else: score_soft = self.softmax(score * self.T) att_logit = paddle.sum(score * score_soft, axis=2) return base_logit + self.lam * att_logit登录后复制

? ? ? ?

可以参阅论文进行理解。

1.2 复现精度

原文在Pascal VOC 2007 val数据集的测试效果如下表

【论文复现】CSRA-Paddle: 残差注意力机制模型_wishdown.com

? ? ? ?

本项目在Pascal VOC 2007 val数据集的测试效果如下表。

FrameNetWorkepochsoptlrresolutionbatch_sizedatasetcardmAP本项目PaddleResnet101+CSRA30SGD0.01448x44816VOC20071xV10094.7

可见,本项目成功用PaddlePaddle复现了论文结果(Resnet101+CSRA: 94.7)。

1.3 数据集

数据集网站:Pascal VOC

AiStudio上的数据集:pascal-voc

数据集介绍:

Pascal 的全称是 Pattern Analysis, Statical Modeling and Computational Learning。 PASCAL VOC 挑战赛是视觉对象的分类识别和检测的一个基准测试,提供了检测算法和学习性能的标准图像注释数据集和标准的评估系统。从2005年至今,该组织每年都会提供一系列类别的、带标签的图片,挑战者通过设计各种精妙的算法,仅根据分析图片内容来将其分类,最终通过准确率、召回率、效率来一决高下。

Pascal VOC(2005~2012)竞赛的目标主要是进行图像的目标识别,其提供的数据集包含20类的物体。每张图片都有标注,标注的物体包括人、动物(如猫、狗、岛等)、交通工具(如车、船飞机等)、家具(如椅子、桌子、沙发等)在内的20个类别。每个图像平均有2.4个目标。

VOC2007:中包含9963张标注过的图片, 由train/val/test三部分组成, 共标注出24,640个物体。

  • 本项目使用的数据集结构:

PATH/Dataset/|-- VOCdevkit/|---- VOC2007/|------ JPEGImages/|------ Annotations/|------ ImageSets/登录后复制

? ? ? ?

注:PATH/Dataset/为数据集的路径

快速开始

2.1 数据准备

In [?]

!unzip -q data/data4379/pascalvoc.zip -d data/data4379/登录后复制 ? ?In [1]

%cd /home/aistudio/CSRA-Paddle/!python utils/prepare/prepare_voc.py --data_path /home/aistudio/data/data4379/pascalvoc/VOCdevkit登录后复制 ? ? ? ?

/home/aistudio/CSRA-Paddlegenerating labels for VOC07 datasetgenerating final json file for VOC07 datasetVOC07 data preparing finished!data/voc07/trainval_voc07.json data/voc07/test_voc07.json登录后复制 ? ? ? ?

2.2 训练

In [?]

%cd /home/aistudio/CSRA-Paddle/!python train.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20 --save_dir=./checkpoint登录后复制 ? ?

2.3 验证

In [?]

%cd /home/aistudio/CSRA-Paddle/!python val.py --model resnet101 --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20 --load_from output/epoch_11.pdparams登录后复制 ? ?

结果:

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 310/310 [01:13

mAP: 0.946971

CP: 0.922363, CR: 0.876188, CF1 :0.898682

OP: 0.943647, OR: 0.890632, OF1 0.916373

2.4 预测

In [3]

%cd /home/aistudio/CSRA-Paddle/!python predict.py --model resnet101 --num_heads 1 --lam 0.1 --dataset voc07 --load_from output/epoch_11.pdparams --img_dir utils/demo_images登录后复制 ? ? ? ?

backbone params inited by paddle official modelW0410 16:12:18.782222 3012 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1W0410 16:12:18.786772 3012 device_context.cc:465] device: 0, cuDNN Version: 7.6.Loading weights from checkpoint_94.697/epoch_11.pdparams/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/math_op_patch.py:253: UserWarning: The dtype of left and right variables are not the same, left dtype is paddle.float32, but right dtype is paddle.int64, the right dtype will convert to paddle.float32 format(lhs_dtype, rhs_dtype, lhs_dtype))utils/demo_images/000002.jpg prediction: train,utils/demo_images/000007.jpg prediction: car,utils/demo_images/000004.jpg prediction: car,utils/demo_images/000009.jpg prediction: horse,person,utils/demo_images/000001.jpg prediction: dog,person,utils/demo_images/000006.jpg prediction: chair,登录后复制 ? ? ? ?

2.5 TIPC

注意:本部分为论文复现赛内容,只是为了验证整个项目的训练推理的正确性。学习目的可以不进行这部分的运行,即这部分非项目必要部分。

首先安装auto_log,需要进行安装,安装方式如下: auto_log的详细介绍参考https://github.com/LDOUBLEV/AutoLog。

git clone https://github.com/LDOUBLEV/AutoLogcd AutoLog/pip3 install -r requirements.txtpython3 setup.py bdist_wheelpip3 install ./dist/auto_log-1.2.0-py3-none-any.whl登录后复制 ? ? ? ?

进行TIPC:在命令行执行

bash test_tipc/prepare.sh test_tipc/configs/CSRARes101/train_infer_python.txt 'lite_train_lite_infer'bash test_tipc/test_train_inference_python.sh test_tipc/configs/CSRARes101/train_infer_python.txt 'lite_train_lite_infer'登录后复制 ? ? ? ?

注意:由于代码中每次训练需要生成数据集的标签json文件,进行tipc会覆盖原来data目录下的json文件,所以进行tipc后要进行完整训练的话。需要重新为完整数据集生成json文件,也就是重新执行数据准备的步骤

2.6 模型导出与推理

In [?]

!python export_model.py --model resnet101 --num_heads 1 --lam 0.1 --img_size=448 --model_path=./output/epoch_11.pdparams --save_dir=./output登录后复制 ? ?In [3]

!python infer.py --use_gpu=True --model_file=output/model.pdmodel --input_file=utils/demo_images --params_file=output/model.pdiparams登录后复制 ? ? ? ?

Inference model(CSRARes101)...W0410 20:56:50.359391 12322 analysis_predictor.cc:795] The one-time configuration of analysis predictor failed, which may be due to native predictor called first and its configurations taken effect.--- Running analysis [ir_graph_build_pass]--- Running analysis [ir_graph_clean_pass]--- Running analysis [ir_analysis_pass]--- Running IR pass [is_test_pass]--- Running IR pass [simplify_with_basic_ops_pass]--- Running IR pass [conv_affine_channel_fuse_pass]--- Running IR pass [conv_eltwiseadd_affine_channel_fuse_pass]--- Running IR pass [conv_bn_fuse_pass]I0410 20:56:50.920820 12322 fuse_pass_base.cc:57] --- detected 104 subgraphs--- Running IR pass [conv_eltwiseadd_bn_fuse_pass]--- Running IR pass [embedding_eltwise_layernorm_fuse_pass]--- Running IR pass [multihead_matmul_fuse_pass_v2]--- Running IR pass [squeeze2_matmul_fuse_pass]--- Running IR pass [reshape2_matmul_fuse_pass]--- Running IR pass [flatten2_matmul_fuse_pass]--- Running IR pass [map_matmul_v2_to_mul_pass]--- Running IR pass [map_matmul_v2_to_matmul_pass]--- Running IR pass [map_matmul_to_mul_pass]--- Running IR pass [fc_fuse_pass]--- Running IR pass [fc_elementwise_layernorm_fuse_pass]--- Running IR pass [conv_elementwise_add_act_fuse_pass]--- Running IR pass [conv_elementwise_add2_act_fuse_pass]--- Running IR pass [conv_elementwise_add_fuse_pass]--- Running IR pass [transpose_flatten_concat_fuse_pass]--- Running IR pass [runtime_context_cache_pass]--- Running analysis [ir_params_sync_among_devices_pass]I0410 20:56:51.119207 12322 ir_params_sync_among_devices_pass.cc:45] Sync params from CPU to GPU--- Running analysis [adjust_cudnn_workspace_size_pass]--- Running analysis [inference_op_replace_pass]--- Running analysis [memory_optimize_pass]I0410 20:56:52.790841 12322 memory_optimize_pass.cc:216] Cluster name : relu_18.tmp_0 size: 6422528I0410 20:56:52.790884 12322 memory_optimize_pass.cc:216] Cluster name : x size: 2408448I0410 20:56:52.790887 12322 memory_optimize_pass.cc:216] Cluster name : tmp_2 size: 12845056I0410 20:56:52.790899 12322 memory_optimize_pass.cc:216] Cluster name : relu_3.tmp_0 size: 12845056I0410 20:56:52.790905 12322 memory_optimize_pass.cc:216] Cluster name : relu_9.tmp_0 size: 12845056--- Running analysis [ir_graph_to_program_pass]I0410 20:56:52.913156 12322 analysis_predictor.cc:714] ======= optimize end =======I0410 20:56:52.924579 12322 naive_executor.cc:98] --- skip [feed], feed -> xI0410 20:56:52.928333 12322 naive_executor.cc:98] --- skip [tmp_38], fetch -> fetchW0410 20:56:52.950525 12322 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1W0410 20:56:52.954545 12322 device_context.cc:465] device: 0, cuDNN Version: 7.6.utils/demo_images/000002.jpgprediction: train,utils/demo_images/000007.jpgprediction: car,utils/demo_images/000004.jpgprediction: car,utils/demo_images/000009.jpgprediction: horse,person,utils/demo_images/000001.jpgprediction: dog,person,utils/demo_images/000006.jpgprediction: chair,登录后复制 ? ? ? ?

导出的模型推理结果与动态图预测结果一致。

复现心得与相关信息

复现心得

多标签图像识别是一项具有挑战性的实用计算机视觉任务。然而,该领域的进展往往具有方法复杂、计算量大、缺乏直观解释的特点。而这篇论文则从很简单的结构设计出发,仅用几行代码,在许多不同的预训练模型和数据集上实现一致的改进,而无需任何额外的训练。CSRA 既易于实现又易于计算,还具有直观的解释。

非常值得读者在图像分类方面的进阶学习!

本次复现也是我在图像分类领域的第一次复现,同时也是第一次完成TIPC任务,学习到了TIPC的内涵,可以帮助别人更快的验证你的模型。

复现的经验分享可以从两个方面来讲:第一步是熟悉论文的核心思想和参考代码的基本结构和核心代码,对复现的难度等有一个大概的把握。第二个是快速的代码对齐。这部分主要是需要熟悉不同框架与Paddle的api函数的功能,不熟悉也没关系,可以通过查阅官网的手册和利用X2Paddle提供的对齐文档进行快速上对齐。

相关信息

信息描述作者xbchen日期2022年4月框架版本PaddlePaddle==2.2.1应用场景图像分类硬件支持GPU、CPU

福利游戏

相关文章

更多

精选合集

更多

大家都在玩

热门话题

大家都在看

更多