1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| import torch import torch.nn as nn
class FusionNetwork(nn.Module): """ 融合网络 """ def __init__(self): super().__init__() self.perclos_encoder = nn.Linear(1, 32) self.mar_encoder = nn.Linear(1, 32) self.head_encoder = nn.Linear(1, 32) self.fusion = nn.Sequential( nn.Linear(96, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, 32), nn.ReLU() ) self.classifier = nn.Linear(32, 1) def forward(self, perclos, mar, head_nodding): """ 前向传播 """ f_perclos = self.perclos_encoder(perclos) f_mar = self.mar_encoder(mar) f_head = self.head_encoder(head_nodding) fused = torch.cat([f_perclos, f_mar, f_head], dim=1) features = self.fusion(fused) fatigue_level = torch.sigmoid(self.classifier(features)) return fatigue_level
|