여러 모듈을 붙일 때 중간 Layer 에 Hook_fn  등 을 걸면, 중간 layer 에서 

output 을 불러다 사용가능하고,

뭐 explainable ai 의 activation map 을 볼때 등등 다양한 용도로 사용한다.

 

torch 의 모델은 'dict' 형태로 저장되있으나, 안에 모듈안에 서브모듈등으로 구성되어있어서.

모든 layer 를 검색하려면 계층적으로 탐색해야한다. 

 

어쩃든 머 특정 layer 에서 hook 을 걸려면 아래처럼 register_forward_hook 또는 register_backward_hook

self.DPTbackbone._modules['scratch'].output_conv[6].register_forward_hook(self.forward_hook('output_conv'))


등을 걸면되고, 

그안에 훅 funtion 을 넣어주게되면 이 function 이 해당레이어에 data 가 통과될 때, call back 처리되어 거꾸로 호출된다.

거꾸로 호출될때 특징이 input output 이 같이 호출되는 구조이다.

펑션은 아래처럼 만들어주면된다.

  def forward_hook(self,layer_name):
        def hook(module, input, output):
            self.selected_out[layer_name] = torch.squeeze(output,dim=1)
        return hook
    def get_last_hook(self):
        return self.selected_out['output_conv']

ref : https://medium.com/the-owl/using-forward-hooks-to-extract-intermediate-layer-outputs-from-a-pre-trained-model-in-pytorch-1ec17af78712

 

Using forward_hooks to Extract Intermediate Layer Outputs from a Pre-trained ResNet Model in…

Here we are again with the fourth article in the “Feature Extraction” series. This article will describe another method (and possibly the…

medium.com

https://hongl.tistory.com/157

'Machine.Learning > torch' 카테고리의 다른 글

torch - cnn concatenate 하기 feature 합치기  (0) 2021.01.24

discuss.pytorch.org/t/concatenate-layer-output-with-additional-input-data/20462

 

Concatenate layer output with additional input data

I want to build a CNN model that takes additional input data besides the image at a certain layer. To do that, I plan to use a standard CNN model, take one of its last FC layers, concatenate it with the additional input data and add FC layers processing bo

discuss.pytorch.org

사실 간단하지만 궁금했던부분이다.

torch 에서는 이런식으로 layer 를 추가해주고 forward 부분에 

x = torch.cat((x1, x2), dim=1

이런식으로 하는게 특징이다.

 

 

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.cnn = models.inception_v3(pretrained=False, aux_logits=False)
        self.cnn.fc = nn.Linear(
            self.cnn.fc.in_features, 20)
        
        self.fc1 = nn.Linear(20 + 10, 60)
        self.fc2 = nn.Linear(60, 5)
        
    def forward(self, image, data):
        x1 = self.cnn(image)
        x2 = data
        
        x = torch.cat((x1, x2), dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

 

'Machine.Learning > torch' 카테고리의 다른 글

pytorch hook 걸기  (0) 2023.04.10

+ Recent posts