카테고리 없음

Streamlit 으로 파이썬 실행결과 GUI를 웹서비스 하기

제갈티 2025. 5. 28. 10:10
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import fft2, ifft2, fftshift, ifftshift
from scipy.ndimage import gaussian_filter
import streamlit as st
from PIL import Image
import cv2
import warnings
warnings.filterwarnings('ignore')

# MPS 설정 (Apple Silicon Mac용)
import torch
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS 사용")
else:
    device = torch.device("cpu")
    print("CPU 사용")

class FractionalFourierTransform:
    """개선된 분수 푸리에 변환 클래스"""
    
    def __init__(self):
        self.device = device
    
    def frft2d(self, image, ax, ay):
        """2D 분수 푸리에 변환 - 더 정확한 구현"""
        # Tensor로 변환
        if isinstance(image, np.ndarray):
            image_tensor = torch.from_numpy(image.astype(np.complex128)).to(self.device)
        else:
            image_tensor = image.to(self.device)
        
        # x 방향 FRFT
        temp = self._frft_1d_improved(image_tensor, ax, axis=1)
        # y 방향 FRFT
        result = self._frft_1d_improved(temp, ay, axis=0)
        
        return result.cpu().numpy()
    
    def _frft_1d_improved(self, data, alpha, axis=0):
        """개선된 1D 분수 푸리에 변환"""
        alpha = alpha % 4  # 0-4 범위로 정규화
        
        # 특수한 경우들
        if abs(alpha) < 1e-10:
            return data
        elif abs(abs(alpha) - 2) < 1e-10:
            return torch.flip(data, dims=[axis])
        elif abs(abs(alpha) - 1) < 1e-10:
            if alpha > 0:
                return torch.fft.fftshift(torch.fft.fft2(data), dim=axis)
            else:
                return torch.fft.ifftshift(torch.fft.ifft2(data), dim=axis)
        elif abs(abs(alpha) - 3) < 1e-10:
            if alpha > 0:
                return torch.fft.ifft2(torch.fft.ifftshift(data, dim=axis))
            else:
                return torch.fft.fft2(torch.fft.fftshift(data, dim=axis))
        
        # 일반적인 분수 변환
        return self._general_frft(data, alpha, axis)
    
    def _general_frft(self, data, alpha, axis):
        """일반적인 분수 푸리에 변환 구현"""
        N = data.shape[axis]
        phi = alpha * np.pi / 2
        
        if abs(np.sin(phi)) < 1e-10:
            return data
        
        # 샘플링된 chirp 함수를 사용한 FRFT 구현
        n = torch.arange(N, dtype=torch.float64, device=self.device)
        
        # Pre-chirp multiplication
        if axis == 0:
            pre_chirp = torch.exp(1j * np.pi * (n**2) * np.cos(phi) / np.sin(phi))
            pre_chirp = pre_chirp.unsqueeze(1)
        else:
            pre_chirp = torch.exp(1j * np.pi * (n**2) * np.cos(phi) / np.sin(phi))
            pre_chirp = pre_chirp.unsqueeze(0)
        
        temp1 = data * pre_chirp
        
        # FFT
        if axis == 0:
            temp2 = torch.fft.fft(temp1, dim=0)
        else:
            temp2 = torch.fft.fft(temp1, dim=1)
        
        # Post-chirp multiplication
        scaling = torch.exp(-1j * np.pi * np.sign(np.sin(phi)) / 4) / torch.sqrt(torch.abs(torch.sin(torch.tensor(phi))))
        
        if axis == 0:
            post_chirp = torch.exp(1j * np.pi * (n**2) * np.cos(phi) / np.sin(phi))
            post_chirp = post_chirp.unsqueeze(1)
        else:
            post_chirp = torch.exp(1j * np.pi * (n**2) * np.cos(phi) / np.sin(phi))
            post_chirp = post_chirp.unsqueeze(0)
        
        result = scaling * temp2 * post_chirp
        
        return result

