1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| import torch import torch.nn as nn
class VLA_Model(nn.Module): def __init__(self, vision_dim, lang_dim, action_dim): super().__init__() self.vision_encoder = nn.Sequential( nn.Linear(vision_dim, 2048), nn.ReLU(), nn.Linear(2048, 512) ) self.lang_encoder = nn.Sequential( nn.Linear(lang_dim, 2048), nn.ReLU(), nn.Linear(2048, 512) ) self.fusion = nn.MultiheadAttention(512, num_heads=8) self.action_head = nn.Sequential( nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, action_dim) ) def forward(self, vision_input, lang_input): vision_feat = self.vision_encoder(vision_input) lang_feat = self.lang_encoder(lang_input) fused, _ = self.fusion(vision_feat, lang_feat, lang_feat) action = self.action_head(fused) return action
|