I want to implement the Style Transfer paper using Pytorch and the VGG19 network.
For this, I need the intermediate output features for some layers.
I named the convolutional modules: ['conv_1', 'conv_2',..., 'conv_16']
For managing the hooks and features I use this method:
class SaveOutput:
#Callable object for saving the layers outputs
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def clear(self):
self.outputs = []
def addHooksToModel(model, layerNames, hookHandles):
#Remove hooks
for hook in hookHandles:
hook.remove()
hookHandles = []
features = SaveOutput()
for name, module in model.named_modules():
if name in layerNames:
hookHandles.append(module.register_forward_hook(features))
return features, hookHandles
I want to store the features of the content and style separately:
CONTENT_LAYERS = ["conv_14"]
STYLE_LAYERS = ["conv_1","conv_3","conv_5","conv_9","conv_13"]
hook_handles_content = []
hook_handles_style = []
content_features, hook_handles_content = addHooksToModel(model, CONTENT_LAYERS, hook_handles_content)
style_features, hook_handles_style = addHooksToModel(model, STYLE_LAYERS, hook_handles_style)
I then pass the contentImage
and styleImage
through the network and I expect the content_features.outputs
to contain 1 tensor and the style_features.outputs
to contain 5 tensors.
model(contentImage)
contentImg_content = content_features.outputs
content_features.clear()
model(styleImage)
styleImg_style = style_features.outputs
style_features.clear()
But in reality, I get 1 tensor for the content_features.outputs
(as expected), but 10 tensors for the style_features.outputs
(two times the expected).
Same thing happens if I first pass the styleImage
and then the contentImage
. I get 5 tensors for the style_features.outputs(as expected), but 2 tensors for the content_features.outputs(two times the expected).
Could somebody point me in the right direction. I know I'm missing something, probably in the way Pytorch hooks are working, but I can't figure out what. Thank you!
question from:
https://stackoverflow.com/questions/65858953/pytorch-adding-hooks-to-model-for-saving-the-intermediate-layers-output-returns