GazeCapsNet详解:轻量级胶囊网络实现实时视线估计

引言:为什么需要轻量级视线估计

视线估计在车载场景的应用面临三个核心挑战:

挑战 传统方法问题 GazeCapsNet解决方案
算力限制 模型参数量大(>50M) 仅11.7M参数
实时性 推理延迟高(>50ms) 20ms延迟
鲁棒性 头位/光照变化敏感 Capsule保留空间关系

GazeCapsNet(2025年发表于Sensors)通过胶囊网络 + Self-Attention Routing实现了精度与效率的平衡。


一、胶囊网络基础

1.1 传统CNN的局限性

问题:CNN的池化操作丢失空间信息

1
2
3
4
传统CNN:
输入图像 → Conv → Pooling → Conv → Pooling → FC → 输出

空间信息丢失(位置不变性)

对视线估计的影响

  • 眼睛的位置信息丢失(上下左右)
  • 只能依赖特征组合,无法建模空间关系
  • 头位变化时精度下降明显

1.2 胶囊网络的核心思想

Capsule = 一组神经元

  • 输出是一个向量,而非标量
  • 向量的长度表示特征存在的概率
  • 向量的方向表示特征的姿态参数
1
2
传统神经元:y = f(Σwx + b)        → 标量
胶囊神经元:v = squash(Σc·û) → 向量

Squash函数

1
2
3
4
5
6
def squash(s):
"""
将向量压缩到[0,1]范围,长度表示概率
"""
norm = torch.norm(s, dim=-1, keepdim=True)
return (norm ** 2 / (1 + norm ** 2)) * (s / norm)

1.3 胶囊网络的优势

特性 CNN Capsule Network
输出 标量(激活值) 向量(姿态+概率)
空间关系 隐式学习 显式建模
视角变化 需要数据增强 等变表示
参数量 少(动态路由)

二、Self-Attention Routing(SAR)

2.1 传统路由的问题

动态路由(Dynamic Routing)

1
2
3
4
5
迭代过程(通常3次):
1. 预测:û_j|i = W_ij · u_i
2. 耦合:c_ij = softmax(b_ij)
3. 聚合:s_j = Σ c_ij · û_j|i
4. 更新:b_ij += û_j|i · v_j

问题

  • 迭代计算耗时长
  • 超参数(迭代次数)敏感
  • 难以并行化

2.2 Self-Attention Routing原理

核心思想:用注意力机制替代迭代路由

1
2
3
4
SAR(单次前向):
1. 预测:û_j|i = W_ij · u_i
2. 注意力:A_ij = softmax(û_j|i · û_j|i^T)
3. 聚合:v_j = squash(Σ A_ij · û_j|i)

数学表达

$$A_{ij} = \frac{\exp(\hat{u}{j|i} \cdot \hat{u}{j|i}^T)}{\sum_k \exp(\hat{u}{k|i} \cdot \hat{u}{k|i}^T)}$$

$$v_j = \text{squash}\left(\sum_i A_{ij} \cdot \hat{u}_{j|i}\right)$$

2.3 SAR代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttentionRouting(nn.Module):
def __init__(self, in_caps, out_caps, in_dim, out_dim):
super().__init__()
self.in_caps = in_caps
self.out_caps = out_caps

# 变换矩阵
self.W = nn.Parameter(torch.randn(in_caps, out_caps, in_dim, out_dim))

def forward(self, u):
"""
u: [batch, in_caps, in_dim] - 输入胶囊
返回: [batch, out_caps, out_dim] - 输出胶囊
"""
batch_size = u.size(0)

# 1. 预测变换
# u: [batch, in_caps, in_dim]
# W: [in_caps, out_caps, in_dim, out_dim]
u_hat = torch.einsum('bci,ijdo->bcjo', u, self.W)
# u_hat: [batch, in_caps, out_caps, out_dim]

# 2. 计算注意力权重
# 计算预测向量之间的相似度
u_hat_flat = u_hat.view(batch_size, self.in_caps, self.out_caps, -1)
attention = F.softmax(
torch.sum(u_hat_flat ** 2, dim=-1), # [batch, in_caps, out_caps]
dim=1
)

# 3. 加权聚合
s = torch.einsum('bco,bcod->bod', attention, u_hat)

# 4. Squash激活
v = self.squash(s)

return v

@staticmethod
def squash(s):
norm = torch.norm(s, dim=-1, keepdim=True)
return (norm ** 2 / (1 + norm ** 2)) * (s / (norm + 1e-8))

2.4 SAR vs 动态路由对比

维度 动态路由 SAR
计算次数 迭代3次 单次前向
延迟 +15ms 0ms
并行性
精度 基准 相当
可训练性

三、GazeCapsNet架构详解

3.1 整体架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
输入图像 (224×224)

┌─────────────────────────────────┐
│ 人脸检测 (SCRFD) │
│ - 检测人脸区域 │
│ - 裁剪并缩放到224×224
└─────────────────────────────────┘

