Vision Transformer 实训:CIFAR-10 与 Swin 分析


一、实验内容

二、实验过程

2.1 视频学习

笔记如下:

2.1.1 Vision Transformer (ViT)

https://blog.csdn.net/2401_89740151/article/details/154704991

2.1.2 Swin Transformer

https://blog.csdn.net/2401_89740151/article/details/154705370

2.1.3 基础知识备忘录

https://blog.csdn.net/2401_89740151/category_13063550.html

2.2 pytorch 实现 ViT

CIFAR-10(32×32) 做演示,模型是小号 ViT,训练更容易收敛。

0) 环境检查与准备

# 如果是打开一个全新 Colab,建议先重启一次运行时,然后再跑本单元

import os, random, math, time, sys, platform
import numpy as np
import torch

# 1) 打印设备信息
device = "cuda" if torch.cuda.is_available() else "cpu"
print("PyTorch:", torch.__version__)
print("Device:", device)
if device == "cuda":
print("CUDA name:", torch.cuda.get_device_name(0))
print("CUDA capability:", torch.cuda.get_device_capability(0))

# 2) 安装/导入 torchvision(Colab 基本都有,如缺再装)
try:
import torchvision
except:
!pip -q install torchvision

# 3) 设定随机种子(可复现)
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # 多卡/多进程时
# 确保更可复现(会略降速;如不需要严格复现,可注释掉下面两行)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

set_seed(42)

1) 配置参数

说明:把常用超参数放在一起,便于修改实验规模。

from dataclasses import dataclass

@dataclass
class CFG:
# 数据集
dataset: str = "cifar10"
img_size: int = 32 # CIFAR-10 本身 32x32
patch_size: int = 4 # 每个 patch = 4x4, => (32/4)^2=64个patch
num_classes: int = 10

# 模型结构(小号,便于在 Colab 训练)
embed_dim: int = 256
depth: int = 6
num_heads: int = 8
mlp_ratio: float = 4.0

# 优化器/训练
epochs: int = 8
batch_size: int = 128
lr: float = 3e-4
weight_decay: float = 0.05
warmup_epochs: int = 1 # 简单 warmup
mixup_alpha: float = 0.0 # 先关掉 mixup;需要时可设 0.2
label_smoothing: float = 0.0 # 先关掉 label smoothing;需要时设 0.1

# Dropout
attn_drop: float = 0.0
proj_drop: float = 0.0
mlp_drop: float = 0.0
drop_embed: float = 0.0

# 训练细节
num_workers: int = 2
amp: bool = True # 是否开启混合精度(Colab GPU 建议 True)
print_every: int = 100

cfg = CFG()
cfg

2) 数据集与数据增强(CIFAR-10)

训练集做轻度增强(随机裁剪/水平翻转)。

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 训练集增强:轻度增强足够;小模型更易学稳
train_tf = transforms.Compose([
transforms.RandomCrop(cfg.img_size, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])

# 测试集:仅 ToTensor
test_tf = transforms.Compose([
transforms.ToTensor(),
])

if cfg.dataset.lower() == "cifar10":
train_set = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
test_set = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)
else:
raise NotImplementedError("当前 demo 只做 CIFAR-10,可以拓展到其他数据集。")

train_loader = DataLoader(train_set, batch_size=cfg.batch_size, shuffle=True,
num_workers=cfg.num_workers, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=cfg.batch_size, shuffle=False,
num_workers=cfg.num_workers, pin_memory=True)

len(train_set), len(test_set)

3) ViT 组件:PatchEmbed / Attention / MLP / Block

按 Transformer Encoder 的标准实现。

import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbed(nn.Module):
"""
用 Conv2d 实现 patch 切分 + 线性投影:
- 输入: [B, C, H, W]
- 输出: [B, N, D],其中 N = (H/ps)*(W/ps), D=embed_dim
"""
def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
super().__init__()
assert img_size % patch_size == 0, "img_size 必须能整除 patch_size"
self.num_patches = (img_size // patch_size) ** 2

self.proj = nn.Conv2d(
in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size
)

def forward(self, x):
x = self.proj(x) # [B, D, H/ps, W/ps]
x = x.flatten(2) # [B, D, N]
x = x.transpose(1, 2) # [B, N, D]
return x


class MLP(nn.Module):
"""
前馈网络(FFN/MLP):
- 结构: Linear -> GELU -> Dropout -> Linear -> Dropout
- 维度: D -> 4D -> D(mlp_ratio 可调)
"""
def __init__(self, dim, mlp_ratio=4.0, drop=0.0):
super().__init__()
hidden = int(dim * mlp_ratio)
self.fc1 = nn.Linear(dim, hidden)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden, dim)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x