class ImprovedNoiseRemovalApp:
    def __init__(self):
        self.frft = FractionalFourierTransform()
    
    def generate_airplane_like_image(self, size=256):
        """비행기와 유사한 이미지 생성 (논문과 유사)"""
        x, y = np.meshgrid(np.linspace(-2, 2, size), np.linspace(-2, 2, size))
        
        # 비행기 몸체 (타원)
        body = np.exp(-((x/0.8)**2 + (y/0.3)**2) * 2)
        
        # 날개 (수평선)
        wings = np.exp(-((y/0.05)**2 + (np.abs(x) - 0.5)**2) * 8)
        
        # 프로펠러 (원형)
        prop_x, prop_y = -1.2, 0
        propeller = np.exp(-((x - prop_x)**2 + (y - prop_y)**2) * 20)
        
        # 프로펠러 블레이드
        blade1 = np.exp(-(((x - prop_x)*np.cos(np.pi/4) - (y - prop_y)*np.sin(np.pi/4))/0.1)**2 * 50)
        blade2 = np.exp(-(((x - prop_x)*np.cos(-np.pi/4) - (y - prop_y)*np.sin(-np.pi/4))/0.1)**2 * 50)
        
        # 조합
        airplane = body + wings + propeller + blade1 + blade2
        
        # 배경 텍스처 추가
        texture = 0.1 * np.sin(5*x) * np.sin(5*y)
        
        # 정규화
        result = airplane + texture
        result = (result - result.min()) / (result.max() - result.min())
        
        return result
    
    def add_structured_noise(self, image, noise_strength=1.0):
        """논문과 유사한 구조화된 노이즈 추가"""
        h, w = image.shape
        
        # 1. 가우시안 노이즈
        gaussian_noise = np.random.normal(0, noise_strength * 0.05, (h, w))
        
        # 2. 강한 줄무늬 노이즈 (논문과 유사)
        x_coords = np.arange(w)
        y_coords = np.arange(h)
        X, Y = np.meshgrid(x_coords, y_coords)
        
        # 대각선 줄무늬 패턴 (논문의 그림과 유사)
        stripe_noise = noise_strength * 0.8 * (
            np.sin(0.3 * (X + Y)) + 
            np.sin(0.2 * (X - Y)) +
            np.sin(0.15 * X) * 0.5
        )
        
        # 3. 고주파 노이즈
        high_freq_noise = noise_strength * 0.3 * np.random.normal(0, 1, (h, w))
        high_freq_noise = gaussian_filter(high_freq_noise, sigma=0.5)
        
        # 노이즈 결합
        total_noise = gaussian_noise + stripe_noise + high_freq_noise
        
        # 노이즈가 추가된 이미지
        noisy_image = image + total_noise
        
        # SNR 계산
        signal_power = np.var(image)
        noise_power = np.var(total_noise)
        snr = 10 * np.log10(signal_power / noise_power) if noise_power > 0 else float('inf')
        
        return noisy_image, snr
    
    def fourier_denoise(self, noisy_image):
        """개선된 일반 푸리에 변환 기반 노이즈 제거"""
        # 복소수로 변환
        complex_image = noisy_image.astype(np.complex128)
        
        # FFT 적용
        fft_image = fft2(complex_image)
        fft_shifted = fftshift(fft_image)
        
        # 적응적 필터링
        magnitude = np.abs(fft_shifted)
        
        # Wiener 필터와 유사한 접근
        noise_var = np.var(noisy_image) * 0.1  # 추정된 노이즈 분산
        wiener_filter = magnitude**2 / (magnitude**2 + noise_var)
        
        # 로우패스 필터와 결합
        h, w = noisy_image.shape
        center_h, center_w = h // 2, w // 2
        y, x = np.ogrid[:h, :w]
        
        # 부드러운 로우패스 필터
        sigma = min(h, w) * 0.25
        lowpass = np.exp(-((x - center_w)**2 + (y - center_h)**2) / (2 * sigma**2))
        
        # 필터 결합
        combined_filter = wiener_filter * lowpass
        filtered_fft = fft_shifted * combined_filter
        
        # 역변환
        ifft_shifted = ifftshift(filtered_fft)
        denoised = ifft2(ifft_shifted)
        
        return np.real(denoised)
    
    def frft_denoise_optimized(self, noisy_image, original_image):
        """최적화된 FRFT 기반 노이즈 제거"""
        best_mse = float('inf')
        best_result = noisy_image.copy()
        best_ax = best_ay = 0
        
        # 논문에서 제시한 범위를 중심으로 탐색
        ax_range = np.arange(-1.0, 1.0, 0.2)
        ay_range = np.arange(-1.0, 1.0, 0.2)
        
        # 진행률 표시
        total_combinations = len(ax_range) * len(ay_range)
        current_combination = 0
        
        progress_placeholder = st.empty()
        
        for ax in ax_range:
            for ay in ay_range:
                try:
                    current_combination += 1
                    progress = current_combination / total_combinations
                    progress_placeholder.progress(progress)
                    
                    # FRFT 적용
                    complex_image = noisy_image.astype(np.complex128)
                    frft_image = self.frft.frft2d(complex_image, ax, ay)
                    
                    # 최적 필터링
                    filtered_frft = self.apply_advanced_filter(frft_image, noisy_image)
                    
                    # 역 FRFT
                    denoised_complex = self.frft.frft2d(filtered_frft, -ax, -ay)
                    denoised = np.real(denoised_complex)
                    
                    # MSE 계산
                    mse = self.calculate_mse(original_image, denoised)
                    
                    if mse < best_mse:
                        best_mse = mse
                        best_result = denoised
                        best_ax = ax
                        best_ay = ay
                        
                except Exception as e:
                    print(f"Error with ax={ax}, ay={ay}: {e}")
                    continue
        
        progress_placeholder.empty()
        return best_result, best_ax, best_ay
    
    def apply_advanced_filter(self, frft_image, original_noisy):
        """고급 적응적 필터링"""
        magnitude = np.abs(frft_image)
        phase = np.angle(frft_image)
        
        # 통계 기반 임계값
        mean_mag = np.mean(magnitude)
        std_mag = np.std(magnitude)
        
        # 적응적 임계값
        threshold = mean_mag + 0.5 * std_mag
        
        # 부드러운 마스크 생성
        mask = 1 / (1 + np.exp(-10 * (magnitude - threshold) / std_mag))
        
        # Wiener 필터링 추가
        noise_var = np.var(original_noisy) * 0.05
        wiener_factor = magnitude**2 / (magnitude**2 + noise_var)
        
        # 필터 결합
        combined_filter = mask * wiener_factor
        
        # 필터 적용
        filtered_magnitude = magnitude * combined_filter
        filtered_frft = filtered_magnitude * np.exp(1j * phase)
        
        return filtered_frft
    
    def calculate_mse(self, original, processed):
        """MSE 계산"""
        if original.shape != processed.shape:
            return float('inf')
        
        # 정규화하여 비교
        orig_norm = (original - original.min()) / (original.max() - original.min())
        proc_norm = (processed - processed.min()) / (processed.max() - processed.min())
        
        return np.mean((orig_norm - proc_norm) ** 2)

