딥러닝

알파채널을 적용한 GradCAM 히트맵

제갈티 2024. 9. 10. 20:07

 

GradCAM으로 오버레이한 밥먹는 댕댕이들.GIF

 

GradCAM의 히트맵에 슈도컬러가 아닌 알파채널을 적용한 결과입니다.

 

- 이런식으로 그래드캠의 히트맵용 Pseudo Color를 투명도를 지정하는 알파채널값으로 대치하면 

주목하는 대상물체만을 보여주는 몉보기 효과 같은 동영상이 만들어집니다.

- 이것은 미디어 아트나 특수효과 분야에서 쓰임새가 있을거 같습니다.

 

import torch
from torchvision import models, transforms
from PIL import Image
import numpy as np
import cv2
from imageio import mimread, mimsave

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_full_backward_hook(self.save_gradient)
        
    def save_activation(self, module, input, output):
        self.activations = output
        
    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]
        
    def generate_cam(self, input_image, target_classes):
        model_output = self.model(input_image)
        self.model.zero_grad()
        
        combined_cam = torch.zeros(self.activations.shape[2:]).to(input_image.device)
        for target_class in target_classes:
            target = model_output[0, target_class]
            target.backward(retain_graph=True)
            
            gradients = self.gradients[0].cpu().data.numpy()
            activations = self.activations[0].cpu().data.numpy()
            
            weights = np.mean(gradients, axis=(1, 2))
            cam = np.sum(weights[:, np.newaxis, np.newaxis] * activations, axis=0)
            
            cam = np.maximum(cam, 0)
            combined_cam += torch.from_numpy(cam).to(input_image.device)
        
        combined_cam = combined_cam.cpu().numpy()
        combined_cam = cv2.resize(combined_cam, (320, 240))  ## -------------------------
        combined_cam = combined_cam - np.min(combined_cam)
        combined_cam = combined_cam / np.max(combined_cam)
        return combined_cam

def preprocess_image(image):
    transform = transforms.Compose([
        transforms.Resize((320, 240)),  ## ------------------------------
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0)

def create_smooth_mask(heatmap, low_threshold=0.4, high_threshold=0.7):
    heatmap_normalized = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
    mask = np.clip((heatmap_normalized - low_threshold) / (high_threshold - low_threshold), 0, 1)
    mask = cv2.GaussianBlur(mask, (15, 15), 0)
    return mask

def apply_gradcam_mask(image, mask):
    mask_3channel = np.repeat(mask[:, :, np.newaxis], 3, axis=2)
    result = image * mask_3channel + np.zeros_like(image) * (1 - mask_3channel)
    return result.astype(np.uint8)

def process_gif(input_gif, output_gif):
    device = torch.device("mps")
    model = models.resnet18(pretrained=True).to(device)  ## ============================
    model.eval()
    
    grad_cam = GradCAM(model, model.layer4[-1])
    
    frames = mimread(input_gif)
    processed_frames = []
    
    dog_classes = range(151, 269)  # 151부터 268까지의 개 클래스
    
    for frame in frames:
        pil_image = Image.fromarray(frame).convert('RGB')
        input_tensor = preprocess_image(pil_image).to(device)
        
        combined_cam = grad_cam.generate_cam(input_tensor, dog_classes)
        
        frame_resized = cv2.resize(np.array(pil_image), (320, 240))  ## ------------------
        
        mask = create_smooth_mask(combined_cam, low_threshold=0.4, high_threshold=0.7)
        result_frame = apply_gradcam_mask(frame_resized, mask)
        
        processed_frames.append(result_frame)
    
    mimsave(output_gif, processed_frames, format='GIF', duration=100, loop=0)

# 사용 예
input_gif = "/Users/m1_16/Desktop/gaenyang.gif"
output_gif = '/Users/m1_16/Desktop/dog_output.gif'
process_gif(input_gif, output_gif)