Spaces:
Sleeping
Sleeping
import streamlit as st | |
from fastai.vision import open_image, load_learner | |
from PIL import Image | |
import requests | |
import os | |
import logging | |
import torch | |
import torch.nn as nn | |
from io import BytesIO | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
class FeatureLoss(nn.Module): | |
def __init__(self, m_feat, layer_ids, layer_wgts): | |
super().__init__() | |
self.m_feat = m_feat | |
self.loss_features = [self.m_feat[i] for i in layer_ids] | |
self.hooks = [module.register_forward_hook(self.hook_fn) for module in self.loss_features] | |
self.wgts = layer_wgts | |
self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))] + [f'gram_{i}' for i in range(len(layer_ids))] | |
def hook_fn(self, module, input, output): | |
self.stored = output.detach().clone() | |
def forward(self, input, target): | |
self.m_feat(target) | |
out_feat = [self.stored.clone()] | |
self.m_feat(input) | |
in_feat = [self.stored] | |
self.feat_losses = [torch.nn.functional.mse_loss(input, target)] | |
self.feat_losses += [torch.nn.functional.mse_loss(f_in, f_out) * w for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)] | |
self.feat_losses += [torch.nn.functional.mse_loss(self.gram_matrix(f_in), self.gram_matrix(f_out)) * w**2 * 5e3 for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)] | |
self.metrics = dict(zip(self.metric_names, self.feat_losses)) | |
return sum(self.feat_losses) | |
def gram_matrix(input): | |
b, c, h, w = input.size() | |
features = input.view(b, c, h * w) | |
G = torch.bmm(features, features.transpose(1, 2)) | |
return G.div(c * h * w) | |
def fetch_image(image_path_or_url): | |
if isinstance(image_path_or_url, str) and image_path_or_url.startswith(('http://', 'https://')): | |
response = requests.get(image_path_or_url) | |
img = Image.open(BytesIO(response.content)).convert("RGB") | |
else: | |
img = Image.open(image_path_or_url).convert("RGB") | |
return img | |
def inference(image_path_or_url, learn): | |
img = fetch_image(image_path_or_url) | |
img_with_margin = Image.new('RGB', (img.width + 500, img.height + 500), (255, 255, 255)) | |
img_with_margin.paste(img, (250, 250)) | |
temp_image_path = "temp_image.jpg" | |
img_with_margin.save(temp_image_path, quality=95) | |
img_fastai = open_image(temp_image_path) | |
_, img_hr, _ = learn.predict(img_fastai) | |
return tensor_to_pil(img_hr) | |
def tensor_to_pil(tensor): | |
tensor = tensor.cpu().clamp(0, 1) | |
array = tensor.numpy().transpose(1, 2, 0) | |
return Image.fromarray((array * 255).astype('uint8')) | |
def load_model(model_url, model_file_path): | |
if not os.path.exists(model_file_path): | |
with st.spinner('Downloading model...'): | |
response = requests.get(model_url) | |
with open(model_file_path, 'wb') as f: | |
f.write(response.content) | |
st.success('Model downloaded successfully!') | |
learn = load_learner(os.path.dirname(model_file_path), model_file_path) | |
return learn | |
# Custom CSS | |
def set_css(style): | |
st.markdown(f"<style>{style}</style>", unsafe_allow_html=True) | |
# Combined dark mode styles | |
combined_css = """ | |
.main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; } | |
.block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); } | |
.stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; } | |
.stSpinner { color: #4CAF50; } | |
.title { | |
font-size: 3rem; | |
font-weight: bold; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
} | |
.colorful-text { | |
background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
} | |
.black-text { | |
color: black; | |
} | |
.white-text { | |
color: white; | |
} | |
.small-input .stTextInput>div>input { | |
height: 2rem; | |
font-size: 0.9rem; | |
} | |
.small-file-uploader .stFileUploader>div>div { | |
height: 2rem; | |
font-size: 0.9rem; | |
} | |
.custom-text { | |
font-size: 1.2rem; | |
color: #feb47b; | |
text-align: center; | |
margin-top: -20px; | |
margin-bottom: 20px; | |
} | |
""" | |
# Streamlit application | |
st.set_page_config(layout="wide") | |
st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True) | |
st.markdown('<div class="title"><span class="colorful-text">Image</span> <span class="black-text">to</span><span class="white-text">Drawing</span>', unsafe_allow_html=True) | |
st.markdown('<div class="custom-text">Jana\'s embroidery studio. Convert Photo\'s to Drawings using AI</div>', unsafe_allow_html=True) | |
# Download and load the model | |
MODEL_URL = "https://huggingface.co/Hammad712/image2sketch/resolve/main/image2sketch.pkl" | |
MODEL_FILE_PATH = 'image2sketch.pkl' | |
if 'learn' not in st.session_state: | |
st.session_state['learn'] = load_model(MODEL_URL, MODEL_FILE_PATH) | |
learn = st.session_state['learn'] | |
# Input for image URL or path | |
with st.expander("Input Options", expanded=True): | |
image_path_or_url = st.text_input("Enter image URL", "", key="image_url", placeholder="Enter image URL", help="Enter the URL of the image to convert") | |
uploaded_file = st.file_uploader("Or upload an image", type=["jpg", "jpeg", "png", "webp"], key="upload_file", help="Upload an image file to convert") | |
if uploaded_file is not None: | |
image_path_or_url = uploaded_file | |
# Run inference button | |
if st.button("Convert"): | |
if image_path_or_url: | |
with st.spinner('Processing...'): | |
try: | |
high_res_image = inference(image_path_or_url, learn) | |
original_image = fetch_image(image_path_or_url) | |
# Display original and high-res images side by side | |
st.markdown("### Result") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(original_image, caption='Original Image', use_column_width=True) | |
with col2: | |
st.image(high_res_image, caption='Drawing', use_column_width=True) | |
# Provide a download button for the generated image | |
img_byte_arr = BytesIO() | |
high_res_image.save(img_byte_arr, format='JPEG') | |
img_byte_arr = img_byte_arr.getvalue() | |
st.download_button( | |
label="Download Drawing", | |
data=img_byte_arr, | |
file_name="Drawing.jpg", | |
mime="image/jpeg" | |
) | |
st.success("Image processed successfully!") | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
logging.error("Error during inference", exc_info=True) | |
else: | |
st.error("Please enter a valid image path or URL.") | |