ShiftViT:采用简单高效的移位操作证明Attention是否必要
时间:2025-07-16 | 作者: | 阅读:0注意力机制被视为视觉Transformer成功关键,但研究质疑其必要性。通过零参数零计算的Shift操作构建ShiftViT,替代注意力层,在分类、检测和分割等任务中表现媲美甚至优于Swin Transformer,表明注意力机制或非ViT成功关键,未来应关注其剩余部分。
ShiftViT:采用简单高效的移位操作证明Attention是否必要
摘要
????????注意力机制被广泛认为是视觉Transformer成功的关键,因为它提供了一种灵活而强大的方式来建模空间关系。然而,注意力机制真的是ViT不可或缺的一部分吗?它能被其他替代品取代吗?为了揭开注意力机制的神秘面纱,我们将其简化为一个极其简单的例子:ZERO FLOP和ZERO parameter。具体来说,我们要重新审视Shift操作。它不包含任何参数或算术计算。唯一的操作是在相邻特性之间交换一小部分通道。基于这个简单的操作,我们构建了一个新的骨干网络,即ShiftViT,其中的注意层被Shift操作所取代。令人惊讶的是,ShiftViT在几个主流任务中工作得相当好,例如,分类,检测和分割。性能与强大的基线Swin Transformer相当,甚至更好。这些结果表明,注意力机制可能不是使ViT成功的关键因素。它甚至可以被零参数操作取代。在今后的工作中,我们应该更多地关注ViT的剩余部分。
1. ShiftViT
????????为验证Transformer中的Attention机制是否是必要的,ShiftViT采用一个简单高效的Shift操作来代替Attention机制,Shift操作是沿上下左右对Shift部分进行偏移操作(与S2MLP特别相似,不同的是S2MLP将所有都采用Shift操作,而ShiftViT仅对一部分使用Shift操作):
z^[0:H,1:W,0:γC]z^[0:H,0:W?1,γC:2γC]z^[0:H?1,0:W,2γC:3γC]z^[1:H,0:W,3γC:4γC]z^[0:H,0:W,4γC:C]←z[0:H,0:W?1,0:γC]←z[0:H,1:W,γC:2γC]←z[1:H,0:W,2γC:3γC]←z[0:H?1,0:W,3γC:4γC]←z[0:H,0:W,4γC:C]
2. 代码复现
2.1 下载并导入所需的库
In [?]!pip install einops-0.3.0-py3-none-any.whl登录后复制In [?]
%matplotlib inlineimport paddleimport numpy as npimport matplotlib.pyplot as pltfrom paddle.vision.datasets import Cifar10from paddle.vision.transforms import Transposefrom paddle.io import Dataset, DataLoaderfrom paddle import nnimport paddle.nn.functional as Fimport paddle.vision.transforms as transformsimport osimport matplotlib.pyplot as pltfrom matplotlib.pyplot import figurefrom einops.layers.paddle import Rearrange, Reducefrom einops import rearrange登录后复制
2.2 创建数据集
In [16]train_tfm = transforms.Compose([ transforms.Resize((230, 230)), transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2), transforms.RandomResizedCrop(224, scale=(0.6, 1.0)), transforms.RandomHorizontalFlip(0.5), transforms.RandomRotation(20), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),])test_tfm = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),])登录后复制In [17]
paddle.vision.set_image_backend('cv2')# 使用Cifar10数据集train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)print(”train_dataset: %d“ % len(train_dataset))print(”val_dataset: %d“ % len(val_dataset))登录后复制
train_dataset: 50000val_dataset: 10000登录后复制In [18]
batch_size=128登录后复制In [19]
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)登录后复制
2.3 模型的创建
2.3.1 标签平滑
In [8]class LabelSmoothingCrossEntropy(nn.Layer): def __init__(self, smoothing=0.1): super().__init__() self.smoothing = smoothing def forward(self, pred, target): confidence = 1. - self.smoothing log_probs = F.log_softmax(pred, axis=-1) idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1) nll_loss = paddle.gather_nd(-log_probs, index=idx) smooth_loss = paddle.mean(-log_probs, axis=-1) loss = confidence * nll_loss + self.smoothing * smooth_loss return loss.mean()登录后复制
2.3.2 DropPath
In [8]def drop_path(x, drop_prob=0.0, training=False): ”“” Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... “”“ if drop_prob == 0.0 or not training: return x keep_prob = paddle.to_tensor(1 - drop_prob) shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) random_tensor = paddle.floor(random_tensor) # binarize output = x.divide(keep_prob) * random_tensor return outputclass DropPath(nn.Layer): def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training)登录后复制
2.3.3 ShiftViT模型的创建
In [9]class MLP(nn.Layer): def __init__(self, in_features, hidden_features=None, out_features=None,act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2D(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv2D(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x登录后复制In [10]
class Shift(nn.Layer): def __init__(self, n_div): super().__init__() self.n_div = n_div def forward(self, x): B, C, H, W = x.shape g = C // self.n_div # out = paddle.zeros_like(x) x[:, g * 0:g * 1, :, :-1] = x[:, g * 0:g * 1, :, 1:] # shift left x[:, g * 1:g * 2, :, 1:] = x[:, g * 1:g * 2, :, :-1] # shift right x[:, g * 2:g * 3, :-1, :] = x[:, g * 2:g * 3, 1:, :] # shift up x[:, g * 3:g * 4, 1:, :] = x[:, g * 3:g * 4, :-1, :] # shift down x[:, g * 4:, :, :] = x[:, g * 4:, :, :] # no shift return x登录后复制In [11]
class ShiftViTBlock(nn.Layer): def __init__(self, dim, n_div=12, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, input_resolution=None): super().__init__() self.input_resolution = input_resolution self.mlp_ratio = mlp_ratio self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.n_div = n_div self.shift = Shift(n_div) def forward(self, x): x = self.shift(x) shortcut = x x = shortcut + self.drop_path(self.mlp(self.norm(x.transpose([0, 2, 3, 1])).transpose([0, 3, 1, 2]))) return x登录后复制In [12]
class BasicLayer(nn.Layer): def __init__(self, dim, input_resolution, depth, n_div=12, mlp_ratio=4., drop=0., drop_path=None, norm_layer=None, downsample=True, act_layer=nn.GELU): super(BasicLayer, self).__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth # build blocks self.blocks = nn.LayerList([ ShiftViTBlock(dim=dim, n_div=n_div, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path[i], norm_layer=norm_layer, act_layer=act_layer, input_resolution=input_resolution) for i in range(depth) ]) # patch merging layer if downsample: self.downsample = nn.Sequential( nn.GroupNorm(num_groups=1, num_channels=dim), nn.Conv2D(dim, dim * 2, kernel_size=2, stride=2,bias_attr=False) ) else: self.downsample = None def forward(self, x): for blk in self.blocks: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x登录后复制In [13]
class ShiftViT(nn.Layer): def __init__(self,n_div=12, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=(2, 2, 6, 2), mlp_ratio=2., drop_rate=0., drop_path_rate=0.1, patch_norm=True, **kwargs): super().__init__() norm_layer = nn.LayerNorm act_layer = nn.GELU self.num_classes = num_classes self.num_layers = len(depths) self.embed_dim = embed_dim self.patch_norm = patch_norm self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) self.mlp_ratio = mlp_ratio # split image into non-overlapping patches self.patch_embed = nn.Sequential( nn.Conv2D(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size), Rearrange('b c h w->b h w c'), nn.LayerNorm(embed_dim) if self.patch_norm else nn.Identity(), Rearrange('b h w c->b c h w') ) # num_patches = self.patch_embed.num_patches patches_resolution = [img_size // patch_size, img_size // patch_size] self.patches_resolution = patches_resolution self.pos_drop = nn.Dropout(p=drop_rate) # stochastic depth decay rule dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, sum(depths))] # build layers self.layers = nn.LayerList() for i_layer in range(self.num_layers): layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), n_div=n_div, input_resolution=(patches_resolution[0] // (2 ** i_layer), patches_resolution[1] // (2 ** i_layer)), depth=depths[i_layer], mlp_ratio=self.mlp_ratio, drop=drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=(i_layer < self.num_layers - 1), act_layer=act_layer) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool2D(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): tn = nn.initializer.TruncatedNormal(std=.02) zeros = nn.initializer.Constant(0.) ones = nn.initializer.Constant(1.) if isinstance(m, nn.Linear): tn(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: zeros(m.bias) elif isinstance(m, (nn.Conv1D, nn.Conv2D)): tn(m.weight) if m.bias is not None: zeros(m.bias) elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)): zeros(m.bias) ones(m.weight) def forward_features(self, x): x = self.patch_embed(x) x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x.transpose([0, 2, 3, 1])).transpose([0, 3, 1, 2]) x = self.avgpool(x) x = paddle.flatten(x, 1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return x登录后复制
2.3.4 模型的参数
In [?]# Shift-Tmodel = ShiftViT(n_div=12, embed_dim=96, depths=(6, 8, 18, 6), num_classes=10)paddle.summary(model, (1, 3, 224, 224))登录后复制
In [?]# Shift-Smodel = ShiftViT(n_div=12, embed_dim=96, depths=(10, 18, 36, 10), num_classes=10)paddle.summary(model, (1, 3, 224, 224))登录后复制
In [?]# Shift-Bmodel = ShiftViT(n_div=16, embed_dim=128, depths=(10, 18, 36, 10), num_classes=10)paddle.summary(model, (1, 3, 224, 224))登录后复制
In [?]# Shift-oursmodel = ShiftViT(n_div=12, embed_dim=96, depths=(3, 4, 9, 3), num_classes=10)paddle.summary(model, (1, 3, 224, 224))登录后复制
2.4 训练
In [19]learning_rate = 0.001n_epochs = 100paddle.seed(42)np.random.seed(42)登录后复制In [?]
work_path = 'work/model'# Shift-oursmodel = ShiftViT(n_div=12, embed_dim=96, depths=(3, 4, 9, 3), num_classes=10)criterion = LabelSmoothingCrossEntropy()scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)gate = 0.0threshold = 0.0best_acc = 0.0val_acc = 0.0loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}} # for recording lossacc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}} # for recording accuracyloss_iter = 0acc_iter = 0for epoch in range(n_epochs): # ---------- Training ---------- model.train() train_num = 0.0 train_loss = 0.0 val_num = 0.0 val_loss = 0.0 accuracy_manager = paddle.metric.Accuracy() val_accuracy_manager = paddle.metric.Accuracy() print(”#===epoch: {}, lr={:.10f}===#“.format(epoch, optimizer.get_lr())) for batch_id, data in enumerate(train_loader): x_data, y_data = data labels = paddle.unsqueeze(y_data, axis=1) logits = model(x_data) loss = criterion(logits, y_data) acc = accuracy_manager.compute(logits, labels) accuracy_manager.update(acc) if batch_id % 10 == 0: loss_record['train']['loss'].append(loss.numpy()) loss_record['train']['iter'].append(loss_iter) loss_iter += 1 loss.backward() optimizer.step() scheduler.step() optimizer.clear_grad() train_loss += loss train_num += len(y_data) total_train_loss = (train_loss / train_num) * batch_size train_acc = accuracy_manager.accumulate() acc_record['train']['acc'].append(train_acc) acc_record['train']['iter'].append(acc_iter) acc_iter += 1 # Print the information. print(”#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#“.format(epoch, total_train_loss.numpy(), train_acc*100)) # ---------- Validation ---------- model.eval() for batch_id, data in enumerate(val_loader): x_data, y_data = data labels = paddle.unsqueeze(y_data, axis=1) with paddle.no_grad(): logits = model(x_data) loss = criterion(logits, y_data) acc = val_accuracy_manager.compute(logits, labels) val_accuracy_manager.update(acc) val_loss += loss val_num += len(y_data) total_val_loss = (val_loss / val_num) * batch_size loss_record['val']['loss'].append(total_val_loss.numpy()) loss_record['val']['iter'].append(loss_iter) val_acc = val_accuracy_manager.accumulate() acc_record['val']['acc'].append(val_acc) acc_record['val']['iter'].append(acc_iter) print(”#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#“.format(epoch, total_val_loss.numpy(), val_acc*100)) # ===================save==================== if val_acc > best_acc: best_acc = val_acc paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams')) paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))print(best_acc)paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))登录后复制
2.5 结果分析
In [21]def plot_learning_curve(record, title='loss', ylabel='CE Loss'): ''' Plot learning curve of your CNN ''' maxtrain = max(map(float, record['train'][title])) maxval = max(map(float, record['val'][title])) ymax = max(maxtrain, maxval) * 1.1 mintrain = min(map(float, record['train'][title])) minval = min(map(float, record['val'][title])) ymin = min(mintrain, minval) * 0.9 total_steps = len(record['train'][title]) x_1 = list(map(int, record['train']['iter'])) x_2 = list(map(int, record['val']['iter'])) figure(figsize=(10, 6)) plt.plot(x_1, record['train'][title], c='tab:red', label='train') plt.plot(x_2, record['val'][title], c='tab:cyan', label='val') plt.ylim(ymin, ymax) plt.xlabel('Training steps') plt.ylabel(ylabel) plt.title('Learning curve of {}'.format(title)) plt.legend() plt.show()登录后复制
2.5.1 loss和acc曲线
In [22]plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')登录后复制
<Figure size 1000x600 with 1 Axes>登录后复制登录后复制In [23]
plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')登录后复制
<Figure size 1000x600 with 1 Axes>登录后复制登录后复制In [24]
import timework_path = 'work/model'model = ShiftViT(n_div=12, embed_dim=96, depths=(3, 4, 9, 3), num_classes=10)model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))model.set_state_dict(model_state_dict)model.eval()aa = time.time()for batch_id, data in enumerate(val_loader): x_data, y_data = data labels = paddle.unsqueeze(y_data, axis=1) with paddle.no_grad(): logits = model(x_data)bb = time.time()print(”Throughout:{}“.format(int(len(val_dataset)//(bb - aa))))登录后复制
Throughout:794登录后复制
2.5.2 预测与真实标签比较
In [25]def get_cifar10_labels(labels): ”“”返回CIFAR10数据集的文本标签。“”“ text_labels = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] return [text_labels[int(i)] for i in labels]登录后复制In [26]
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5): ”“”Plot a list of images.“”“ figsize = (num_cols * scale, num_rows * scale) _, axes = plt.subplots(num_rows, num_cols, figsize=figsize) axes = axes.flatten() for i, (ax, img) in enumerate(zip(axes, imgs)): if paddle.is_tensor(img): ax.imshow(img.numpy()) else: ax.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) if pred or gt: ax.set_title(”pt: “ + pred[i] + ”ngt: “ + gt[i]) return axes登录后复制In [27]
work_path = 'work/model'X, y = next(iter(DataLoader(val_dataset, batch_size=18)))model = ShiftViT(n_div=12, embed_dim=96, depths=(3, 4, 9, 3), num_classes=10)model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))model.set_state_dict(model_state_dict)model.eval()logits = model(X)y_pred = paddle.argmax(logits, -1)X = paddle.transpose(X, [0, 2, 3, 1])axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))plt.show()登录后复制
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).登录后复制
<Figure size 2700x150 with 18 Axes>登录后复制
3. 对比实验
注:Swin代码来自浅析 Swin Transformer,实验结果在main-Copy2.ipynb
总结
????????本文用了一个简单的Shift操作证明了Vision Transformer中的Attention不是必要的,与Swin在参数可比的情况下精度高了0.06794(小数据集如CIFAR10上ShiftViT比Swin优势明显,在大数据集上ShiftViT与Swin性能差不多)
福利游戏
相关文章
更多-
- 2025迷你世界慕容笑天永久激活码
- 时间:2025-07-16
-
- 2025迷你世界光明法典米诺斯永久激活码
- 时间:2025-07-16
-
- Deepseek 满血版携手 Reedsy Book Editor,排版专业电子书?
- 时间:2025-07-16
-
- ChatGPT如何生成产品原型 ChatGPT工业设计辅助方案
- 时间:2025-07-16
-
- 个人实现的反向文心(无需训练的AI看图说话,你不心动?)
- 时间:2025-07-16
-
- 【Autoregressive】从0构造一个基于Paddle的自回归模型库
- 时间:2025-07-16
-
- 余承东亲自体验!享界S9T旅行车实拍图曝光:车身修长
- 时间:2025-07-16
-
- 指间山海哪吒技能全解析:残血战神的不灭法则
- 时间:2025-07-16
大家都在玩
热门话题
大家都在看
更多-
- 一汽奥迪公开反对汽车用消费级芯片!汽车正变成快消品吗 网友专家博主们吵翻
- 时间:2025-07-16
-
- 比特币合约交易规则详解:掌握策略
- 时间:2025-07-16
-
- 改写光合作用!麻省理工科学家让植物加速生长
- 时间:2025-07-16
-
- 央视曝App小程序成隐私刺客:13万份体检报告险泄露
- 时间:2025-07-16
-
- WBTC什么时候上交易所
- 时间:2025-07-16
-
- 京东回应外卖取消超时20分钟免单:外卖准时率已大幅提升
- 时间:2025-07-16
-
- 比亚迪出海舰队再加一!第七艘滚装船“郑州号”今日正式交付
- 时间:2025-07-16
-
- Pi Network杠杆能赚钱吗
- 时间:2025-07-16