딥러닝

[Pytorch] Yolo Txt 로 만든 데이타셋으로 faster R-CNN 학습하기

제갈티 2024. 10. 2. 09:55
import os
import torch
import torch.utils.data
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from PIL import Image
import numpy as np
import torchvision.transforms.functional as F
import random

# (이전의 Compose, RandomHorizontalFlip, RandomVerticalFlip, ColorJitter, ToTensor, YOLODataset 클래스는 그대로 유지)
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class RandomHorizontalFlip:
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            width, height = image.size
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
        return image, target

class RandomVerticalFlip:
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            width, height = image.size
            image = image.transpose(Image.FLIP_TOP_BOTTOM)
            bbox = target["boxes"]
            bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
            target["boxes"] = bbox
        return image, target

class ColorJitter:
    def __init__(self, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.01):
        self.color_jitter = transforms.ColorJitter(brightness, contrast, saturation, hue)

    def __call__(self, image, target):
        image = self.color_jitter(image)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

class YOLODataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith('.jpg') or f.endswith('.png')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        txt_path = os.path.splitext(img_path)[0] + '.txt'
        image = Image.open(img_path).convert("RGB")
        
        boxes = []
        labels = []
        
        with open(txt_path, 'r') as file:
            for line in file:
                class_id, x, y, w, h = map(float, line.strip().split())
                x_min = x - w/2
                y_min = y - h/2
                x_max = x + w/2
                y_max = y + h/2
                boxes.append([x_min, y_min, x_max, y_max])
                labels.append(int(class_id))

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels

        if self.transform:
            image, target = self.transform(image, target)

        return image, target


def get_model(num_classes):
    model = fasterrcnn_resnet50_fpn_v2(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

def evaluate(model, data_loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for images, targets in data_loader:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            # 모델을 학습 모드로 일시적으로 전환하여 손실을 계산합니다.
            model.train()
            loss_dict = model(images, targets)
            model.eval()
            
            total_loss += sum(loss for loss in loss_dict.values())
    return total_loss / len(data_loader)

def main():
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print(device) 
    
    # 데이터셋 경로와 클래스 수 설정
    data_dir = "/home/elev/Videos/0930_frcnn_ssg/NG"
    num_classes = 2  # 배경 클래스 포함

    # 데이터 증강 정의
    data_transform = Compose([
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        ColorJitter(),
        ToTensor(),
    ])

    # 데이터셋 생성
    full_dataset = YOLODataset(data_dir, transform=data_transform)
    
    # 데이터셋을 학습용과 검증용으로 분할 (예: 80% 학습, 20% 검증)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    # 데이터 로더 생성
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

    # 모델 생성 및 디바이스 설정
    model = get_model(num_classes)
    model.to(device)

    # 옵티마이저 설정
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

    # 학습 루프
    num_epochs = 100
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        for images, targets in train_loader:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

        # 검증 단계
        val_loss = evaluate(model, val_loader, device)
        
        # 검증 손실이 개선되었을 때만 로그 출력
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}")
            
            # 최고 성능 모델 저장
            torch.save(model.state_dict(), 'best_frcnn_res50fpn.pth')

    print("Training completed.")

if __name__ == "__main__":
    main()