안녕하세요 성민석입니다.
딥러닝을 하다보면, 모델 내부 구조에서 어떤 연산이 되는지 궁금하게 됩니다. 현재 Conv 연산에서 어떤 필터가 학습되어 사용되고 있는지, 혹은 그에 따른 결과는 어떻게 나오는지 다양한 부분들이 궁금합니다. 다만, 사용자가 모델 내부를 들여다보긴 쉽지 않습니다.
PyTorch에서는 Hook은 모델의 forward 및 backward pass 중에 실행되는 사용자 정의 함수입니다. 이를 통해 우리는 모델이 학습되는 중간 계산 과정이나 그래디언트를 참조하거나 수정할 수 있습니다.
여기에서는 MNIST 숫자 데이터 중 하나를 가지고 예시를 들어보겠습니다. MNIST의 가장 첫 번째 데이터인 숫자 5 이미지를 가져와보겠습니다.
1. Forward Hook
Forward Hook은 모듈의 forward pass 중에 호출됩니다. 아래는 간단한 forward hook 예시입니다. 기본적으로 Hook함수에는 module과 input 그리고 output 를 입력으로 받습니다.
forward_hook_input = None
forward_hook_output = None
def forward_hook_fn(module, input, output):
global forward_hook_input, forward_hook_output
module_name = module._get_name() # named_modules
print(f"[Forward Hook] {module_name}의 입력: {input[0].shape}") # torch.Size([1, 64, 56, 56])
print(f"[Forward Hook] {module_name}의 출력: {output[0].shape}") # torch.Size([256, 56, 56])
#############################################################################
# Input
#############################################################################
forward_hook_input = input
plt.figure(figsize=(16,16))
for i, _ in enumerate(np.random.randint(0, input[0][0].shape[0], 16), start=1):
ax = plt.subplot(4,4, i)
ax.imshow(input[0][0][i-1].detach().numpy())
plt.suptitle('[Forward Hook] Input', fontsize=20)
plt.tight_layout()
plt.show()
#############################################################################
# Output
#############################################################################
forward_hook_output = output
plt.figure(figsize=(16,16))
for i, _ in enumerate(np.random.randint(0, output[0].shape[0], 16), start=1):
ax = plt.subplot(4,4, i)
ax.imshow(output[0][i-1].detach().numpy())
plt.suptitle('[Forward Hook] Output', fontsize=20)
plt.tight_layout()
plt.show()
위의 forward_hook_fn은 모듈의 forward 메서드가 호출될 때마다 실행되며, 입력 및 출력 값을 인쇄합니다. 여기 예시에서는 제가 간단하 CNN을 직접 만들어서 사용했지만, VGG나 ResNet과 같이 널리 알려져있는 모델 구조도 사용할 수 있습니다.
# Forward Hook Example
x, y = dataloader_train[0]
display(x)
x = preprocess_image(x)
handle = model.block3.register_forward_hook(forward_hook_fn)
output = model(x)
handle.remove()
2. Backward Hook
Backward Hook은 모듈의 backward pass 중에 호출됩니다. Backward Hook을 사용하면 gradient 값에 접근하거나 이를 수정할 수 있습니다.
backward_hook_input = None
backward_hook_output = None
def backward_hook_fn(module, grad_input, grad_output):
global backward_hook_input, backward_hook_output
module_name = module._get_name() # named_modules
print(f"[Backward Hook] {module_name}의 입력: {grad_input[0].shape}")
print(f"[Backward Hook] {module_name}의 출력: {grad_output[0].shape}")
#############################################################################
# Input
#############################################################################
backward_hook_input = grad_input
plt.figure(figsize=(16,16))
for i, _ in enumerate(np.random.randint(0, backward_hook_input[0].shape[0], 16), start=1):
ax = plt.subplot(4,4, i)
ax.imshow(backward_hook_input[0][0][i-1].detach().numpy())
plt.suptitle('[Backward Hook] Input', fontsize=20)
plt.tight_layout()
plt.show()
#############################################################################
# OutPut
#############################################################################
backward_hook_output = grad_output
plt.figure(figsize=(16,16))
for i, _ in enumerate(np.random.randint(0, backward_hook_output[0].shape[0], 16), start=1):
ax = plt.subplot(4,4, i)
ax.imshow(backward_hook_output[0][0][i-1].detach().numpy())
plt.suptitle('[Backward Hook] Output', fontsize=20)
plt.tight_layout()
plt.show()
위의 backward_hook_fn은 모듈의 backward 메서드가 호출될 때마다 실행되며, 그래디언트 입력 및 출력 값을 인쇄합니다.
# Backward Hook Example
handle = model.block3.register_backward_hook(backward_hook_fn)
criterion = nn.KLDivLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer.zero_grad()
outputs = model(x)
labels = torch.eye(10)[3]
loss = criterion(outputs, labels)
loss.backward()
handle.remove()
이처럼 PyTorch에서의 Hooking은 모델의 작동 방식을 이해하는 데 도움이 되며, 내부 계산이나 그래디언트에 접근하여 커스텀 동작을 적용할 수 있게 해줍니다. 사실 더욱 많은 내용이 있지만, 일단 기본적으로 사용할 때는 이 정도만 알아도 충분할 것 같습니다.
'딥러닝' 카테고리의 다른 글
[PyTorch] Random seed 고정하기 (0) | 2024.03.07 |
---|---|
[PyTorch] 단 한줄로 PyTorch와 관련된 정보 확인하기 (0) | 2021.10.04 |