┌─────────────────────────────────┐
│ 双流特征提取 │
│ ├── MobileNet v2 → 低级特征 │
│ └── ResNet-18 → 高级特征 │
│ 拼接 → combined_features │
└─────────────────────────────────┘

┌─────────────────────────────────┐
│ Primary Capsules │
│ - 将特征图转换为初级胶囊 │
│ - 每个胶囊8维向量 │
└─────────────────────────────────┘

┌─────────────────────────────────┐
│ Self-Attention Routing │
│ - 注意力加权的胶囊路由 │
│ - 生成Gaze Capsules (16维) │
└─────────────────────────────────┘

┌─────────────────────────────────┐
│ Gaze Regression Head │
│ - 输出3D视线向量 (pitch, yaw, roll)│
└─────────────────────────────────┘

3.2 双流特征提取

为什么用MobileNet v2 + ResNet-18?

网络 特点 贡献
MobileNet v2 轻量、快速 低级纹理特征
ResNet-18 残差连接 高级语义特征
融合 互补 平衡精度与速度

实现代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torchvision.models as models

class DualStreamExtractor(nn.Module):
def __init__(self):
super().__init__()
# 预训练骨干网络
mobilenet = models.mobilenet_v2(pretrained=True)
resnet = models.resnet18(pretrained=True)

# 移除分类头
self.mobile_features = mobilenet.features
self.resnet_features = nn.Sequential(*list(resnet.children())[:-2])

# 特征融合
self.fusion = nn.Conv2d(1280 + 512, 512, kernel_size=1)

def forward(self, x):
# 双流提取
f1 = self.mobile_features(x) # [B, 1280, 7, 7]
f2 = self.resnet_features(x) # [B, 512, 7, 7]

# 拼接融合
f = torch.cat([f1, f2], dim=1) # [B, 1792, 7, 7]
f = self.fusion(f) # [B, 512, 7, 7]

return f

3.3 Primary Capsules

从特征图到胶囊

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class PrimaryCapsules(nn.Module):
def __init__(self, in_channels=512, out_channels=32, dim=8):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels * dim,
kernel_size=3, stride=2, padding=1)
self.out_channels = out_channels
self.dim = dim

def forward(self, x):
# x: [B, 512, 7, 7]
out = self.conv(x) # [B, 32*8, 4, 4]

# 重塑为胶囊形式
batch_size = out.size(0)
out = out.view(batch_size, self.out_channels, -1, self.dim)
# [B, 32, 16, 8] - 32种胶囊,每种16个位置,每个8维

# Squash激活
out = self.squash(out)

return out.view(batch_size, -1, self.dim) # [B, 512, 8]

@staticmethod
def squash(s):
norm = torch.norm(s, dim=-1, keepdim=True)
return (norm ** 2 / (1 + norm ** 2)) * (s / (norm + 1e-8))

3.4 Gaze Regression Head

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class GazeRegression(nn.Module):
def __init__(self, in_dim=16):
super().__init__()
self.fc = nn.Linear(in_dim, 3) # 3D gaze vector

def forward(self, v):
"""
v: [B, out_caps, dim] - Gaze capsules
返回: [B, 3] - 3D gaze vector (pitch, yaw, roll)
"""
# 使用最长胶囊作为代表
norms = torch.norm(v, dim=-1) # [B, out_caps]
idx = torch.argmax(norms, dim=1) # [B]

# 提取代表胶囊
batch_size = v.size(0)
gaze_capsule = v[torch.arange(batch_size), idx] # [B, dim]

# 回归3D向量
gaze = self.fc(gaze_capsule) # [B, 3]
gaze = F.normalize(gaze, dim=-1) # 单位向量

return gaze

四、损失函数设计

4.1 Angular Loss

问题:MSE损失对角度误差不敏感

解决方案:直接优化角误差

$$L_{angular} = \arccos\left(\frac{g_{pred} \cdot g_{true}}{|g_{pred}| \cdot |g_{true}|}\right)$$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def angular_loss(g_pred, g_true):
"""
g_pred: [B, 3] - 预测的3D视线向量
g_true: [B, 3] - 真实的3D视线向量
返回: 角误差(度)
"""
# 归一化
g_pred = F.normalize(g_pred, dim=-1)
g_true = F.normalize(g_true, dim=-1)

# 计算余弦相似度
cos_sim = torch.sum(g_pred * g_true, dim=-1)
cos_sim = torch.clamp(cos_sim, -1.0, 1.0) # 数值稳定性

# 转换为角度
angle = torch.acos(cos_sim) * 180 / torch.pi

return angle.mean()

4.2 多任务损失

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def total_loss(gaze_pred, gaze_true, zone_pred, zone_true):
# 角度损失
L_angular = angular_loss(gaze_pred, gaze_true)

# 分类损失(视线区域)
L_class = F.cross_entropy(zone_pred, zone_true)

# 路由损失(胶囊稀疏性)
L_routing = torch.mean(torch.norm(capsule_output, dim=-1))

# 加权组合
loss = L_angular + 0.1 * L_class + 0.01 * L_routing

return loss

五、实验结果与分析

5.1 数据集性能

Gaze360数据集

