딥러닝

[파이썬] Focal Loss 로 Resnet 모델 학습하기

제갈티 2024. 9. 11. 17:34

Focal Loss의 장단점은 다음과 같습니다:

장점:

1. 클래스 불균형 해결:
- 불균형한 데이터셋에서 소수 클래스의 학습을 개선합니다.
- 객체 탐지와 같이 배경이 대부분인 문제에 특히 효과적입니다.

2. 어려운 예제에 집중:
- 쉽게 분류되는 예제보다 어려운 예제에 더 높은 가중치를 부여합니다.
- 이는 모델이 더 challenging한 케이스를 학습하는 데 도움을 줍니다.

3. 과대적합 감소:
- 쉬운 예제의 영향을 줄임으로써 과대적합을 어느 정도 방지할 수 있습니다.

4. 성능 향상:
- 특히 불균형 데이터셋에서 기존 Cross Entropy Loss보다 더 나은 성능을 보일 수 있습니다.

단점:

1. 하이퍼파라미터 조정:
- α와 γ 값을 적절히 설정해야 하며, 이는 추가적인 튜닝 작업을 필요로 합니다.

2. 계산 복잡도:
- Cross Entropy Loss에 비해 계산이 더 복잡하여 학습 시간이 늘어날 수 있습니다.

3. 균형 잡힌 데이터셋에서의 효과:
- 클래스 간 균형이 잘 잡힌 데이터셋에서는 기존 loss 함수와 비교해 큰 이점이 없을 수 있습니다.

4. 과도한 집중:
- 극단적으로 어려운 예제에 과도하게 집중할 경우, 전반적인 성능이 저하될 수 있습니다.

5. 해석의 어려움:
- Cross Entropy에 비해 loss 값의 해석이 직관적이지 않을 수 있습니다.

Focal Loss는 특정 상황, 특히 클래스 불균형 문제에서 매우 효과적일 수 있지만, 모든 상황에 적합한 것은 아닙니다. 데이터셋의 특성과 문제의 성격에 따라 적절히 선택하고 하이퍼파라미터를 조정해야 합니다.


Focal Loss 사용시 주의사항

이미지 분류 문제에서 단순히 손실 함수만 Focal Loss로 바꾼다고 해서 즉시 성능 향상이 나타나지 않을 가능성이 높습니다. 이에 대한 몇 가지 이유와 고려사항을 설명드리겠습니다:

1. 데이터셋 특성:
- Focal Loss는 주로 클래스 불균형 문제를 해결하기 위해 설계되었습니다. 만약 데이터셋이 이미 잘 균형잡혀 있다면, Focal Loss의 이점이 크지 않을 수 있습니다.

2. 문제의 복잡성:
- 이미지 분류 문제가 단순하거나 모델이 이미 높은 성능을 보이고 있다면, Focal Loss로 인한 개선 여지가 제한적일 수 있습니다.

3. 하이퍼파라미터 튜닝:
- Focal Loss의 α와 γ 파라미터를 적절히 조정해야 합니다. 이 값들을 최적화하지 않으면 오히려 성능이 저하될 수 있습니다.

4. 학습률 조정:
- Focal Loss를 사용하면 기존 Cross Entropy Loss와는 다른 스케일의 그래디언트가 생성될 수 있습니다. 따라서 학습률을 재조정해야 할 수 있습니다.

5. 모델 아키텍처:
- 일부 모델 아키텍처는 Focal Loss와 더 잘 작동할 수 있지만, 다른 아키텍처에서는 그렇지 않을 수 있습니다.

6. 학습 전략:
- Focal Loss를 효과적으로 사용하려면 전체적인 학습 전략(예: 데이터 증강, 배치 크기, 에폭 수 등)을 함께 조정해야 할 수 있습니다.

7. 평가 지표:
- Focal Loss가 전체적인 정확도를 개선하지 않더라도, 소수 클래스의 성능을 향상시킬 수 있습니다. 따라서 다양한 평가 지표를 고려해야 합니다.

8. 시간:
- 일부 경우에는 Focal Loss의 효과가 즉시 나타나지 않고, 학습이 진행됨에 따라 서서히 나타날 수 있습니다.

결론적으로, Focal Loss를 효과적으로 사용하려면 단순히 손실 함수만 바꾸는 것이 아니라, 데이터셋 분석, 하이퍼파라미터 튜닝, 전체적인 학습 전략 조정 등 종합적인 접근이 필요합니다. 또한, Focal Loss가 모든 이미지 분류 문제에 최적의 선택은 아니며, 문제의 특성에 따라 적절히 선택해야 합니다.

# Author: Sasank Chilamkurthy
# Modified by WonwooPark, 2024.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory

- 필요한 패키지들을 로딩합니다.

cudnn.benchmark = True
plt.ion()   # interactive mode

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

##data_dir = 'data/hymenoptera_data'
data_dir = "/Users/m1_16/Desktop/0909NG_labeling_NoEme/output"

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=0)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

##device = torch.device("mps" if torch.cuda.is_available() else "cpu")
device = torch.device("mps")
print(device)

- 준비된 데이타셋을 로딩합니다.

def imshow(inp, title=None):
    """Display image for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

##imshow(out, title=[class_names[x] for x in classes])

- 데이타 로더를 테스트합니다.

import torch
import torch.nn.functional as F

class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt)**self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

def train_model(model, optimizer, scheduler, num_epochs=25, alpha=1, gamma=2):
    since = time.time()
    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')
        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0

        # Initialize Focal Loss
        criterion = FocalLoss(alpha=alpha, gamma=gamma)

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)
            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()   # Set model to evaluate mode
                running_loss = 0.0
                running_corrects = 0
                # Iterate over data.
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    # zero the parameter gradients
                    optimizer.zero_grad()
                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)
                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                if phase == 'train':
                    scheduler.step()
                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.float() / dataset_sizes[phase]
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
                # deep copy the model
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)
            print()
        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')
        # load best model weights
        model.load_state_dict(torch.load(best_model_params_path))
    return model

- Focal Loss 함수를 정의 하고 모델학습 함수도 수정 합니다.

model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

- 파인튜닝할 모델을 정의하고 학습 파라미터등을 설정합니다.

model_ft = train_model(model_ft, optimizer_ft, exp_lr_scheduler, num_epochs=25, alpha=0.25, gamma=2)

torch.save(model_ft, "resnet18_focal_fatnorm.pth")

- 비로소 학습을 시작합니다.

모델 학습이 진행되는 모습니다. (파이썬 개발 도구인 Thonny 스크린샷)
m1 맥미니에서 mps 모드로 학습중이라 GPU를 잘 활용하여 학습합니다.