class Attention(nn.Module):
"""
多头自注意力:
- 输入/输出: [B, N, D]
- 步骤: 线性映射得到 qkv -> 分头 -> 缩放点积注意力 -> 拼回 -> 输出投影
"""
def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0.0, proj_drop=0.0):
super().__init__()
assert dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
B, N, D = x.shape
qkv = self.qkv(x) # [B, N, 3D]
qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, heads, N, head_dim]
q, k, v = qkv[0], qkv[1], qkv[2]

attn = (q @ k.transpose(-2, -1)) * self.scale # [B, heads, N, N]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

out = attn @ v # [B, heads, N, head_dim]
out = out.transpose(1, 2).contiguous().reshape(B, N, D)
out = self.proj(out)
out = self.proj_drop(out)
return out


class TransformerEncoderBlock(nn.Module):
"""
标准 Transformer Encoder(Pre-LN 结构):
x = x + MSA(LN(x))
x = x + MLP(LN(x))
"""
def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True,
attn_drop=0.0, proj_drop=0.0, mlp_drop=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, num_heads, qkv_bias, attn_drop, proj_drop)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim, mlp_ratio, mlp_drop)

def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x

4) ViT 主体:位置编码 / [CLS] / 堆叠 Block / 分类头

实现 ViT 主体。位置编码与 [CLS] 是关键;输出用 [CLS] 表示整图特征。

class ViT(nn.Module):
def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10,
embed_dim=256, depth=6, num_heads=8, mlp_ratio=4.0,
qkv_bias=True, attn_drop=0.0, proj_drop=0.0,
mlp_drop=0.0, drop_embed=0.0):
super().__init__()

# Patch Embedding
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches

# 可学习的 [CLS] token + 可学习的位置编码
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, embed_dim))
self.pos_drop = nn.Dropout(drop_embed)

# 堆叠多个 Transformer Block
self.blocks = nn.ModuleList([
TransformerEncoderBlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop, mlp_drop=mlp_drop
) for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)

# 分类头
self.head = nn.Linear(embed_dim, num_classes)

self.apply(self._init_weights)

def _init_weights(self, m):
# 线性层:xavier 更稳;bias 置 0
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
# LayerNorm:weight=1, bias=0
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
# 可学习参数(pos_embed/cls_token)默认 torch 会随机化,这里保持默认即可

def forward(self, x):
B = x.size(0)
x = self.patch_embed(x) # [B, N, D]
cls = self.cls_token.expand(B, -1, -1) # [B, 1, D]
x = torch.cat([cls, x], dim=1) # [B, 1+N, D]

x = x + self.pos_embed[:, : x.size(1), :]
x = self.pos_drop(x)

for blk in self.blocks:
x = blk(x)
x = self.norm(x)
cls_feat = x[:, 0] # 取 [CLS]
logits = self.head(cls_feat)
return logits

5) 训练准备:损失/优化器/调度器/评估函数

使用 AdamW + 余弦退火;可选混合精度(AMP);提供 evaluate()。

import torch.optim as optim

# Label smoothing(可选)
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, smoothing=0.1):
super().__init__()
self.smoothing = smoothing

def forward(self, pred, target):
"""
pred: [B, C] logits
target: [B]
"""
log_probs = F.log_softmax(pred, dim=-1)
n_classes = pred.size(-1)
with torch.no_grad():
true_dist = torch.zeros_like(log_probs)
true_dist.fill_(self.smoothing / (n_classes - 1))
true_dist.scatter_(1, target.unsqueeze(1), 1 - self.smoothing)
return torch.mean(torch.sum(-true_dist * log_probs, dim=-1))


