image2sketch / app.py
Hammad712's picture
Update app.py
7ec7952 verified
raw
history blame contribute delete
No virus
7.2 kB
import streamlit as st
from fastai.vision import open_image, load_learner
from PIL import Image
import requests
import os
import logging
import torch
import torch.nn as nn
from io import BytesIO
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class FeatureLoss(nn.Module):
def __init__(self, m_feat, layer_ids, layer_wgts):
super().__init__()
self.m_feat = m_feat
self.loss_features = [self.m_feat[i] for i in layer_ids]
self.hooks = [module.register_forward_hook(self.hook_fn) for module in self.loss_features]
self.wgts = layer_wgts
self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))] + [f'gram_{i}' for i in range(len(layer_ids))]
def hook_fn(self, module, input, output):
self.stored = output.detach().clone()
def forward(self, input, target):
self.m_feat(target)
out_feat = [self.stored.clone()]
self.m_feat(input)
in_feat = [self.stored]
self.feat_losses = [torch.nn.functional.mse_loss(input, target)]
self.feat_losses += [torch.nn.functional.mse_loss(f_in, f_out) * w for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
self.feat_losses += [torch.nn.functional.mse_loss(self.gram_matrix(f_in), self.gram_matrix(f_out)) * w**2 * 5e3 for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
self.metrics = dict(zip(self.metric_names, self.feat_losses))
return sum(self.feat_losses)
@staticmethod
def gram_matrix(input):
b, c, h, w = input.size()
features = input.view(b, c, h * w)
G = torch.bmm(features, features.transpose(1, 2))
return G.div(c * h * w)
def fetch_image(image_path_or_url):
if isinstance(image_path_or_url, str) and image_path_or_url.startswith(('http://', 'https://')):
response = requests.get(image_path_or_url)
img = Image.open(BytesIO(response.content)).convert("RGB")
else:
img = Image.open(image_path_or_url).convert("RGB")
return img
def inference(image_path_or_url, learn):
img = fetch_image(image_path_or_url)
img_with_margin = Image.new('RGB', (img.width + 500, img.height + 500), (255, 255, 255))
img_with_margin.paste(img, (250, 250))
temp_image_path = "temp_image.jpg"
img_with_margin.save(temp_image_path, quality=95)
img_fastai = open_image(temp_image_path)
_, img_hr, _ = learn.predict(img_fastai)
return tensor_to_pil(img_hr)
def tensor_to_pil(tensor):
tensor = tensor.cpu().clamp(0, 1)
array = tensor.numpy().transpose(1, 2, 0)
return Image.fromarray((array * 255).astype('uint8'))
def load_model(model_url, model_file_path):
if not os.path.exists(model_file_path):
with st.spinner('Downloading model...'):
response = requests.get(model_url)
with open(model_file_path, 'wb') as f:
f.write(response.content)
st.success('Model downloaded successfully!')
learn = load_learner(os.path.dirname(model_file_path), model_file_path)
return learn
# Custom CSS
def set_css(style):
st.markdown(f"<style>{style}</style>", unsafe_allow_html=True)
# Combined dark mode styles
combined_css = """
.main, .sidebar .sidebar-content { background-color: #1c1c1c; color: #f0f2f6; }
.block-container { padding: 1rem 2rem; background-color: #333; border-radius: 10px; box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.5); }
.stButton>button, .stDownloadButton>button { background: linear-gradient(135deg, #ff7e5f, #feb47b); color: white; border: none; padding: 10px 24px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; margin: 4px 2px; cursor: pointer; border-radius: 5px; }
.stSpinner { color: #4CAF50; }
.title {
font-size: 3rem;
font-weight: bold;
display: flex;
align-items: center;
justify-content: center;
}
.colorful-text {
background: -webkit-linear-gradient(135deg, #ff7e5f, #feb47b);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
}
.black-text {
color: black;
}
.white-text {
color: white;
}
.small-input .stTextInput>div>input {
height: 2rem;
font-size: 0.9rem;
}
.small-file-uploader .stFileUploader>div>div {
height: 2rem;
font-size: 0.9rem;
}
.custom-text {
font-size: 1.2rem;
color: #feb47b;
text-align: center;
margin-top: -20px;
margin-bottom: 20px;
}
"""
# Streamlit application
st.set_page_config(layout="wide")
st.markdown(f"<style>{combined_css}</style>", unsafe_allow_html=True)
st.markdown('<div class="title"><span class="colorful-text">Image</span> <span class="black-text">to</span><span class="white-text">Drawing</span>', unsafe_allow_html=True)
st.markdown('<div class="custom-text">Jana\'s embroidery studio. Convert Photo\'s to Drawings using AI</div>', unsafe_allow_html=True)
# Download and load the model
MODEL_URL = "https://huggingface.co/Hammad712/image2sketch/resolve/main/image2sketch.pkl"
MODEL_FILE_PATH = 'image2sketch.pkl'
if 'learn' not in st.session_state:
st.session_state['learn'] = load_model(MODEL_URL, MODEL_FILE_PATH)
learn = st.session_state['learn']
# Input for image URL or path
with st.expander("Input Options", expanded=True):
image_path_or_url = st.text_input("Enter image URL", "", key="image_url", placeholder="Enter image URL", help="Enter the URL of the image to convert")
uploaded_file = st.file_uploader("Or upload an image", type=["jpg", "jpeg", "png", "webp"], key="upload_file", help="Upload an image file to convert")
if uploaded_file is not None:
image_path_or_url = uploaded_file
# Run inference button
if st.button("Convert"):
if image_path_or_url:
with st.spinner('Processing...'):
try:
high_res_image = inference(image_path_or_url, learn)
original_image = fetch_image(image_path_or_url)
# Display original and high-res images side by side
st.markdown("### Result")
col1, col2 = st.columns(2)
with col1:
st.image(original_image, caption='Original Image', use_column_width=True)
with col2:
st.image(high_res_image, caption='Drawing', use_column_width=True)
# Provide a download button for the generated image
img_byte_arr = BytesIO()
high_res_image.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
st.download_button(
label="Download Drawing",
data=img_byte_arr,
file_name="Drawing.jpg",
mime="image/jpeg"
)
st.success("Image processed successfully!")
except Exception as e:
st.error(f"An error occurred: {e}")
logging.error("Error during inference", exc_info=True)
else:
st.error("Please enter a valid image path or URL.")