딥러닝
[Pytorch] Yolo Txt 로 만든 데이타셋으로 faster R-CNN 학습하기
제갈티
2024. 10. 2. 09:55
import os
import torch
import torch.utils.data
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from PIL import Image
import numpy as np
import torchvision.transforms.functional as F
import random
# (이전의 Compose, RandomHorizontalFlip, RandomVerticalFlip, ColorJitter, ToTensor, YOLODataset 클래스는 그대로 유지)
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class RandomHorizontalFlip:
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, image, target):
if random.random() < self.prob:
width, height = image.size
image = image.transpose(Image.FLIP_LEFT_RIGHT)
bbox = target["boxes"]
bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
target["boxes"] = bbox
return image, target
class RandomVerticalFlip:
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, image, target):
if random.random() < self.prob:
width, height = image.size
image = image.transpose(Image.FLIP_TOP_BOTTOM)
bbox = target["boxes"]
bbox[:, [1, 3]] = height - bbox[:, [3, 1]]
target["boxes"] = bbox
return image, target
class ColorJitter:
def __init__(self, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.01):
self.color_jitter = transforms.ColorJitter(brightness, contrast, saturation, hue)
def __call__(self, image, target):
image = self.color_jitter(image)
return image, target
class ToTensor:
def __call__(self, image, target):
image = F.to_tensor(image)
return image, target
class YOLODataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_files = [f for f in os.listdir(root_dir) if f.endswith('.jpg') or f.endswith('.png')]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_path = os.path.join(self.root_dir, self.image_files[idx])
txt_path = os.path.splitext(img_path)[0] + '.txt'
image = Image.open(img_path).convert("RGB")
boxes = []
labels = []
with open(txt_path, 'r') as file:
for line in file:
class_id, x, y, w, h = map(float, line.strip().split())
x_min = x - w/2
y_min = y - h/2
x_max = x + w/2
y_max = y + h/2
boxes.append([x_min, y_min, x_max, y_max])
labels.append(int(class_id))
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
if self.transform:
image, target = self.transform(image, target)
return image, target
def get_model(num_classes):
model = fasterrcnn_resnet50_fpn_v2(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
def evaluate(model, data_loader, device):
model.eval()
total_loss = 0
with torch.no_grad():
for images, targets in data_loader:
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# 모델을 학습 모드로 일시적으로 전환하여 손실을 계산합니다.
model.train()
loss_dict = model(images, targets)
model.eval()
total_loss += sum(loss for loss in loss_dict.values())
return total_loss / len(data_loader)
def main():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)
# 데이터셋 경로와 클래스 수 설정
data_dir = "/home/elev/Videos/0930_frcnn_ssg/NG"
num_classes = 2 # 배경 클래스 포함
# 데이터 증강 정의
data_transform = Compose([
RandomHorizontalFlip(),
RandomVerticalFlip(),
ColorJitter(),
ToTensor(),
])
# 데이터셋 생성
full_dataset = YOLODataset(data_dir, transform=data_transform)
# 데이터셋을 학습용과 검증용으로 분할 (예: 80% 학습, 20% 검증)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
# 데이터 로더 생성
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))
# 모델 생성 및 디바이스 설정
model = get_model(num_classes)
model.to(device)
# 옵티마이저 설정
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
# 학습 루프
num_epochs = 100
best_val_loss = float('inf')
for epoch in range(num_epochs):
model.train()
for images, targets in train_loader:
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
# 검증 단계
val_loss = evaluate(model, val_loader, device)
# 검증 손실이 개선되었을 때만 로그 출력
if val_loss < best_val_loss:
best_val_loss = val_loss
print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss:.4f}")
# 최고 성능 모델 저장
torch.save(model.state_dict(), 'best_frcnn_res50fpn.pth')
print("Training completed.")
if __name__ == "__main__":
main()