方法 MAE(角误差) 参数量 推理时间
FullFace 6.53° 196.6M 50ms
RT-GENE 6.02° 45.2M 40ms
GazeTR-Pure 5.33° 23.1M 45ms
GazeCaps 5.10° 14.2M 25ms
GazeCapsNet 5.10° 11.7M 20ms

MPIIFaceGaze数据集

方法 MAE
FullFace 4.95°
Dilated-Net 4.78°
GazeTR-Hybrid 4.06°
GazeCapsNet 4.06°

5.2 消融实验

组件 MAE 说明
完整模型 5.10° 基准
- ResNet-18 5.68° +0.58°
- MobileNet v2 5.45° +0.35°
- SAR(用动态路由) 5.23° +0.13°,但延迟+15ms
- 胶囊(用FC) 5.89° +0.79°

结论

  • ResNet-18贡献最大:高级特征对精度影响显著
  • SAR贡献适中:主要优势在速度
  • 胶囊机制重要:空间建模能力不可替代

5.3 跨域泛化

训练→测试跨数据集

训练数据 测试数据 MAE
ETH-XGaze Gaze360 6.82°
Gaze360 ETH-XGaze 5.94°
ETH-XGaze MPIIFaceGaze 5.23°
Gaze360 MPIIFaceGaze 4.87°

结论:Gaze360训练的模型泛化能力更强(野外数据多样性高)。


六、嵌入式部署

6.1 模型量化

FP16量化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch

# 加载模型
model = GazeCapsNet.load_from_checkpoint('checkpoint.pth')
model.eval()

# FP16量化
model_half = model.half()

# 转换输入
dummy_input = torch.randn(1, 3, 224, 224).half()

# 测试精度
with torch.no_grad():
output = model_half(dummy_input)

INT8量化(需校准)

1
2
3
4
5
6
7
8
9
10
11
12
import torch.quantization as quant

# 准备量化
model.qconfig = quant.get_default_qconfig('qnnpack')
quant.prepare(model, inplace=True)

# 校准(用100张图片)
for image in calibration_images:
model(image)

# 转换为INT8
quant.convert(model, inplace=True)

6.2 TensorRT加速

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch.onnx
import tensorrt as trt

# 导出ONNX
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, 'gazecapsnet.onnx')

# TensorRT优化
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)

with open('gazecapsnet.onnx', 'rb') as f:
parser.parse(f.read())

config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 1GB
config.set_flag(trt.BuilderFlag.FP16)

engine = builder.build_serialized_network(network, config)

6.3 性能对比

NVIDIA Jetson Orin Nano

精度 延迟 功耗 精度损失
FP32 35ms 8W
FP16 20ms 6W 0.1°
INT8 12ms 5W 0.5°

Qualcomm 8255(QNN)

精度 延迟 功耗
FP32 28ms 4W
FP16 18ms 3W
INT8 10ms 2.5W

七、代码开源

完整代码https://github.com/yakhyo/gaze-estimation

快速使用

1
2
3
4
5
6
7
8
9
10
11
12
from gazecapsnet import GazeCapsNet

# 加载预训练模型
model = GazeCapsNet.from_pretrained('gazecapsnet_ethxgaze.pth')

# 推理
image = cv2.imread('driver.jpg')
gaze_vector = model.predict(image) # [pitch, yaw, roll]

# 视线角度
pitch, yaw, roll = gaze_vector
print(f"视线方向: pitch={pitch:.1f}°, yaw={yaw:.1f}°, roll={roll:.1f}°")

八、总结

8.1 GazeCapsNet核心贡献

  1. Self-Attention Routing:替代迭代路由,延迟降低20%
  2. 双流特征提取:MobileNet v2 + ResNet-18平衡精度与速度
  3. 端到端设计:无需人脸关键点,直接从图像预测
  4. 轻量级:11.7M参数,20ms延迟,适合边缘部署

8.2 适用场景

场景 推荐 说明
车载DMS ✅ 强烈推荐 精度、速度、部署成本平衡
VR/AR ✅ 推荐 实时性好
科研实验 ⚠️ 可选 精度不如专用硬件
移动设备 ✅ 推荐 轻量级

8.3 局限性

  1. 墨镜场景:IR图像瞳孔不可见
  2. 极端头位:>45°时精度下降
  3. 多人种:训练数据以欧美为主,亚洲人精度略低

参考文献

  1. Muksimova, S., et al. “GazeCapsNet: A Lightweight Gaze Estimation Framework.” Sensors, 2025.
  2. Sabour, S., et al. “Dynamic Routing Between Capsules.” NeurIPS, 2017.
  3. Zhang, X., et al. “ETH-XGaze: A Large Scale Dataset.” ECCV, 2020.

本文是IMS视线估计算法系列文章之一,下一篇:GazeTR Transformer详解


GazeCapsNet详解:轻量级胶囊网络实现实时视线估计
https://dapalm.com/2026/03/13/2026-03-13-GazeCapsNet详解-轻量级胶囊网络实现实时视线估计/
作者
Mars
发布于
2026年3月13日
许可协议