def main():
    st.set_page_config(page_title="개선된 FRFT 노이즈 제거", layout="wide")
    
    st.title("개선된 FRFT 기반 이미지 노이즈 제거")
    st.markdown("### 논문 'Applications of the fractional Fourier transform' 정확한 구현")
    
    app = ImprovedNoiseRemovalApp()
    
    # 사이드바 설정
    st.sidebar.header("설정")
    
    # 이미지 옵션
    image_option = st.sidebar.selectbox(
        "이미지 선택",
        ["논문과 유사한 비행기 이미지", "이미지 업로드"]
    )
    
    original_image = None
    
    if image_option == "논문과 유사한 비행기 이미지":
        image_size = st.sidebar.slider("이미지 크기", 128, 512, 256)
        if st.sidebar.button("비행기 이미지 생성"):
            original_image = app.generate_airplane_like_image(image_size)
            st.session_state.original_image = original_image
            st.sidebar.success("비행기 이미지 생성 완료!")
    
    elif image_option == "이미지 업로드":
        uploaded_file = st.sidebar.file_uploader(
            "이미지 업로드",
            type=['png', 'jpg', 'jpeg', 'bmp']
        )
        
        if uploaded_file is not None:
            image = Image.open(uploaded_file).convert('L')
            image = np.array(image).astype(np.float64) / 255.0
            
            # 크기 조정
            if max(image.shape) > 400:
                scale = 400 / max(image.shape)
                new_size = (int(image.shape[1] * scale), int(image.shape[0] * scale))
                image = cv2.resize(image, new_size)
            
            original_image = image
            st.session_state.original_image = original_image
    
    # 세션에서 이미지 가져오기
    if 'original_image' in st.session_state:
        original_image = st.session_state.original_image
    
    if original_image is not None:
        # 노이즈 설정
        noise_strength = st.sidebar.slider("노이즈 강도", 0.5, 2.0, 1.0, 0.1)
        
        if st.sidebar.button("🚀 개선된 노이즈 제거 실행"):
            # 노이즈 추가
            with st.spinner("구조화된 노이즈 추가 중..."):
                noisy_image, snr = app.add_structured_noise(original_image, noise_strength)
            
            st.success(f"노이즈 추가 완료! SNR: {snr:.2f} dB")
            
            # 디노이징 처리
            col1, col2 = st.columns(2)
            
            with col1:
                with st.spinner("일반 푸리에 변환 디노이징..."):
                    fft_result = app.fourier_denoise(noisy_image)
                    fft_mse = app.calculate_mse(original_image, fft_result)
            
            with col2:
                with st.spinner("FRFT 최적 노이즈 제거... (시간이 좀 걸립니다)"):
                    frft_result, optimal_ax, optimal_ay = app.frft_denoise_optimized(noisy_image, original_image)
                    frft_mse = app.calculate_mse(original_image, frft_result)
            
            # 개선율 계산
            improvement = fft_mse / frft_mse if frft_mse > 0 else float('inf')
            
            # 결과 표시
            col1, col2 = st.columns(2)
            
            with col1:
                st.subheader("원본 이미지")
                fig1, ax1 = plt.subplots(figsize=(6, 6))
                ax1.imshow(original_image, cmap='gray', vmin=0, vmax=1)
                ax1.set_title("Original Image")
                ax1.axis('off')
                st.pyplot(fig1)
                
                st.subheader("일반 푸리에 변환 디노이징")
                fig3, ax3 = plt.subplots(figsize=(6, 6))
                ax3.imshow(fft_result, cmap='gray')
                ax3.set_title(f"FFT Denoising (MSE: {fft_mse:.6f})")
                ax3.axis('off')
                st.pyplot(fig3)
            
            with col2:
                st.subheader("노이즈 추가된 이미지")
                fig2, ax2 = plt.subplots(figsize=(6, 6))
                ax2.imshow(noisy_image, cmap='gray')
                ax2.set_title(f"Noisy Image (SNR: {snr:.2f} dB)")
                ax2.axis('off')
                st.pyplot(fig2)
                
                st.subheader("FRFT 최적 노이즈 제거")
                fig4, ax4 = plt.subplots(figsize=(6, 6))
                ax4.imshow(frft_result, cmap='gray')
                ax4.set_title(f"FRFT Denoising (MSE: {frft_mse:.6f})")
                ax4.axis('off')
                st.pyplot(fig4)
            
            # 결과 요약
            st.markdown("---")
            st.header("🎯 결과 분석")
            
            col1, col2, col3 = st.columns(3)
            with col1:
                st.metric("FFT MSE", f"{fft_mse:.6f}")
            with col2:
                st.metric("FRFT MSE", f"{frft_mse:.6f}")
            with col3:
                st.metric("🚀 개선율", f"{improvement:.2f}배")
            
            st.info(f"🎯 최적 FRFT 파라미터: ax = {optimal_ax:.2f}, ay = {optimal_ay:.2f}")
            
            # 성능 평가
            if improvement > 3:
                st.success(f"🎉 FRFT 방법이 FFT보다 {improvement:.1f}배 우수한 성능을 보입니다!")
                st.balloons()
            elif improvement > 1.5:
                st.success(f"✅ FRFT 방법이 FFT보다 {improvement:.1f}배 개선된 결과를 보입니다.")
            else:
                st.warning(f"⚠️ 이 경우 개선율이 {improvement:.1f}배로 미미합니다. 노이즈 강도를 조정해보세요.")
    
    else:
        st.info("👈 좌측 사이드바에서 이미지를 생성하거나 업로드하세요.")
        
        # 사용 안내
        st.markdown("""
        ### 📚 사용 방법
        1. **이미지 선택**: 논문과 유사한 비행기 이미지 생성 또는 직접 업로드
        2. **노이즈 강도 조정**: 0.5 (약함) ~ 2.0 (강함)
        3. **실행**: "개선된 노이즈 제거 실행" 버튼 클릭
        
        ### 🔬 구현 특징
        - **정확한 FRFT**: PyTorch 기반 MPS 가속화
        - **구조화된 노이즈**: 논문과 유사한 줄무늬 + 가우시안 노이즈
        - **적응적 필터링**: Wiener 필터 + 통계적 임계값
        - **최적화 탐색**: 다양한 분수 차수 조합 테스트
        """)

if __name__ == "__main__":
    main()