여러 모듈을 붙일 때 중간 Layer 에 Hook_fn 등 을 걸면, 중간 layer 에서
output 을 불러다 사용가능하고,
뭐 explainable ai 의 activation map 을 볼때 등등 다양한 용도로 사용한다.
torch 의 모델은 'dict' 형태로 저장되있으나, 안에 모듈안에 서브모듈등으로 구성되어있어서.
모든 layer 를 검색하려면 계층적으로 탐색해야한다.
어쩃든 머 특정 layer 에서 hook 을 걸려면 아래처럼 register_forward_hook 또는 register_backward_hook
self.DPTbackbone._modules['scratch'].output_conv[6].register_forward_hook(self.forward_hook('output_conv'))
등을 걸면되고,
그안에 훅 funtion 을 넣어주게되면 이 function 이 해당레이어에 data 가 통과될 때, call back 처리되어 거꾸로 호출된다.
거꾸로 호출될때 특징이 input output 이 같이 호출되는 구조이다.
펑션은 아래처럼 만들어주면된다.
def forward_hook(self,layer_name):
def hook(module, input, output):
self.selected_out[layer_name] = torch.squeeze(output,dim=1)
return hook
def get_last_hook(self):
return self.selected_out['output_conv']
'Machine.Learning > torch' 카테고리의 다른 글
torch - cnn concatenate 하기 feature 합치기 (0) | 2021.01.24 |
---|