import streamlit as st from fastai.vision import open_image, load_learner from PIL import Image import requests import os import logging import torch import torch.nn as nn from io import BytesIO # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') class FeatureLoss(nn.Module): def __init__(self, m_feat, layer_ids, layer_wgts): super().__init__() self.m_feat = m_feat self.loss_features = [self.m_feat[i] for i in layer_ids] self.hooks = [module.register_forward_hook(self.hook_fn) for module in self.loss_features] self.wgts = layer_wgts self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))] + [f'gram_{i}' for i in range(len(layer_ids))] def hook_fn(self, module, input, output): self.stored = output.detach().clone() def forward(self, input, target): self.m_feat(target) out_feat = [self.stored.clone()] self.m_feat(input) in_feat = [self.stored] self.feat_losses = [torch.nn.functional.mse_loss(input, target)] self.feat_losses += [torch.nn.functional.mse_loss(f_in, f_out) * w for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)] self.feat_losses += [torch.nn.functional.mse_loss(self.gram_matrix(f_in), self.gram_matrix(f_out)) * w**2 * 5e3 for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)] self.metrics = dict(zip(self.metric_names, self.feat_losses)) return sum(self.feat_losses) @staticmethod def gram_matrix(input): b, c, h, w = input.size() features = input.view(b, c, h * w) G = torch.bmm(features, features.transpose(1, 2)) return G.div(c * h * w) def fetch_image(image_path_or_url): if isinstance(image_path_or_url, str) and image_path_or_url.startswith(('http://', 'https://')): response = requests.get(image_path_or_url) img = Image.open(BytesIO(response.content)).convert("RGB") else: img = Image.open(image_path_or_url).convert("RGB") return img def inference(image_path_or_url, learn): img = fetch_image(image_path_or_url) img_with_margin = Image.new('RGB', (img.width + 500, img.height + 500), (255, 255, 255)) img_with_margin.paste(img, (250, 250)) temp_image_path = "temp_image.jpg" img_with_margin.save(temp_image_path, quality=95) img_fastai = open_image(temp_image_path) _, img_hr, _ = learn.predict(img_fastai) return tensor_to_pil(img_hr) def tensor_to_pil(tensor): tensor = tensor.cpu().clamp(0, 1) array = tensor.numpy().transpose(1, 2, 0) return Image.fromarray((array * 255).astype('uint8')) def load_model(model_url, model_file_path): if not os.path.exists(model_file_path): with st.spinner('Downloading model...'): response = requests.get(model_url) with open(model_file_path, 'wb') as f: f.write(response.content) st.success('Model downloaded successfully!') learn = load_learner(os.path.dirname(model_file_path), model_file_path) return learn # Custom CSS def set_css(style): st.markdown(f"", unsafe_allow_html=True) # Combined dark mode styles combined_css = """ .main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; } .block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); } .stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; } .stSpinner { color: #4CAF50; } .title { font-size: 3rem; font-weight: bold; display: flex; align-items: center; justify-content: center; } .colorful-text { background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .black-text { color: black; } .white-text { color: white; } .small-input .stTextInput>div>input { height: 2rem; font-size: 0.9rem; } .small-file-uploader .stFileUploader>div>div { height: 2rem; font-size: 0.9rem; } .custom-text { font-size: 1.2rem; color: #feb47b; text-align: center; margin-top: -20px; margin-bottom: 20px; } """ # Streamlit application st.set_page_config(layout="wide") st.markdown(f"", unsafe_allow_html=True) st.markdown('