여러 모듈을 붙일 때 중간 Layer 에 Hook_fn  등 을 걸면, 중간 layer 에서 

output 을 불러다 사용가능하고,

뭐 explainable ai 의 activation map 을 볼때 등등 다양한 용도로 사용한다.

 

torch 의 모델은 'dict' 형태로 저장되있으나, 안에 모듈안에 서브모듈등으로 구성되어있어서.

모든 layer 를 검색하려면 계층적으로 탐색해야한다. 

 

어쩃든 머 특정 layer 에서 hook 을 걸려면 아래처럼 register_forward_hook 또는 register_backward_hook

self.DPTbackbone._modules['scratch'].output_conv[6].register_forward_hook(self.forward_hook('output_conv'))


등을 걸면되고, 

그안에 훅 funtion 을 넣어주게되면 이 function 이 해당레이어에 data 가 통과될 때, call back 처리되어 거꾸로 호출된다.

거꾸로 호출될때 특징이 input output 이 같이 호출되는 구조이다.

펑션은 아래처럼 만들어주면된다.

  def forward_hook(self,layer_name):
        def hook(module, input, output):
            self.selected_out[layer_name] = torch.squeeze(output,dim=1)
        return hook
    def get_last_hook(self):
        return self.selected_out['output_conv']

ref : https://medium.com/the-owl/using-forward-hooks-to-extract-intermediate-layer-outputs-from-a-pre-trained-model-in-pytorch-1ec17af78712

 

Using forward_hooks to Extract Intermediate Layer Outputs from a Pre-trained ResNet Model in…

Here we are again with the fourth article in the “Feature Extraction” series. This article will describe another method (and possibly the…

medium.com

https://hongl.tistory.com/157

'Machine.Learning > torch' 카테고리의 다른 글

torch - cnn concatenate 하기 feature 합치기  (0) 2021.01.24

+ Recent posts