def build_model_and_opt(cfg):
model = ViT(
img_size=cfg.img_size, patch_size=cfg.patch_size, in_chans=3,
num_classes=cfg.num_classes, embed_dim=cfg.embed_dim, depth=cfg.depth,
num_heads=cfg.num_heads, mlp_ratio=cfg.mlp_ratio, qkv_bias=True,
attn_drop=cfg.attn_drop, proj_drop=cfg.proj_drop, mlp_drop=cfg.mlp_drop,
drop_embed=cfg.drop_embed
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

# Warmup + Cosine
def lr_lambda(current_epoch):
if current_epoch < cfg.warmup_epochs:
return float(current_epoch + 1) / float(max(1, cfg.warmup_epochs))
# 余弦从 1 -> 0
progress = (current_epoch - cfg.warmup_epochs) / max(1, cfg.epochs - cfg.warmup_epochs)
return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# Loss:可选 label smoothing
if cfg.label_smoothing > 0.0:
criterion = LabelSmoothingCrossEntropy(cfg.label_smoothing)
else:
criterion = nn.CrossEntropyLoss()

return model, optimizer, scheduler, criterion


@torch.no_grad()
def evaluate(model, loader, criterion):
model.eval()
tot, correct, loss_sum = 0, 0, 0.0
for x, y in loader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = criterion(logits, y)
loss_sum += loss.item() * x.size(0)
pred = logits.argmax(dim=1)
correct += (pred == y).sum().item()
tot += x.size(0)
return loss_sum / tot, correct / tot

6) 训练循环

结果如图1.

from pathlib import Path
save_dir = Path("./checkpoints")
save_dir.mkdir(parents=True, exist_ok=True)
ckpt_path = save_dir / "vit_cifar10_small.pth"

model, optimizer, scheduler, criterion = build_model_and_opt(cfg)

scaler = torch.cuda.amp.GradScaler(enabled=(cfg.amp and device=="cuda"))

best_acc = 0.0
for epoch in range(1, cfg.epochs + 1):
model.train()
running_loss, seen = 0.0, 0

start = time.time()
for it, (x, y) in enumerate(train_loader, 1):
x, y = x.to(device), y.to(device)

optimizer.zero_grad(set_to_none=True)
if scaler.is_enabled():
with torch.cuda.amp.autocast():
logits = model(x)
loss = criterion(logits, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
logits = model(x)
loss = criterion(logits, y)
loss.backward()
optimizer.step()

running_loss += loss.item() * x.size(0)
seen += x.size(0)

if it % cfg.print_every == 0:
print(f"[Ep {epoch}/{cfg.epochs}] it {it:04d} loss={running_loss/seen:.4f}")

# 验证
val_loss, val_acc = evaluate(model, test_loader, criterion)
scheduler.step()
dt = time.time() - start
print(f"Epoch {epoch:02d} | train_loss={running_loss/seen:.4f} "
f"| val_loss={val_loss:.4f} | val_acc={val_acc*100:.2f}% | time={dt:.1f}s")

# 保存最好模型
if val_acc > best_acc:
best_acc = val_acc
torch.save({"model": model.state_dict(),
"cfg": cfg.__dict__,
"acc": best_acc}, ckpt_path)
print(f" ↳ Saved best ckpt to: {ckpt_path} (acc={best_acc*100:.2f}%)")

7) 加载权重与快速推理

加载保存的权重,以及对几张图片做预测。结果如图2.

import torchvision
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

# CIFAR-10 的类别名
CIFAR10_CLASSES = ('airplane','automobile','bird','cat','deer',
'dog','frog','horse','ship','truck')

# 加载最优权重
if ckpt_path.exists():
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt["model"])
print(f"Loaded ckpt acc={ckpt.get('acc', 0)*100:.2f}%")
else:
print("Warn: ckpt not found, using current (random) weights.")

# 取测试集前 16 张图做可视化 & 预测
model.eval()
images, labels = next(iter(test_loader))
images, labels = images[:16].to(device), labels[:16].to(device)

with torch.no_grad():
logits = model(images)
probs = logits.softmax(dim=1)
preds = probs.argmax(dim=1)

# 可视化
grid = make_grid(images.cpu(), nrow=8, padding=2)
plt.figure(figsize=(12, 4))
plt.axis("off")
plt.title("CIFAR-10 Samples")
plt.imshow(np.transpose(grid.numpy(), (1, 2, 0)))
plt.show()

# 打印预测结果
print("Predictions:")
for i in range(images.size(0)):
print(f" #{i:02d} pred={CIFAR10_CLASSES[preds[i].item()]:>10s} "
f"(p={probs[i, preds[i]].item():.2f}) | "
f"true={CIFAR10_CLASSES[labels[i].item()]}")
图1 训练循环结果 图2 推理结果
图1 训练循环结果 图2 推理结果

