discuss.pytorch.org/t/concatenate-layer-output-with-additional-input-data/20462
사실 간단하지만 궁금했던부분이다.
torch 에서는 이런식으로 layer 를 추가해주고 forward 부분에
x = torch.cat((x1, x2), dim=1)
이런식으로 하는게 특징이다.
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.cnn = models.inception_v3(pretrained=False, aux_logits=False)
self.cnn.fc = nn.Linear(
self.cnn.fc.in_features, 20)
self.fc1 = nn.Linear(20 + 10, 60)
self.fc2 = nn.Linear(60, 5)
def forward(self, image, data):
x1 = self.cnn(image)
x2 = data
x = torch.cat((x1, x2), dim=1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
'Machine.Learning > torch' 카테고리의 다른 글
pytorch hook 걸기 (0) | 2023.04.10 |
---|