2.3 回答相关问题

2.3.1 在ViT中要降低 Attention的计算量,有哪些方法?(提示:Swin的 Window attention,PVT的attention)

A. 限域与分块

  • 使用窗口注意力:只在固定大小窗口内做注意力,计算量近似线性随分辨率增长。
  • 使用条带/十字/轴向注意力:改成行/列或十字形局部关注。

B. 下采样 K/V 或减少序列长度

  • Spatial-Reduction Attention:对 K/V 特征做步幅下采样,Q 全量、K/V 变少,从而把复杂度从 N² 降到 N·(N/r²)。
  • Token Pooling / Patch Merging / Strided Conv:在层间逐步减小 token 个数

2.3.2 Swin体现了一种什么思路?对后来工作有哪些启发?(提示:先局部再整体)

Swin 的核心思想其实挺符合人类的视觉习惯——“先看局部,再看整体”。一开始,它在小窗口里看局部细节,就像你先盯着图片的一块区域。 然后它在下一层把窗口平移一下,这样前后两层的窗口就会有重叠,模型就能逐步把这些局部信息拼成全局的理解。

这种从局部到全局的层次式结构非常聪明。它让 Transformer 拥有了类似 CNN 的多尺度特征(高分辨率看细节、低分辨率看整体),而且计算量控制得很好。

也因为这个设计,后面很多模型都跟着它走,比如 CSWin、Twins、SwinV2,都在模仿它的思路:先高效建局部,再逐步扩全局。Swin 的启发就是让大家意识到:Transformer 不一定要看全局,分阶段整合信息反而更高效、更稳。


2.3.3 有些网络将CNN和Transformer结合,为什么一般把 CNN block放在面前,Transformer block放在后面?

CNN 擅长低层局部特征,Transformer 擅长高层全局语义,前者打地基,后者盖高楼。CNN 的卷积特别擅长捕捉局部纹理,像边缘、角点、颜色变化这些低级特征。而 Transformer 擅长的是全局关系,比如哪两个区域属于同一个物体、或者整体的语义是什么。

所以在结构上,大多数人会先用 CNN 做前端特征提取——相当于帮 Transformer 打好底,再用 Transformer 去处理全局信息。

还有一个很现实的原因:图像刚开始的时候分辨率特别大,patch 特别多,如果一上来就用 Transformer,显存会炸。CNN 可以在前面先做降采样,把特征图缩小,给后面的 Transformer 减轻负担。


2.3.4 阅读并了解Restormer,思考:Transformer的基本结构为 attention+ FFN,这个工作分别做了哪些改进?

Restormer 是一个专门为**图像复原(比如去噪、去模糊、超分辨) **设计的 Transformer。它把原来的两个部分——Attention 和 FFN 都改得更懂图像了。

在 Attention 这块,它用了一个叫 MDTA(Multi-DConv Head Transposed Attention) 的结构。简单来说,它先在算注意力前后加了深度卷积,这样模型不光能看全局关系,还能看到局部纹理细节。另外,它的注意力不是在空间上算(HW×HW),而是在通道维度上算(C×C),这让高分辨率图片的计算量小很多,效率更高。

在 FFN 这块,它用了 GDFN(Gated-DConv Feed-Forward Network)。它在原来的两层全连接中间加了门控机制和深度卷积。门控让网络能自动判断哪些信息该保留,哪些可以忽略,而卷积让 FFN 有了局部感知能力,更适合处理图像。整体结构也变成了一个 U-Net 风格 的 Transformer,有 encoder、decoder 和跳跃连接,所以特别适合恢复类任务。

三、问题与体会

这次学习让我对 Transformer 在视觉领域的应用有了更深的理解。一开始我只知道 ViT 是把图像切成小块去做注意力,但没想到计算量这么大,也没意识到后面有那么多改进的方向。通过对比 Swin、PVT、Restormer 这些模型,我发现大家其实都在解决同一个问题——让 Transformer 更懂图像、更高效地处理高分辨率数据

我觉得 Swin 的思路特别有启发性。它没有一上来就做全局建模,而是从局部窗口开始,逐步扩展到全局。这种“先局部、再整体”的设计很像我们人看图的方式,也说明在深度学习里,很多有效的想法其实来自常识。


Author: linda1729
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint policy. If reproduced, please indicate source linda1729 !
评论
  TOC