ishworrsubedii
commited on
Commit
•
477d077
1
Parent(s):
fdb594c
initial commit
Browse files- .github/workflows/workflow.yml +39 -0
- .gitignore +144 -0
- Dockerfile +25 -0
- app.py +70 -0
- requirements.txt +83 -0
- setup.py +20 -0
- src/__init__.py +5 -0
- src/api/__init__.py +5 -0
- src/api/batch_api.py +752 -0
- src/api/image_prep_api.py +367 -0
- src/api/image_regeneration_api.py +172 -0
- src/api/makeup_tryon_api.py +171 -0
- src/api/mannequin_to_model_api.py +135 -0
- src/api/nto_api.py +911 -0
- src/components/__init__.py +0 -0
- src/components/auto_crop.py +61 -0
- src/components/color_extraction.py +38 -0
- src/components/makeup_try_on.py +104 -0
- src/components/necklaceTryOn.py +471 -0
- src/components/title_des_gen.py +46 -0
- src/pipelines/__init__.py +0 -0
- src/pipelines/completePipeline.py +64 -0
- src/utils/__init__.py +98 -0
- src/utils/backgroundEnhancerArchitecture.py +454 -0
- src/utils/exceptions.py +16 -0
- src/utils/logger.py +22 -0
.github/workflows/workflow.yml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Publish Docker image
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [main]
|
6 |
+
|
7 |
+
jobs:
|
8 |
+
push_to_registry:
|
9 |
+
name: Push Docker image to Docker Hub
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
permissions:
|
12 |
+
packages: write
|
13 |
+
contents: read
|
14 |
+
attestations: write
|
15 |
+
steps:
|
16 |
+
- name: Check out the repo
|
17 |
+
uses: actions/checkout@v4
|
18 |
+
|
19 |
+
- name: Log in to Docker Hub
|
20 |
+
uses: docker/login-action@f4ef78c080cd8ba55a85445d5b36e214a81df20a
|
21 |
+
with:
|
22 |
+
username: ${{ secrets.DOCKER_USERNAME }}
|
23 |
+
password: ${{ secrets.DOCKER_PASSWORD }}
|
24 |
+
|
25 |
+
- name: Extract metadata (tags, labels) for Docker
|
26 |
+
id: meta
|
27 |
+
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
|
28 |
+
with:
|
29 |
+
images: techconsp/jecmdvhgyyqtmjvzggkbukxyiphuvbwjdjadfagx
|
30 |
+
|
31 |
+
- name: Build and push Docker image
|
32 |
+
id: push
|
33 |
+
uses: docker/build-push-action@3b5e8027fcad23fda98b2e3ac259d8d67585f671
|
34 |
+
with:
|
35 |
+
context: .
|
36 |
+
file: ./Dockerfile
|
37 |
+
push: true
|
38 |
+
tags: ${{ steps.meta.outputs.tags }}
|
39 |
+
labels: ${{ steps.meta.outputs.labels }}
|
.gitignore
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# .gitignore
|
2 |
+
|
3 |
+
# Byte-compiled / optimized / DLL files
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# Distribution / packaging
|
12 |
+
.Python
|
13 |
+
build/
|
14 |
+
develop-eggs/
|
15 |
+
dist/
|
16 |
+
downloads/
|
17 |
+
eggs/
|
18 |
+
.eggs/
|
19 |
+
lib/
|
20 |
+
lib64/
|
21 |
+
parts/
|
22 |
+
sdist/
|
23 |
+
var/
|
24 |
+
wheels/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
|
31 |
+
# PyInstaller
|
32 |
+
# Usually these files are written by a python script from a template
|
33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34 |
+
*.manifest
|
35 |
+
*.spec
|
36 |
+
|
37 |
+
# Installer logs
|
38 |
+
pip-log.txt
|
39 |
+
pip-delete-this-directory.txt
|
40 |
+
|
41 |
+
# Unit test / coverage reports
|
42 |
+
htmlcov/
|
43 |
+
.tox/
|
44 |
+
.nox/
|
45 |
+
.coverage
|
46 |
+
.coverage.*
|
47 |
+
.cache
|
48 |
+
nosetests.xml
|
49 |
+
coverage.xml
|
50 |
+
*.cover
|
51 |
+
*.py,cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
cover/
|
55 |
+
|
56 |
+
# Translations
|
57 |
+
*.mo
|
58 |
+
*.pot
|
59 |
+
|
60 |
+
# Django stuff:
|
61 |
+
*.log
|
62 |
+
local_settings.py
|
63 |
+
db.sqlite3
|
64 |
+
db.sqlite3-journal
|
65 |
+
|
66 |
+
# Flask stuff:
|
67 |
+
instance/
|
68 |
+
.webassets-cache
|
69 |
+
|
70 |
+
# Scrapy stuff:
|
71 |
+
.scrapy
|
72 |
+
|
73 |
+
# Sphinx documentation
|
74 |
+
docs/_build/
|
75 |
+
|
76 |
+
# PyBuilder
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# IPython
|
83 |
+
profile_default/
|
84 |
+
ipython_config.py
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
.python-version
|
88 |
+
|
89 |
+
# pipenv
|
90 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
91 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
92 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
93 |
+
# install all needed dependencies.
|
94 |
+
# Pipfile.lock
|
95 |
+
|
96 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
97 |
+
__pypackages__/
|
98 |
+
|
99 |
+
# Celery stuff
|
100 |
+
celerybeat-schedule
|
101 |
+
celerybeat.pid
|
102 |
+
|
103 |
+
# SageMath parsed files
|
104 |
+
*.sage.py
|
105 |
+
|
106 |
+
# Environments
|
107 |
+
.env
|
108 |
+
.venv
|
109 |
+
env/
|
110 |
+
venv/
|
111 |
+
ENV/
|
112 |
+
env.bak/
|
113 |
+
venv.bak/
|
114 |
+
|
115 |
+
# Spyder project settings
|
116 |
+
.spyderproject
|
117 |
+
.spyderworkspace
|
118 |
+
|
119 |
+
# Rope project settings
|
120 |
+
.ropeproject
|
121 |
+
|
122 |
+
# mkdocs documentation
|
123 |
+
/site
|
124 |
+
|
125 |
+
# mypy
|
126 |
+
.mypy_cache/
|
127 |
+
.dmypy.json
|
128 |
+
dmypy.json
|
129 |
+
|
130 |
+
# Pyre type checker
|
131 |
+
.pyre/
|
132 |
+
|
133 |
+
# pytype static type analyzer
|
134 |
+
.pytype/
|
135 |
+
|
136 |
+
# Cython debug symbols
|
137 |
+
cython_debug/
|
138 |
+
|
139 |
+
# PyCharm
|
140 |
+
.idea/
|
141 |
+
|
142 |
+
# Local logs
|
143 |
+
*.log
|
144 |
+
examples
|
Dockerfile
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
WORKDIR /api
|
4 |
+
|
5 |
+
COPY . /api
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get install -y \
|
8 |
+
libgl1-mesa-glx \
|
9 |
+
ffmpeg \
|
10 |
+
libsm6 \
|
11 |
+
libxext6
|
12 |
+
|
13 |
+
RUN pip install -r requirements.txt
|
14 |
+
|
15 |
+
RUN mkdir -p /.cache /root/.cache /tmp/.cache /api/.cache && \
|
16 |
+
chmod -R 777 /.cache /root/.cache /tmp/.cache /api/.cache
|
17 |
+
|
18 |
+
ENV MPLCONFIGDIR=/tmp/.cache/matplotlib
|
19 |
+
ENV HUGGINGFACE_HUB_CACHE=/tmp/.cache/huggingface
|
20 |
+
|
21 |
+
RUN chmod -R 777 /api /tmp /root /.cache
|
22 |
+
|
23 |
+
EXPOSE 7860
|
24 |
+
|
25 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ NTO-TCP-HF
|
3 |
+
created @ 2024-10-28
|
4 |
+
author @ github.com/ishworrsubedii
|
5 |
+
"""
|
6 |
+
import time
|
7 |
+
from datetime import datetime, timedelta, timezone
|
8 |
+
|
9 |
+
from fastapi import FastAPI, Depends, HTTPException
|
10 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
11 |
+
from src.api.batch_api import batch_router
|
12 |
+
from src.api.image_prep_api import preprocessing_router
|
13 |
+
from src.api.image_regeneration_api import image_regeneration_router
|
14 |
+
from src.api.makeup_tryon_api import makeup_tryon_router
|
15 |
+
from src.api.mannequin_to_model_api import mto_router
|
16 |
+
from src.api.nto_api import nto_cto_router
|
17 |
+
from src.api.nto_api import supabase
|
18 |
+
from starlette.middleware.cors import CORSMiddleware
|
19 |
+
from starlette.responses import JSONResponse
|
20 |
+
|
21 |
+
security = HTTPBearer()
|
22 |
+
|
23 |
+
|
24 |
+
async def verify_login_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
25 |
+
try:
|
26 |
+
response = supabase.table("JewelMirrorUserLogins").select("*").eq("LoginToken",
|
27 |
+
credentials.credentials).execute()
|
28 |
+
|
29 |
+
if not response.data:
|
30 |
+
raise HTTPException(status_code=401, detail="Unauthorized: Token not found")
|
31 |
+
|
32 |
+
token_data = response.data[0]
|
33 |
+
|
34 |
+
created_at = datetime.fromisoformat(token_data["UpdatedAt"].replace("Z", "+00:00"))
|
35 |
+
current_time = datetime.now(timezone.utc)
|
36 |
+
time_difference = current_time - created_at
|
37 |
+
|
38 |
+
if time_difference <= timedelta(minutes=30):
|
39 |
+
return JSONResponse({"status": "Authorized", "message": "Token is valid"}, status_code=200)
|
40 |
+
|
41 |
+
raise HTTPException(status_code=401, detail="Unauthorized: Token expired")
|
42 |
+
|
43 |
+
except Exception as e:
|
44 |
+
print(f"Token verification error: {e}")
|
45 |
+
raise HTTPException(status_code=401, detail="Invalid token")
|
46 |
+
|
47 |
+
|
48 |
+
app = FastAPI(title="NTO-TCP-HF", version="1.0.0")
|
49 |
+
app.include_router(nto_cto_router, tags=["NTO-CTO"])
|
50 |
+
app.include_router(preprocessing_router, tags=["Image-Preprocessing"])
|
51 |
+
app.include_router(image_regeneration_router, tags=["Image-Regeneration"])
|
52 |
+
time.sleep(1)
|
53 |
+
app.include_router(batch_router, tags=['Realtime'])
|
54 |
+
|
55 |
+
app.include_router(mto_router, tags=["MTO"])
|
56 |
+
|
57 |
+
app.include_router(makeup_tryon_router, tags=["makeup_tryon"])
|
58 |
+
|
59 |
+
app.add_middleware(
|
60 |
+
CORSMiddleware,
|
61 |
+
allow_origins=["*"],
|
62 |
+
allow_credentials=True,
|
63 |
+
allow_methods=["*"],
|
64 |
+
allow_headers=["*"],
|
65 |
+
)
|
66 |
+
|
67 |
+
if __name__ == '__main__':
|
68 |
+
import uvicorn
|
69 |
+
|
70 |
+
uvicorn.run(app)
|
requirements.txt
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
annotated-types==0.7.0
|
3 |
+
anyio==4.4.0
|
4 |
+
attrs==23.2.0
|
5 |
+
certifi==2024.7.4
|
6 |
+
cffi==1.16.0
|
7 |
+
charset-normalizer==3.3.2
|
8 |
+
click==8.1.7
|
9 |
+
contourpy==1.2.1
|
10 |
+
cvzone==1.6.1
|
11 |
+
cycler==0.12.1
|
12 |
+
deprecation==2.1.0
|
13 |
+
dnspython==2.6.1
|
14 |
+
email_validator==2.2.0
|
15 |
+
exceptiongroup==1.2.1
|
16 |
+
fastapi==0.111.0
|
17 |
+
fastapi-cli==0.0.4
|
18 |
+
flatbuffers==24.3.25
|
19 |
+
fonttools==4.53.1
|
20 |
+
gotrue==2.5.4
|
21 |
+
h11==0.14.0
|
22 |
+
httpcore==1.0.5
|
23 |
+
httptools==0.6.1
|
24 |
+
httpx==0.27.0
|
25 |
+
idna==3.7
|
26 |
+
jax==0.4.30
|
27 |
+
jaxlib==0.4.30
|
28 |
+
Jinja2==3.1.4
|
29 |
+
kiwisolver==1.4.5
|
30 |
+
markdown-it-py==3.0.0
|
31 |
+
MarkupSafe==2.1.5
|
32 |
+
matplotlib==3.9.1
|
33 |
+
mdurl==0.1.2
|
34 |
+
mediapipe
|
35 |
+
ml-dtypes==0.4.0
|
36 |
+
numpy==2.0.0
|
37 |
+
opencv-contrib-python==4.10.0.84
|
38 |
+
opencv-python==4.10.0.84
|
39 |
+
opt-einsum==3.3.0
|
40 |
+
orjson==3.10.6
|
41 |
+
packaging==24.1
|
42 |
+
pandas==2.2.2
|
43 |
+
pillow==10.4.0
|
44 |
+
postgrest==0.16.8
|
45 |
+
protobuf==4.25.3
|
46 |
+
pycparser==2.22
|
47 |
+
pydantic==2.8.2
|
48 |
+
pydantic_core==2.20.1
|
49 |
+
Pygments==2.18.0
|
50 |
+
pyparsing==3.1.2
|
51 |
+
python-dateutil==2.9.0.post0
|
52 |
+
python-dotenv==1.0.1
|
53 |
+
python-multipart==0.0.9
|
54 |
+
pytz==2024.1
|
55 |
+
PyYAML==6.0.1
|
56 |
+
realtime==1.0.6
|
57 |
+
requests==2.32.3
|
58 |
+
rich==13.7.1
|
59 |
+
scikit-build==0.18.0
|
60 |
+
scipy==1.14.0
|
61 |
+
shellingham==1.5.4
|
62 |
+
six==1.16.0
|
63 |
+
sniffio==1.3.1
|
64 |
+
sounddevice==0.4.7
|
65 |
+
starlette==0.37.2
|
66 |
+
storage3==0.7.6
|
67 |
+
StrEnum==0.4.15
|
68 |
+
supabase==2.5.1
|
69 |
+
supafunc==0.4.6
|
70 |
+
tomli==2.0.1
|
71 |
+
typer==0.12.3
|
72 |
+
typing_extensions==4.12.2
|
73 |
+
tzdata==2024.1
|
74 |
+
ujson==5.10.0
|
75 |
+
urllib3==2.2.2
|
76 |
+
uvicorn==0.30.1
|
77 |
+
uvloop==0.19.0
|
78 |
+
watchfiles==0.22.0
|
79 |
+
websockets==12.0
|
80 |
+
replicate
|
81 |
+
google-generativeai
|
82 |
+
aiohttp
|
83 |
+
-e .
|
setup.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
HYPER_E_DOT = "-e ."
|
4 |
+
|
5 |
+
|
6 |
+
def getRequirements(requirementsPath: str) -> list[str]:
|
7 |
+
with open(requirementsPath) as file:
|
8 |
+
requirements = file.read().split("\n")
|
9 |
+
requirements.remove(HYPER_E_DOT)
|
10 |
+
return requirements
|
11 |
+
|
12 |
+
|
13 |
+
setup(
|
14 |
+
name="Jewel Mirror",
|
15 |
+
author="Ishwor Subedi",
|
16 |
+
author_email="ishworr.subedi@gmail.com",
|
17 |
+
version="0.1",
|
18 |
+
packages=find_packages(),
|
19 |
+
install_requires=getRequirements(requirementsPath="./requirements.txt")
|
20 |
+
)
|
src/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ NTO-TCP-HF
|
3 |
+
created @ 2024-10-28
|
4 |
+
author @ github/ishworrsubedii
|
5 |
+
"""
|
src/api/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ NTO-TCP-HF
|
3 |
+
created @ 2024-10-28
|
4 |
+
author @ github/ishworrsubedii
|
5 |
+
"""
|
src/api/batch_api.py
ADDED
@@ -0,0 +1,752 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ CTO_TCP_ZERO_GPU
|
3 |
+
created @ 2024-11-14
|
4 |
+
author @ github.com/ishworrsubedii
|
5 |
+
"""
|
6 |
+
import asyncio
|
7 |
+
import base64
|
8 |
+
import gc
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
from io import BytesIO
|
12 |
+
from typing import List
|
13 |
+
|
14 |
+
import aiohttp
|
15 |
+
import cv2
|
16 |
+
import numpy as np
|
17 |
+
from PIL import Image
|
18 |
+
from fastapi import File, UploadFile, Form
|
19 |
+
from fastapi.responses import JSONResponse
|
20 |
+
from fastapi.responses import StreamingResponse
|
21 |
+
from fastapi.routing import APIRouter
|
22 |
+
from pydantic import BaseModel
|
23 |
+
|
24 |
+
from src.api.makeup_tryon_api import get_colors
|
25 |
+
from src.api.nto_api import pipeline, replicate_run_cto, supabase_upload_and_return_url
|
26 |
+
from src.utils import returnBytesData
|
27 |
+
from src.utils.logger import logger
|
28 |
+
|
29 |
+
batch_router = APIRouter()
|
30 |
+
colors = get_colors()
|
31 |
+
|
32 |
+
|
33 |
+
class ClothingRequest(BaseModel):
|
34 |
+
c_list: List[str]
|
35 |
+
|
36 |
+
|
37 |
+
@batch_router.post("/rt_cto")
|
38 |
+
async def rt_cto(
|
39 |
+
image: UploadFile = File(...),
|
40 |
+
c_list: str = Form(...)
|
41 |
+
):
|
42 |
+
logger.info("-" * 50)
|
43 |
+
logger.info(">>> REAL-TIME CTO STARTED <<<")
|
44 |
+
logger.info(f"Parameters: clothing_list={c_list}")
|
45 |
+
|
46 |
+
setup_start_time = time.time()
|
47 |
+
try:
|
48 |
+
clothing_list = [item.strip() for item in c_list.split(",")]
|
49 |
+
image_bytes = await image.read()
|
50 |
+
pil_image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
51 |
+
setup_time = round(time.time() - setup_start_time, 2)
|
52 |
+
logger.info(f">>> IMAGE LOADED SUCCESSFULLY in {setup_time}s <<<")
|
53 |
+
except Exception as e:
|
54 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
55 |
+
return {"error": "Error reading image", "code": 500}
|
56 |
+
|
57 |
+
async def generate():
|
58 |
+
logger.info("-" * 50)
|
59 |
+
logger.info(">>> CLOTHING TRY ON V2 STARTED <<<")
|
60 |
+
|
61 |
+
# Mask generation timing
|
62 |
+
mask_start_time = time.time()
|
63 |
+
try:
|
64 |
+
mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=pil_image)
|
65 |
+
mask_time = round(time.time() - mask_start_time, 2)
|
66 |
+
logger.info(f">>> MASK GENERATION COMPLETED in {mask_time}s <<<")
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(f">>> MASK GENERATION ERROR: {str(e)} <<<")
|
69 |
+
yield json.dumps({"error": "Error generating mask", "code": 500}) + "\n"
|
70 |
+
await asyncio.sleep(0.1)
|
71 |
+
return
|
72 |
+
|
73 |
+
# Encoding timing
|
74 |
+
encoding_start_time = time.time()
|
75 |
+
try:
|
76 |
+
mask_img_base_64, act_img_base_64 = BytesIO(), BytesIO()
|
77 |
+
mask.save(mask_img_base_64, format="WEBP")
|
78 |
+
pil_image.save(act_img_base_64, format="WEBP")
|
79 |
+
mask_bytes_ = base64.b64encode(mask_img_base_64.getvalue()).decode("utf-8")
|
80 |
+
image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8")
|
81 |
+
|
82 |
+
mask_data_uri = f"data:image/webp;base64,{mask_bytes_}"
|
83 |
+
image_data_uri = f"data:image/webp;base64,{image_bytes_}"
|
84 |
+
encoding_time = round(time.time() - encoding_start_time, 2)
|
85 |
+
logger.info(f">>> IMAGE ENCODING COMPLETED in {encoding_time}s <<<")
|
86 |
+
except Exception as e:
|
87 |
+
logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
|
88 |
+
yield json.dumps({"error": "Error converting images to base64", "code": 500}) + "\n"
|
89 |
+
await asyncio.sleep(0.1)
|
90 |
+
return
|
91 |
+
|
92 |
+
for idx, clothing_type in enumerate(clothing_list):
|
93 |
+
if not clothing_type:
|
94 |
+
continue
|
95 |
+
|
96 |
+
iteration_start_time = time.time()
|
97 |
+
try:
|
98 |
+
inference_start_time = time.time()
|
99 |
+
output = replicate_run_cto({
|
100 |
+
"mask": mask_data_uri,
|
101 |
+
"image": image_data_uri,
|
102 |
+
"prompt": f"Dull {clothing_type}, non-reflective clothing, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple, perfect eyes, perfect face, perfect body, high quality, realistic, photorealistic, high resolution,traditional full sleeve blouse",
|
103 |
+
"negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
|
104 |
+
"num_inference_steps": 25
|
105 |
+
})
|
106 |
+
inference_time = round(time.time() - inference_start_time, 2)
|
107 |
+
logger.info(f">>> REPLICATE PROCESSING COMPLETED FOR {clothing_type} in {inference_time}s <<<")
|
108 |
+
|
109 |
+
output_url = str(output[0]) if output and output[0] else None
|
110 |
+
iteration_time = round(time.time() - iteration_start_time, 2)
|
111 |
+
|
112 |
+
result = {
|
113 |
+
"code": 200,
|
114 |
+
"output": output_url,
|
115 |
+
"timing": {
|
116 |
+
"setup": setup_time,
|
117 |
+
"mask_generation": mask_time,
|
118 |
+
"encoding": encoding_time,
|
119 |
+
"inference": inference_time,
|
120 |
+
"iteration": iteration_time
|
121 |
+
},
|
122 |
+
"clothing_type": clothing_type,
|
123 |
+
"progress": f"{idx + 1}/{len(clothing_list)}"
|
124 |
+
}
|
125 |
+
yield json.dumps(result) + "\n"
|
126 |
+
await asyncio.sleep(0.1)
|
127 |
+
|
128 |
+
except Exception as e:
|
129 |
+
logger.error(f">>> REPLICATE PROCESSING ERROR: {str(e)} <<<")
|
130 |
+
error_result = {
|
131 |
+
"error": str(e),
|
132 |
+
"details": str(e),
|
133 |
+
"code": 500,
|
134 |
+
"clothing_type": clothing_type,
|
135 |
+
"progress": f"{idx + 1}/{len(clothing_list)}"
|
136 |
+
}
|
137 |
+
yield json.dumps(error_result) + "\n"
|
138 |
+
await asyncio.sleep(0.1)
|
139 |
+
|
140 |
+
return StreamingResponse(
|
141 |
+
generate(),
|
142 |
+
media_type="application/x-ndjson",
|
143 |
+
headers={
|
144 |
+
"Cache-Control": "no-cache",
|
145 |
+
"Connection": "keep-alive",
|
146 |
+
"X-Accel-Buffering": "no",
|
147 |
+
"Transfer-Encoding": "chunked"
|
148 |
+
}
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
@batch_router.post("/rt_nto")
|
153 |
+
async def rt_nto(
|
154 |
+
image: UploadFile = File(...),
|
155 |
+
necklace_id_list: str = Form(...),
|
156 |
+
category_list: str = Form(...),
|
157 |
+
storename_list: str = Form(...),
|
158 |
+
offset_x_list: str = Form(...),
|
159 |
+
offset_y_list: str = Form(...)
|
160 |
+
):
|
161 |
+
logger.info("-" * 50)
|
162 |
+
logger.info(">>> REAL-TIME NECKLACE TRY ON STARTED <<<")
|
163 |
+
logger.info(f"Parameters: storename={storename_list}, categories={category_list}, necklace_ids={necklace_id_list}")
|
164 |
+
|
165 |
+
try:
|
166 |
+
# Parse all lists
|
167 |
+
necklace_ids = [id.strip() for id in necklace_id_list.split(",")]
|
168 |
+
categories = [cat.strip() for cat in category_list.split(",")]
|
169 |
+
stores = [store.strip() for store in storename_list.split(",")]
|
170 |
+
offset_x_values = [float(x.strip()) for x in offset_x_list.split(",")]
|
171 |
+
offset_y_values = [float(y.strip()) for y in offset_y_list.split(",")]
|
172 |
+
|
173 |
+
# Validate list lengths
|
174 |
+
if len(necklace_ids) != len(categories) or \
|
175 |
+
len(necklace_ids) != len(stores) or \
|
176 |
+
len(necklace_ids) != len(offset_x_values) or \
|
177 |
+
len(necklace_ids) != len(offset_y_values):
|
178 |
+
return JSONResponse(
|
179 |
+
content={
|
180 |
+
"error": "Number of necklace IDs must match number of categories, stores, and offset values",
|
181 |
+
"code": 400
|
182 |
+
},
|
183 |
+
status_code=400
|
184 |
+
)
|
185 |
+
|
186 |
+
# Load the source image
|
187 |
+
image_bytes = await image.read()
|
188 |
+
source_image = Image.open(BytesIO(image_bytes))
|
189 |
+
logger.info(">>> SOURCE IMAGE LOADED SUCCESSFULLY <<<")
|
190 |
+
except Exception as e:
|
191 |
+
logger.error(f">>> INITIAL SETUP ERROR: {str(e)} <<<")
|
192 |
+
return JSONResponse(
|
193 |
+
content={"error": "Error in initial setup", "details": str(e), "code": 500},
|
194 |
+
status_code=500
|
195 |
+
)
|
196 |
+
|
197 |
+
async def generate():
|
198 |
+
setup_start_time = time.time() # Add setup timing
|
199 |
+
|
200 |
+
# After loading images
|
201 |
+
setup_time = round(time.time() - setup_start_time, 2)
|
202 |
+
logger.info(f">>> SETUP COMPLETED in {setup_time}s <<<")
|
203 |
+
|
204 |
+
for idx, (necklace_id, category, storename, offset_x, offset_y) in enumerate(
|
205 |
+
zip(necklace_ids, categories, stores, offset_x_values, offset_y_values)
|
206 |
+
):
|
207 |
+
iteration_start_time = time.time()
|
208 |
+
try:
|
209 |
+
# Load jewellery timing
|
210 |
+
jewellery_load_start = time.time()
|
211 |
+
jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{storename}/{category}/image/{necklace_id}.png"
|
212 |
+
jewellery = Image.open(returnBytesData(url=jewellery_url))
|
213 |
+
jewellery_time = round(time.time() - jewellery_load_start, 2)
|
214 |
+
logger.info(f">>> JEWELLERY LOADED in {jewellery_time}s <<<")
|
215 |
+
|
216 |
+
# NTO timing
|
217 |
+
nto_start_time = time.time()
|
218 |
+
result, headetText, mask = await pipeline.necklaceTryOnDynamicOffset_(
|
219 |
+
image=source_image,
|
220 |
+
jewellery=jewellery,
|
221 |
+
storename=storename,
|
222 |
+
offset=[offset_x, offset_y]
|
223 |
+
)
|
224 |
+
nto_time = round(time.time() - nto_start_time, 2)
|
225 |
+
|
226 |
+
# Upload timing
|
227 |
+
upload_start_time = time.time()
|
228 |
+
upload_tasks = [
|
229 |
+
supabase_upload_and_return_url(prefix="NTO", image=result, necklace_id=necklace_id),
|
230 |
+
supabase_upload_and_return_url(prefix="NTO", image=mask, necklace_id=necklace_id)
|
231 |
+
]
|
232 |
+
result_url, mask_url = await asyncio.gather(*upload_tasks)
|
233 |
+
upload_time = round(time.time() - upload_start_time, 2)
|
234 |
+
|
235 |
+
result = {
|
236 |
+
"code": 200,
|
237 |
+
"output": result_url,
|
238 |
+
"mask": mask_url,
|
239 |
+
"timing": {
|
240 |
+
"setup": setup_time,
|
241 |
+
"jewellery_load": jewellery_time,
|
242 |
+
"nto_inference": nto_time,
|
243 |
+
"upload": upload_time,
|
244 |
+
"total_iteration": round(time.time() - iteration_start_time, 2)
|
245 |
+
},
|
246 |
+
"necklace_id": necklace_id,
|
247 |
+
"category": category,
|
248 |
+
"progress": f"{idx + 1}/{len(necklace_ids)}"
|
249 |
+
}
|
250 |
+
yield json.dumps(result) + "\n"
|
251 |
+
await asyncio.sleep(0.1)
|
252 |
+
|
253 |
+
del result
|
254 |
+
del mask
|
255 |
+
gc.collect()
|
256 |
+
|
257 |
+
except Exception as e:
|
258 |
+
logger.error(f">>> PROCESSING ERROR FOR {necklace_id}: {str(e)} <<<")
|
259 |
+
error_result = {
|
260 |
+
"error": f"Error processing necklace {necklace_id}",
|
261 |
+
"details": str(e),
|
262 |
+
"code": 500,
|
263 |
+
"necklace_id": necklace_id,
|
264 |
+
"category": category,
|
265 |
+
"progress": f"{idx + 1}/{len(necklace_ids)}"
|
266 |
+
}
|
267 |
+
yield json.dumps(error_result) + "\n"
|
268 |
+
await asyncio.sleep(0.1)
|
269 |
+
|
270 |
+
return StreamingResponse(
|
271 |
+
generate(),
|
272 |
+
media_type="application/x-ndjson",
|
273 |
+
headers={
|
274 |
+
"Cache-Control": "no-cache",
|
275 |
+
"Connection": "keep-alive",
|
276 |
+
"X-Accel-Buffering": "no",
|
277 |
+
"Transfer-Encoding": "chunked"
|
278 |
+
}
|
279 |
+
)
|
280 |
+
|
281 |
+
|
282 |
+
@batch_router.post("/rt_cto_nto")
|
283 |
+
async def rt_cto_nto(
|
284 |
+
image: UploadFile = File(...),
|
285 |
+
c_list: str = Form(...),
|
286 |
+
necklace_id_list: str = Form(...),
|
287 |
+
necklace_category_list: str = Form(...),
|
288 |
+
storename_list: str = Form(...),
|
289 |
+
offset_x_list: str = Form(...),
|
290 |
+
offset_y_list: str = Form(...)
|
291 |
+
):
|
292 |
+
logger.info("-" * 50)
|
293 |
+
logger.info(">>> REAL-TIME CTO-NTO STARTED <<<")
|
294 |
+
logger.info(f"Parameters: storenames={storename_list}, necklace_categories={necklace_category_list}, "
|
295 |
+
f"necklace_ids={necklace_id_list}, clothing_list={c_list}")
|
296 |
+
|
297 |
+
try:
|
298 |
+
# Parse all input lists
|
299 |
+
clothing_list = [item.strip() for item in c_list.split(",")]
|
300 |
+
necklace_ids = [id.strip() for id in necklace_id_list.split(",")]
|
301 |
+
categories = [cat.strip() for cat in necklace_category_list.split(",")]
|
302 |
+
stores = [store.strip() for store in storename_list.split(",")]
|
303 |
+
offset_x_values = [float(x.strip()) for x in offset_x_list.split(",")]
|
304 |
+
offset_y_values = [float(y.strip()) for y in offset_y_list.split(",")]
|
305 |
+
|
306 |
+
# Validate list lengths
|
307 |
+
if len(necklace_ids) != len(categories) or \
|
308 |
+
len(necklace_ids) != len(stores) or \
|
309 |
+
len(necklace_ids) != len(offset_x_values) or \
|
310 |
+
len(necklace_ids) != len(offset_y_values):
|
311 |
+
return JSONResponse(
|
312 |
+
content={
|
313 |
+
"error": "Number of necklace IDs must match number of categories, stores, and offset values",
|
314 |
+
"code": 400
|
315 |
+
},
|
316 |
+
status_code=400
|
317 |
+
)
|
318 |
+
|
319 |
+
# Load source image
|
320 |
+
image_bytes = await image.read()
|
321 |
+
source_image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
322 |
+
|
323 |
+
# Load all jewellery images
|
324 |
+
jewellery_data = []
|
325 |
+
for nid, cat, store in zip(necklace_ids, categories, stores):
|
326 |
+
jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{store}/{cat}/image/{nid}.png"
|
327 |
+
jewellery = Image.open(returnBytesData(url=jewellery_url)).convert("RGBA")
|
328 |
+
jewellery_data.append({
|
329 |
+
'image': jewellery,
|
330 |
+
'id': nid,
|
331 |
+
'category': cat,
|
332 |
+
'store': store
|
333 |
+
})
|
334 |
+
|
335 |
+
logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
|
336 |
+
except Exception as e:
|
337 |
+
logger.error(f">>> INITIAL SETUP ERROR: {str(e)} <<<")
|
338 |
+
return JSONResponse(
|
339 |
+
content={"error": "Error in initial setup", "details": str(e), "code": 500},
|
340 |
+
status_code=500
|
341 |
+
)
|
342 |
+
|
343 |
+
async def generate():
|
344 |
+
setup_start_time = time.time()
|
345 |
+
|
346 |
+
try:
|
347 |
+
# Mask generation timing
|
348 |
+
mask_start_time = time.time()
|
349 |
+
mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=source_image)
|
350 |
+
mask_time = round(time.time() - mask_start_time, 2)
|
351 |
+
logger.info(f">>> MASK GENERATION COMPLETED in {mask_time}s <<<")
|
352 |
+
|
353 |
+
# Encoding
|
354 |
+
encoding_start_time = time.time()
|
355 |
+
mask_img_base_64, act_img_base_64 = BytesIO(), BytesIO()
|
356 |
+
mask.save(mask_img_base_64, format="WEBP")
|
357 |
+
source_image.save(act_img_base_64, format="WEBP")
|
358 |
+
mask_bytes_ = base64.b64encode(mask_img_base_64.getvalue()).decode("utf-8")
|
359 |
+
image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8")
|
360 |
+
|
361 |
+
mask_data_uri = f"data:image/webp;base64,{mask_bytes_}"
|
362 |
+
image_data_uri = f"data:image/webp;base64,{image_bytes_}"
|
363 |
+
encoding_time = round(time.time() - encoding_start_time, 2)
|
364 |
+
logger.info(f">>> IMAGE ENCODING COMPLETED in {encoding_time}s <<<")
|
365 |
+
|
366 |
+
total_combinations = len(clothing_list) * len(jewellery_data)
|
367 |
+
current_combination = 0
|
368 |
+
|
369 |
+
for clothing_type in clothing_list:
|
370 |
+
# First do CTO
|
371 |
+
cto_start_time = time.time()
|
372 |
+
cto_output = replicate_run_cto({
|
373 |
+
"mask": mask_data_uri,
|
374 |
+
"image": image_data_uri,
|
375 |
+
"prompt": f"Dull {clothing_type}, non-reflective clothing, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple, perfect eyes, perfect face, perfect body, high quality, realistic, photorealistic, high resolution,traditional full sleeve blouse",
|
376 |
+
"negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
|
377 |
+
"num_inference_steps": 25
|
378 |
+
})
|
379 |
+
cto_time = round(time.time() - cto_start_time, 2)
|
380 |
+
logger.info(f">>> CTO COMPLETED for {clothing_type} in {cto_time}s <<<")
|
381 |
+
|
382 |
+
# Get CTO result image
|
383 |
+
async with aiohttp.ClientSession() as session:
|
384 |
+
async with session.get(str(cto_output[0])) as response:
|
385 |
+
if response.status != 200:
|
386 |
+
raise ValueError("Failed to fetch CTO output")
|
387 |
+
cto_result_bytes = await response.read()
|
388 |
+
|
389 |
+
cto_result_image = Image.open(BytesIO(cto_result_bytes)).convert("RGB")
|
390 |
+
|
391 |
+
# Now apply each necklace to the CTO result
|
392 |
+
for jewellery_item, offset_x, offset_y in zip(jewellery_data, offset_x_values, offset_y_values):
|
393 |
+
current_combination += 1
|
394 |
+
iteration_start_time = time.time()
|
395 |
+
|
396 |
+
try:
|
397 |
+
# Perform NTO
|
398 |
+
nto_start_time = time.time()
|
399 |
+
result, headerText, nto_mask = await pipeline.necklaceTryOnDynamicOffset_(
|
400 |
+
image=cto_result_image,
|
401 |
+
jewellery=jewellery_item['image'],
|
402 |
+
storename=jewellery_item['store'],
|
403 |
+
offset=[offset_x, offset_y]
|
404 |
+
)
|
405 |
+
nto_time = round(time.time() - nto_start_time, 2)
|
406 |
+
logger.info(f">>> NTO COMPLETED for {jewellery_item['id']} in {nto_time}s <<<")
|
407 |
+
|
408 |
+
# Upload result
|
409 |
+
upload_start_time = time.time()
|
410 |
+
result_url = await supabase_upload_and_return_url(
|
411 |
+
prefix="NTOCTO",
|
412 |
+
image=result,
|
413 |
+
necklace_id=jewellery_item['id']
|
414 |
+
)
|
415 |
+
upload_time = round(time.time() - upload_start_time, 2)
|
416 |
+
|
417 |
+
output_result = {
|
418 |
+
"code": 200,
|
419 |
+
"output": result_url,
|
420 |
+
"timing": {
|
421 |
+
"setup": round(time.time() - setup_start_time, 2),
|
422 |
+
"mask_generation": mask_time,
|
423 |
+
"encoding": encoding_time,
|
424 |
+
"cto_inference": cto_time,
|
425 |
+
"nto_inference": nto_time,
|
426 |
+
"upload": upload_time,
|
427 |
+
"total_iteration": round(time.time() - iteration_start_time, 2)
|
428 |
+
},
|
429 |
+
"clothing_type": clothing_type,
|
430 |
+
"necklace_id": jewellery_item['id'],
|
431 |
+
"necklace_category": jewellery_item['category'],
|
432 |
+
"store": jewellery_item['store'],
|
433 |
+
"progress": f"{current_combination}/{total_combinations}"
|
434 |
+
}
|
435 |
+
yield json.dumps(output_result) + "\n"
|
436 |
+
await asyncio.sleep(0.1)
|
437 |
+
|
438 |
+
del result
|
439 |
+
gc.collect()
|
440 |
+
|
441 |
+
except Exception as e:
|
442 |
+
logger.error(
|
443 |
+
f">>> PROCESSING ERROR FOR {clothing_type} with {jewellery_item['id']}: {str(e)} <<<")
|
444 |
+
error_result = {
|
445 |
+
"error": f"Error processing combination",
|
446 |
+
"details": str(e),
|
447 |
+
"code": 500,
|
448 |
+
"clothing_type": clothing_type,
|
449 |
+
"necklace_id": jewellery_item['id'],
|
450 |
+
"necklace_category": jewellery_item['category'],
|
451 |
+
"store": jewellery_item['store'],
|
452 |
+
"progress": f"{current_combination}/{total_combinations}"
|
453 |
+
}
|
454 |
+
yield json.dumps(error_result) + "\n"
|
455 |
+
await asyncio.sleep(0.1)
|
456 |
+
|
457 |
+
except Exception as e:
|
458 |
+
logger.error(f">>> GENERAL PROCESSING ERROR: {str(e)} <<<")
|
459 |
+
yield json.dumps({
|
460 |
+
"error": "General processing error",
|
461 |
+
"details": str(e),
|
462 |
+
"code": 500
|
463 |
+
}) + "\n"
|
464 |
+
await asyncio.sleep(0.1)
|
465 |
+
|
466 |
+
return StreamingResponse(
|
467 |
+
generate(),
|
468 |
+
media_type="application/x-ndjson",
|
469 |
+
headers={
|
470 |
+
"Cache-Control": "no-cache",
|
471 |
+
"Connection": "keep-alive",
|
472 |
+
"X-Accel-Buffering": "no",
|
473 |
+
"Transfer-Encoding": "chunked"
|
474 |
+
}
|
475 |
+
)
|
476 |
+
|
477 |
+
|
478 |
+
@batch_router.post("/rt_makeup")
|
479 |
+
async def rt_makeup(
|
480 |
+
image: UploadFile = File(...),
|
481 |
+
lipstick_colors: str = Form(...),
|
482 |
+
eyeliner_colors: str = Form(...),
|
483 |
+
eyeshadow_colors: str = Form(...)
|
484 |
+
):
|
485 |
+
logger.info("-" * 50)
|
486 |
+
logger.info(">>> REAL-TIME MAKEUP STARTED <<<")
|
487 |
+
logger.info(f"Parameters: lipstick_colors={lipstick_colors}, "
|
488 |
+
f"eyeliner_colors={eyeliner_colors}, eyeshadow_colors={eyeshadow_colors}")
|
489 |
+
|
490 |
+
try:
|
491 |
+
# Parse color lists
|
492 |
+
lipstick_colors_list = [color.strip() for color in lipstick_colors.split(",")]
|
493 |
+
eyeliner_colors_list = [color.strip() for color in eyeliner_colors.split(",")]
|
494 |
+
eyeshadow_colors_list = [color.strip() for color in eyeshadow_colors.split(",")]
|
495 |
+
|
496 |
+
# Read uploaded image
|
497 |
+
image_bytes = await image.read()
|
498 |
+
source_image = Image.open(BytesIO(image_bytes)).convert("RGB").resize((1080, 1080))
|
499 |
+
cv_image = cv2.cvtColor(np.array(source_image), cv2.COLOR_RGB2BGR)
|
500 |
+
|
501 |
+
logger.info(">>> INPUT PARSING SUCCESSFUL <<<")
|
502 |
+
except Exception as e:
|
503 |
+
logger.error(f">>> INPUT PARSING ERROR: {str(e)} <<<")
|
504 |
+
return JSONResponse(
|
505 |
+
content={"error": "Error in parsing inputs", "details": str(e), "code": 500},
|
506 |
+
status_code=500
|
507 |
+
)
|
508 |
+
|
509 |
+
async def generate():
|
510 |
+
setup_start_time = time.time()
|
511 |
+
|
512 |
+
try:
|
513 |
+
combinations = list(zip(
|
514 |
+
lipstick_colors_list,
|
515 |
+
eyeliner_colors_list,
|
516 |
+
eyeshadow_colors_list
|
517 |
+
))
|
518 |
+
total_combinations = len(combinations)
|
519 |
+
current_combination = 0
|
520 |
+
|
521 |
+
for lipstick_color, eyeliner_color, eyeshadow_color in combinations:
|
522 |
+
lipstick_color_set = tuple(colors["Lipstick"][lipstick_color])
|
523 |
+
eyeliner_color_set = tuple(colors["Eyeliner"][eyeliner_color])
|
524 |
+
eyeshadow_color_set = tuple(colors["Eyeshadow"][eyeshadow_color])
|
525 |
+
|
526 |
+
current_combination += 1
|
527 |
+
iteration_start_time = time.time()
|
528 |
+
|
529 |
+
try:
|
530 |
+
makeup_start_time = time.time()
|
531 |
+
result = await pipeline.makeup_tryon(
|
532 |
+
image=cv_image.copy(),
|
533 |
+
lipstick_color=lipstick_color_set,
|
534 |
+
eyeliner_color=eyeliner_color_set,
|
535 |
+
eyeshadow_color=eyeshadow_color_set
|
536 |
+
)
|
537 |
+
makeup_time = round(time.time() - makeup_start_time, 2)
|
538 |
+
logger.info(f">>> MAKEUP APPLICATION COMPLETED in {makeup_time}s <<<")
|
539 |
+
|
540 |
+
# Upload result
|
541 |
+
upload_start_time = time.time()
|
542 |
+
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
|
543 |
+
|
544 |
+
result_url = await supabase_upload_and_return_url(
|
545 |
+
prefix="MTO",
|
546 |
+
image=result_image,
|
547 |
+
necklace_id="JEWELMIRRORMAKEUP"
|
548 |
+
)
|
549 |
+
upload_time = round(time.time() - upload_start_time, 2)
|
550 |
+
logger.info(f">>> RESULT UPLOADED in {upload_time}s <<<")
|
551 |
+
|
552 |
+
output_result = {
|
553 |
+
"code": 200,
|
554 |
+
"output": result_url,
|
555 |
+
"timing": {
|
556 |
+
"setup": round(time.time() - setup_start_time, 2),
|
557 |
+
"makeup_application": makeup_time,
|
558 |
+
"upload": upload_time,
|
559 |
+
"total_iteration": round(time.time() - iteration_start_time, 2)
|
560 |
+
},
|
561 |
+
"progress": f"{current_combination}/{total_combinations}",
|
562 |
+
"lipstick_color": lipstick_color,
|
563 |
+
"eyeliner_color": eyeliner_color,
|
564 |
+
"eyeshadow_color": eyeshadow_color
|
565 |
+
}
|
566 |
+
yield json.dumps(output_result) + "\n"
|
567 |
+
await asyncio.sleep(0.1)
|
568 |
+
|
569 |
+
del result
|
570 |
+
gc.collect()
|
571 |
+
|
572 |
+
except Exception as e:
|
573 |
+
logger.error(f">>> PROCESSING ERROR: {str(e)} <<<")
|
574 |
+
error_result = {
|
575 |
+
"error": "Error processing makeup application",
|
576 |
+
"details": str(e),
|
577 |
+
"code": 500,
|
578 |
+
"progress": f"{current_combination}/{total_combinations}",
|
579 |
+
"lipstick_color": lipstick_color,
|
580 |
+
"eyeliner_color": eyeliner_color,
|
581 |
+
"eyeshadow_color": eyeshadow_color
|
582 |
+
}
|
583 |
+
yield json.dumps(error_result) + "\n"
|
584 |
+
await asyncio.sleep(0.1)
|
585 |
+
|
586 |
+
except Exception as e:
|
587 |
+
logger.error(f">>> GENERAL PROCESSING ERROR: {str(e)} <<<")
|
588 |
+
yield json.dumps({
|
589 |
+
"error": "General processing error",
|
590 |
+
"details": str(e),
|
591 |
+
"code": 500
|
592 |
+
}) + "\n"
|
593 |
+
await asyncio.sleep(0.1)
|
594 |
+
|
595 |
+
return StreamingResponse(
|
596 |
+
generate(),
|
597 |
+
media_type="application/x-ndjson",
|
598 |
+
headers={
|
599 |
+
"Cache-Control": "no-cache",
|
600 |
+
"Connection": "keep-alive",
|
601 |
+
"X-Accel-Buffering": "no",
|
602 |
+
"Transfer-Encoding": "chunked"
|
603 |
+
}
|
604 |
+
)
|
605 |
+
|
606 |
+
|
607 |
+
@batch_router.post("/batch_rt_makeup")
|
608 |
+
async def batch_rt_makeup(
|
609 |
+
image_urls: str = Form(...),
|
610 |
+
lipstick_colors: str = Form(...),
|
611 |
+
eyeliner_colors: str = Form(...),
|
612 |
+
eyeshadow_colors: str = Form(...)
|
613 |
+
):
|
614 |
+
logger.info("-" * 50)
|
615 |
+
logger.info(">>> REAL-TIME MAKEUP STARTED <<<")
|
616 |
+
logger.info(f"Parameters: image_urls={image_urls}, lipstick_colors={lipstick_colors}, "
|
617 |
+
f"eyeliner_colors={eyeliner_colors}, eyeshadow_colors={eyeshadow_colors}")
|
618 |
+
|
619 |
+
try:
|
620 |
+
# Parse all input lists
|
621 |
+
image_urls_list = [url.strip() for url in image_urls.split(",")]
|
622 |
+
lipstick_colors_list = [color.strip() for color in lipstick_colors.split(",")]
|
623 |
+
eyeliner_colors_list = [color.strip() for color in eyeliner_colors.split(",")]
|
624 |
+
eyeshadow_colors_list = [color.strip() for color in eyeshadow_colors.split(",")]
|
625 |
+
|
626 |
+
# Validate list lengths
|
627 |
+
if not (len(image_urls_list) == len(lipstick_colors_list) ==
|
628 |
+
len(eyeliner_colors_list) == len(eyeshadow_colors_list)):
|
629 |
+
return JSONResponse(
|
630 |
+
content={
|
631 |
+
"error": "All input lists must have the same number of elements.",
|
632 |
+
"code": 400
|
633 |
+
},
|
634 |
+
status_code=400
|
635 |
+
)
|
636 |
+
|
637 |
+
logger.info(">>> INPUT PARSING SUCCESSFUL <<<")
|
638 |
+
except Exception as e:
|
639 |
+
logger.error(f">>> INPUT PARSING ERROR: {str(e)} <<<")
|
640 |
+
return JSONResponse(
|
641 |
+
content={"error": "Error in parsing inputs", "details": str(e), "code": 500},
|
642 |
+
status_code=500
|
643 |
+
)
|
644 |
+
|
645 |
+
async def generate():
|
646 |
+
setup_start_time = time.time()
|
647 |
+
|
648 |
+
try:
|
649 |
+
total_combinations = len(image_urls_list)
|
650 |
+
current_combination = 0
|
651 |
+
|
652 |
+
for image_url, lipstick_color, eyeliner_color, eyeshadow_color in zip(
|
653 |
+
image_urls_list, lipstick_colors_list, eyeliner_colors_list, eyeshadow_colors_list):
|
654 |
+
lipstick_color_set = tuple(colors["Lipstick"][lipstick_color])
|
655 |
+
eyeliner_color_set = tuple(colors["Eyeliner"][eyeliner_color])
|
656 |
+
eyeshadow_color_set = tuple(colors["Eyeshadow"][eyeshadow_color])
|
657 |
+
|
658 |
+
current_combination += 1
|
659 |
+
iteration_start_time = time.time()
|
660 |
+
lipstick_color
|
661 |
+
|
662 |
+
try:
|
663 |
+
fetch_start_time = time.time()
|
664 |
+
async with aiohttp.ClientSession() as session:
|
665 |
+
async with session.get(image_url) as response:
|
666 |
+
if response.status != 200:
|
667 |
+
raise ValueError("Failed to fetch image")
|
668 |
+
image_bytes = await response.read()
|
669 |
+
|
670 |
+
source_image = Image.open(BytesIO(image_bytes)).convert("RGB").resize((1080, 1080))
|
671 |
+
fetch_time = round(time.time() - fetch_start_time, 2)
|
672 |
+
logger.info(f">>> IMAGE FETCHED SUCCESSFULLY in {fetch_time}s <<<")
|
673 |
+
image = cv2.cvtColor(np.array(source_image), cv2.COLOR_RGB2BGR)
|
674 |
+
|
675 |
+
makeup_start_time = time.time()
|
676 |
+
result = await pipeline.makeup_tryon(
|
677 |
+
image=image,
|
678 |
+
lipstick_color=lipstick_color_set,
|
679 |
+
eyeliner_color=eyeliner_color_set,
|
680 |
+
eyeshadow_color=eyeshadow_color_set
|
681 |
+
)
|
682 |
+
makeup_time = round(time.time() - makeup_start_time, 2)
|
683 |
+
logger.info(f">>> MAKEUP APPLICATION COMPLETED in {makeup_time}s <<<")
|
684 |
+
|
685 |
+
# Upload result
|
686 |
+
upload_start_time = time.time()
|
687 |
+
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
|
688 |
+
|
689 |
+
result_url = await supabase_upload_and_return_url(
|
690 |
+
prefix="MTO",
|
691 |
+
image=result_image,
|
692 |
+
necklace_id="JEWELMIRRORMAKEUP"
|
693 |
+
)
|
694 |
+
upload_time = round(time.time() - upload_start_time, 2)
|
695 |
+
logger.info(f">>> RESULT UPLOADED in {upload_time}s <<<")
|
696 |
+
|
697 |
+
output_result = {
|
698 |
+
"code": 200,
|
699 |
+
"output": result_url,
|
700 |
+
"timing": {
|
701 |
+
"setup": round(time.time() - setup_start_time, 2),
|
702 |
+
"fetch": fetch_time,
|
703 |
+
"makeup_application": makeup_time,
|
704 |
+
"upload": upload_time,
|
705 |
+
"total_iteration": round(time.time() - iteration_start_time, 2)
|
706 |
+
},
|
707 |
+
"progress": f"{current_combination}/{total_combinations}",
|
708 |
+
"lipstick_color": lipstick_color,
|
709 |
+
"eyeliner_color": eyeliner_color,
|
710 |
+
"eyeshadow_color": eyeshadow_color
|
711 |
+
}
|
712 |
+
yield json.dumps(output_result) + "\n"
|
713 |
+
await asyncio.sleep(0.1)
|
714 |
+
|
715 |
+
del result
|
716 |
+
gc.collect()
|
717 |
+
|
718 |
+
except Exception as e:
|
719 |
+
logger.error(
|
720 |
+
f">>> PROCESSING ERROR FOR IMAGE {image_url}: {str(e)} <<<")
|
721 |
+
error_result = {
|
722 |
+
"error": "Error processing makeup application",
|
723 |
+
"details": str(e),
|
724 |
+
"code": 500,
|
725 |
+
"progress": f"{current_combination}/{total_combinations}",
|
726 |
+
"lipstick_color": lipstick_color,
|
727 |
+
"eyeliner_color": eyeliner_color,
|
728 |
+
"eyeshadow_color": eyeshadow_color
|
729 |
+
}
|
730 |
+
yield json.dumps(error_result) + "\n"
|
731 |
+
await asyncio.sleep(0.1)
|
732 |
+
|
733 |
+
|
734 |
+
except Exception as e:
|
735 |
+
logger.error(f">>> GENERAL PROCESSING ERROR: {str(e)} <<<")
|
736 |
+
yield json.dumps({
|
737 |
+
"error": "General processing error",
|
738 |
+
"details": str(e),
|
739 |
+
"code": 500
|
740 |
+
}) + "\n"
|
741 |
+
await asyncio.sleep(0.1)
|
742 |
+
|
743 |
+
return StreamingResponse(
|
744 |
+
generate(),
|
745 |
+
media_type="application/x-ndjson",
|
746 |
+
headers={
|
747 |
+
"Cache-Control": "no-cache",
|
748 |
+
"Connection": "keep-alive",
|
749 |
+
"X-Accel-Buffering": "no",
|
750 |
+
"Transfer-Encoding": "chunked"
|
751 |
+
}
|
752 |
+
)
|
src/api/image_prep_api.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ NTO-TCP-HF
|
3 |
+
created @ 2024-10-28
|
4 |
+
author @ github.com/ishworrsubedii
|
5 |
+
"""
|
6 |
+
import base64
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
from io import BytesIO
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import replicate
|
14 |
+
import requests
|
15 |
+
from PIL import Image
|
16 |
+
from fastapi import APIRouter, UploadFile, File, HTTPException
|
17 |
+
from fastapi.responses import JSONResponse
|
18 |
+
|
19 |
+
from src.components.auto_crop import crop_transparent_image
|
20 |
+
from src.components.color_extraction import ColorExtractionRMBG
|
21 |
+
from src.components.title_des_gen import NecklaceProductListing
|
22 |
+
from src.utils.logger import logger
|
23 |
+
|
24 |
+
preprocessing_router = APIRouter()
|
25 |
+
|
26 |
+
rmbg: str = os.getenv("RMBG")
|
27 |
+
|
28 |
+
enhancer: str = os.getenv("ENHANCER")
|
29 |
+
prod_listing_api_key: str = os.getenv("PROD_LISTING_API_KEY")
|
30 |
+
|
31 |
+
color_extraction_rmbg = ColorExtractionRMBG()
|
32 |
+
product_listing_obj = NecklaceProductListing(prod_listing_api_key)
|
33 |
+
|
34 |
+
|
35 |
+
def replicate_bg(input):
|
36 |
+
output = replicate.run(
|
37 |
+
rmbg,
|
38 |
+
input=input
|
39 |
+
)
|
40 |
+
return output
|
41 |
+
|
42 |
+
|
43 |
+
def replicate_enhancer(input):
|
44 |
+
output = replicate.run(
|
45 |
+
enhancer,
|
46 |
+
input=input
|
47 |
+
)
|
48 |
+
return output
|
49 |
+
|
50 |
+
|
51 |
+
@preprocessing_router.post("/rem_bg")
|
52 |
+
async def remove_background(image: UploadFile = File(...)):
|
53 |
+
logger.info("-" * 50)
|
54 |
+
logger.info(">>> REMOVE BACKGROUND STARTED <<<")
|
55 |
+
start_time = time.time()
|
56 |
+
|
57 |
+
try:
|
58 |
+
image_bytes = await image.read()
|
59 |
+
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
60 |
+
logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<")
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
63 |
+
return JSONResponse(status_code=500, content={"error": f"Error reading image: {str(e)}", "code": 500})
|
64 |
+
|
65 |
+
try:
|
66 |
+
act_img_base_64 = BytesIO()
|
67 |
+
image.save(act_img_base_64, format="WEBP")
|
68 |
+
image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8")
|
69 |
+
image_data_uri = f"data:image/WEBP;base64,{image_bytes_}"
|
70 |
+
logger.info(">>> IMAGE ENCODING COMPLETED <<<")
|
71 |
+
except Exception as e:
|
72 |
+
logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
|
73 |
+
return JSONResponse(status_code=500,
|
74 |
+
content={"error": f"Error converting image to base64: {str(e)}", "code": 500})
|
75 |
+
|
76 |
+
try:
|
77 |
+
output = replicate_bg({"image": image_data_uri})
|
78 |
+
logger.info(">>> BACKGROUND REMOVAL COMPLETED <<<")
|
79 |
+
except Exception as e:
|
80 |
+
logger.error(f">>> BACKGROUND REMOVAL ERROR: {str(e)} <<<")
|
81 |
+
return JSONResponse(status_code=500,
|
82 |
+
content={"error": f"Error running background removal: {str(e)}", "code": 500})
|
83 |
+
|
84 |
+
try:
|
85 |
+
response = requests.get(output)
|
86 |
+
base_64 = base64.b64encode(response.content).decode('utf-8')
|
87 |
+
base64_prefix = "data:image/WEBP;base64,"
|
88 |
+
|
89 |
+
total_inference_time = round((time.time() - start_time), 2)
|
90 |
+
|
91 |
+
response = {
|
92 |
+
"output": f"{base64_prefix}{base_64}",
|
93 |
+
"inference_time": total_inference_time,
|
94 |
+
"code": 200
|
95 |
+
}
|
96 |
+
logger.info(">>> RESPONSE PREPARATION COMPLETED <<<")
|
97 |
+
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
98 |
+
logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<")
|
99 |
+
logger.info("-" * 50)
|
100 |
+
return JSONResponse(content=response, status_code=200)
|
101 |
+
|
102 |
+
except Exception as e:
|
103 |
+
logger.error(f">>> RESPONSE PROCESSING ERROR: {str(e)} <<<")
|
104 |
+
return JSONResponse(status_code=500,
|
105 |
+
content={"error": f"Error processing response: {str(e)}", "code": 500})
|
106 |
+
|
107 |
+
|
108 |
+
@preprocessing_router.post("/upscale_image")
|
109 |
+
async def upscale_image(image: UploadFile = File(...), scale: int = 1):
|
110 |
+
logger.info("-" * 50)
|
111 |
+
logger.info(">>> IMAGE UPSCALING STARTED <<<")
|
112 |
+
start_time = time.time()
|
113 |
+
|
114 |
+
try:
|
115 |
+
image_bytes = await image.read()
|
116 |
+
image = Image.open(BytesIO(image_bytes)).convert("RGBA")
|
117 |
+
logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<")
|
118 |
+
except Exception as e:
|
119 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
120 |
+
return JSONResponse(status_code=500, content={"error": f"Error reading image: {str(e)}", "code": 500})
|
121 |
+
|
122 |
+
try:
|
123 |
+
act_img_base_64 = BytesIO()
|
124 |
+
image.save(act_img_base_64, format="PNG")
|
125 |
+
image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8")
|
126 |
+
image_data_uri = f"data:image/png;base64,{image_bytes_}"
|
127 |
+
logger.info(">>> IMAGE ENCODING COMPLETED <<<")
|
128 |
+
except Exception as e:
|
129 |
+
logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
|
130 |
+
return JSONResponse(status_code=500,
|
131 |
+
content={"error": f"Error converting image to base64: {str(e)}", "code": 500})
|
132 |
+
|
133 |
+
try:
|
134 |
+
input = {
|
135 |
+
"image": image_data_uri,
|
136 |
+
"scale": scale,
|
137 |
+
"face_enhance": False
|
138 |
+
}
|
139 |
+
output = replicate_enhancer(input)
|
140 |
+
logger.info(">>> IMAGE ENHANCEMENT COMPLETED <<<")
|
141 |
+
except Exception as e:
|
142 |
+
logger.error(f">>> IMAGE ENHANCEMENT ERROR: {str(e)} <<<")
|
143 |
+
return JSONResponse(status_code=500,
|
144 |
+
content={"error": f"Error running image enhancement: {str(e)}", "code": 500})
|
145 |
+
|
146 |
+
try:
|
147 |
+
response = requests.get(output)
|
148 |
+
base_64 = base64.b64encode(response.content).decode('utf-8')
|
149 |
+
base64_prefix = image_data_uri.split(",")[0] + ","
|
150 |
+
|
151 |
+
total_inference_time = round((time.time() - start_time), 2)
|
152 |
+
|
153 |
+
response = {
|
154 |
+
"output": f"{base64_prefix}{base_64}",
|
155 |
+
"inference_time": total_inference_time,
|
156 |
+
"code": 200
|
157 |
+
}
|
158 |
+
logger.info(">>> RESPONSE PREPARATION COMPLETED <<<")
|
159 |
+
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
160 |
+
logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<")
|
161 |
+
logger.info("-" * 50)
|
162 |
+
return JSONResponse(content=response, status_code=200)
|
163 |
+
|
164 |
+
except Exception as e:
|
165 |
+
logger.error(f">>> RESPONSE PROCESSING ERROR: {str(e)} <<<")
|
166 |
+
return JSONResponse(status_code=500,
|
167 |
+
content={"error": f"Error processing response: {str(e)}", "code": 500})
|
168 |
+
|
169 |
+
|
170 |
+
@preprocessing_router.post("/crop_transparent")
|
171 |
+
async def crop_transparent(image: UploadFile):
|
172 |
+
logger.info("-" * 50)
|
173 |
+
logger.info(">>> CROP TRANSPARENT STARTED <<<")
|
174 |
+
start_time = time.time()
|
175 |
+
|
176 |
+
try:
|
177 |
+
if not image.content_type == "image/png":
|
178 |
+
logger.error(">>> INVALID FILE TYPE: NOT PNG <<<")
|
179 |
+
return JSONResponse(status_code=400,
|
180 |
+
content={"error": "Only PNG files are supported", "code": 400})
|
181 |
+
except Exception as e:
|
182 |
+
logger.error(f">>> FILE TYPE CHECK ERROR: {str(e)} <<<")
|
183 |
+
return JSONResponse(status_code=500,
|
184 |
+
content={"error": f"Error checking file type: {str(e)}", "code": 500})
|
185 |
+
|
186 |
+
try:
|
187 |
+
contents = await image.read()
|
188 |
+
cropped_image_bytes, metadata = crop_transparent_image(contents)
|
189 |
+
logger.info(">>> IMAGE CROPPING COMPLETED <<<")
|
190 |
+
except Exception as e:
|
191 |
+
logger.error(f">>> IMAGE CROPPING ERROR: {str(e)} <<<")
|
192 |
+
return JSONResponse(status_code=500,
|
193 |
+
content={"error": f"Error cropping image: {str(e)}", "code": 500})
|
194 |
+
|
195 |
+
try:
|
196 |
+
base64_image = base64.b64encode(cropped_image_bytes).decode('utf-8')
|
197 |
+
base64_prefix = "data:image/png;base64,"
|
198 |
+
|
199 |
+
total_inference_time = round((time.time() - start_time), 2)
|
200 |
+
|
201 |
+
logger.info(">>> RESPONSE PREPARATION COMPLETED <<<")
|
202 |
+
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
203 |
+
logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<")
|
204 |
+
logger.info("-" * 50)
|
205 |
+
|
206 |
+
return JSONResponse(content={
|
207 |
+
"status": "success",
|
208 |
+
"code": 200,
|
209 |
+
"data": {
|
210 |
+
"image": f"{base64_prefix}{base64_image}",
|
211 |
+
"metadata": metadata,
|
212 |
+
"inference_time": total_inference_time
|
213 |
+
}
|
214 |
+
}, status_code=200)
|
215 |
+
except Exception as e:
|
216 |
+
logger.error(f">>> RESPONSE PROCESSING ERROR: {str(e)} <<<")
|
217 |
+
return JSONResponse(status_code=500,
|
218 |
+
content={"error": f"Error processing response: {str(e)}", "code": 500})
|
219 |
+
|
220 |
+
|
221 |
+
@preprocessing_router.post("/background_replace")
|
222 |
+
async def bg_replace(image: UploadFile = File(...), bg_image: UploadFile = File(...)):
|
223 |
+
logger.info("-" * 50)
|
224 |
+
logger.info(">>> BACKGROUND REPLACE STARTED <<<")
|
225 |
+
start_time = time.time()
|
226 |
+
|
227 |
+
try:
|
228 |
+
image_bytes = await image.read()
|
229 |
+
bg_bytes = await bg_image.read()
|
230 |
+
image = Image.open(BytesIO(image_bytes)).convert("RGBA")
|
231 |
+
bg_image = Image.open(BytesIO(bg_bytes)).convert("RGB")
|
232 |
+
logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
|
233 |
+
except Exception as e:
|
234 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
235 |
+
return JSONResponse(status_code=500,
|
236 |
+
content={"error": f"Error reading images: {str(e)}", "code": 500})
|
237 |
+
|
238 |
+
try:
|
239 |
+
width, height = bg_image.size
|
240 |
+
background = Image.fromarray(np.array(bg_image)).resize((width, height))
|
241 |
+
orig_img = Image.fromarray(np.array(image)).resize((width, height))
|
242 |
+
background.paste(orig_img, (0, 0), mask=orig_img)
|
243 |
+
logger.info(">>> IMAGE PROCESSING COMPLETED <<<")
|
244 |
+
except Exception as e:
|
245 |
+
logger.error(f">>> IMAGE PROCESSING ERROR: {str(e)} <<<")
|
246 |
+
return JSONResponse(status_code=500,
|
247 |
+
content={"error": f"Error processing images: {str(e)}", "code": 500})
|
248 |
+
|
249 |
+
try:
|
250 |
+
act_img_base_64 = BytesIO()
|
251 |
+
background.save(act_img_base_64, format="WEBP")
|
252 |
+
image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8")
|
253 |
+
image_data_uri = f"data:image/webp;base64,{image_bytes_}"
|
254 |
+
|
255 |
+
total_inference_time = round((time.time() - start_time), 2)
|
256 |
+
|
257 |
+
logger.info(">>> RESPONSE PREPARATION COMPLETED <<<")
|
258 |
+
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
259 |
+
logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<")
|
260 |
+
logger.info("-" * 50)
|
261 |
+
|
262 |
+
return JSONResponse(content={
|
263 |
+
"output": image_data_uri,
|
264 |
+
"code": 200,
|
265 |
+
"inference_time": total_inference_time
|
266 |
+
}, status_code=200)
|
267 |
+
except Exception as e:
|
268 |
+
logger.error(f">>> RESPONSE PROCESSING ERROR: {str(e)} <<<")
|
269 |
+
return JSONResponse(status_code=500,
|
270 |
+
content={"error": f"Error creating response: {str(e)}", "code": 500})
|
271 |
+
|
272 |
+
|
273 |
+
@preprocessing_router.post("/rem_bg_color_extraction")
|
274 |
+
async def remove_background_color_extraction(image: UploadFile = File(...),
|
275 |
+
hex_color: str = "#FFFFFF",
|
276 |
+
threshold: int = 30):
|
277 |
+
logger.info("-" * 50)
|
278 |
+
logger.info(">>> COLOR EXTRACTION STARTED <<<")
|
279 |
+
start_time = time.time()
|
280 |
+
|
281 |
+
try:
|
282 |
+
image_bytes = await image.read()
|
283 |
+
image = Image.open(BytesIO(image_bytes)).convert("RGBA")
|
284 |
+
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
285 |
+
logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<")
|
286 |
+
except Exception as e:
|
287 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
288 |
+
return JSONResponse(status_code=500,
|
289 |
+
content={"error": f"Error reading image: {str(e)}", "code": 500})
|
290 |
+
|
291 |
+
try:
|
292 |
+
result = color_extraction_rmbg.extract_color(image, hex_color, threshold)
|
293 |
+
result = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_RGB2BGRA)).convert("RGBA")
|
294 |
+
logger.info(">>> COLOR EXTRACTION COMPLETED <<<")
|
295 |
+
except Exception as e:
|
296 |
+
logger.error(f">>> COLOR EXTRACTION ERROR: {str(e)} <<<")
|
297 |
+
return JSONResponse(status_code=500,
|
298 |
+
content={"error": f"Error extracting colors: {str(e)}", "code": 500})
|
299 |
+
|
300 |
+
try:
|
301 |
+
act_img_base_64 = BytesIO()
|
302 |
+
result.save(act_img_base_64, format="PNG")
|
303 |
+
image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8")
|
304 |
+
image_data_uri = f"data:image/png;base64,{image_bytes_}"
|
305 |
+
|
306 |
+
total_inference_time = round((time.time() - start_time), 2)
|
307 |
+
|
308 |
+
logger.info(">>> RESPONSE PREPARATION COMPLETED <<<")
|
309 |
+
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
310 |
+
logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<")
|
311 |
+
logger.info("-" * 50)
|
312 |
+
|
313 |
+
return JSONResponse(content={
|
314 |
+
"output": image_data_uri,
|
315 |
+
"code": 200,
|
316 |
+
"inference_time": total_inference_time
|
317 |
+
}, status_code=200)
|
318 |
+
except Exception as e:
|
319 |
+
logger.error(f">>> RESPONSE PROCESSING ERROR: {str(e)} <<<")
|
320 |
+
return JSONResponse(status_code=500,
|
321 |
+
content={"error": f"Error creating response: {str(e)}", "code": 500})
|
322 |
+
|
323 |
+
|
324 |
+
@preprocessing_router.post("/title_description_generator")
|
325 |
+
async def product_title_description_generator(image: UploadFile = File(...)):
|
326 |
+
logger.info("-" * 50)
|
327 |
+
logger.info(">>> TITLE DESCRIPTION GENERATION STARTED <<<")
|
328 |
+
start_time = time.time()
|
329 |
+
|
330 |
+
try:
|
331 |
+
image_bytes = await image.read()
|
332 |
+
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
333 |
+
logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<")
|
334 |
+
except Exception as e:
|
335 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
336 |
+
return JSONResponse(status_code=500,
|
337 |
+
content={"error": f"Error reading image: {str(e)}", "code": 500})
|
338 |
+
|
339 |
+
try:
|
340 |
+
result = product_listing_obj.gen_title_desc(image=image)
|
341 |
+
title = result.split("Title:")[1].split("Description:")[0]
|
342 |
+
description = result.split("Description:")[1]
|
343 |
+
logger.info(">>> TITLE AND DESCRIPTION GENERATION COMPLETED <<<")
|
344 |
+
except Exception as e:
|
345 |
+
logger.error(">>> TITLE DESCRIPTION GENERATION ERROR <<<")
|
346 |
+
return JSONResponse(status_code=500,
|
347 |
+
content={"error": "Please make sure the image is clear and necklaces are visible",
|
348 |
+
"code": 500})
|
349 |
+
|
350 |
+
try:
|
351 |
+
total_inference_time = round((time.time() - start_time), 2)
|
352 |
+
|
353 |
+
logger.info(">>> RESPONSE PREPARATION COMPLETED <<<")
|
354 |
+
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
355 |
+
logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<")
|
356 |
+
logger.info("-" * 50)
|
357 |
+
|
358 |
+
return JSONResponse(content={
|
359 |
+
"code": 200,
|
360 |
+
"title": title,
|
361 |
+
"description": description,
|
362 |
+
"inference_time": total_inference_time
|
363 |
+
}, status_code=200)
|
364 |
+
except Exception as e:
|
365 |
+
logger.error(f">>> RESPONSE PROCESSING ERROR: {str(e)} <<<")
|
366 |
+
return JSONResponse(status_code=500,
|
367 |
+
content={"error": f"Error creating response: {str(e)}", "code": 500})
|
src/api/image_regeneration_api.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ NTO-TCP-HF
|
3 |
+
created @ 2024-10-29
|
4 |
+
author @ github.com/ishworrsubedii
|
5 |
+
"""
|
6 |
+
import base64
|
7 |
+
import time
|
8 |
+
from io import BytesIO
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import replicate
|
12 |
+
import requests
|
13 |
+
from PIL import Image
|
14 |
+
from fastapi import APIRouter, UploadFile, File, Form
|
15 |
+
from fastapi.responses import JSONResponse
|
16 |
+
from src.utils.logger import logger
|
17 |
+
|
18 |
+
image_regeneration_router = APIRouter()
|
19 |
+
|
20 |
+
|
21 |
+
def image_regeneration_replicate(input):
|
22 |
+
output = replicate.run(
|
23 |
+
"konieshadow/fooocus-api:fda927242b1db6affa1ece4f54c37f19b964666bf23b0d06ae2439067cd344a4",
|
24 |
+
input=input
|
25 |
+
)
|
26 |
+
return output
|
27 |
+
|
28 |
+
|
29 |
+
@image_regeneration_router.post("/image_redesign")
|
30 |
+
async def image_re_gen(
|
31 |
+
prompt: str = Form(...),
|
32 |
+
negative_prompt: str = Form(""),
|
33 |
+
image: UploadFile = File(...),
|
34 |
+
mask_image: Optional[UploadFile] = File(default=None),
|
35 |
+
reference_image_c1: Optional[UploadFile] = File(default=None),
|
36 |
+
reference_image_c1_type: Optional[str] = Form(default=""),
|
37 |
+
reference_image_c1_weight: Optional[float] = Form(default=0.0),
|
38 |
+
reference_image_c1_stop: Optional[float] = Form(default=0.0),
|
39 |
+
reference_image_c2: Optional[UploadFile] = File(default=None),
|
40 |
+
reference_image_c2_type: Optional[str] = Form(default=""),
|
41 |
+
reference_image_c2_weight: Optional[float] = Form(default=0.0),
|
42 |
+
reference_image_c2_stop: Optional[float] = Form(default=0.0),
|
43 |
+
reference_image_c3: Optional[UploadFile] = File(default=None),
|
44 |
+
reference_image_c3_type: Optional[str] = Form(default=""),
|
45 |
+
reference_image_c3_weight: Optional[float] = Form(default=0.0),
|
46 |
+
reference_image_c3_stop: Optional[float] = Form(default=0.0),
|
47 |
+
reference_image_c4: Optional[UploadFile] = File(default=None),
|
48 |
+
reference_image_c4_type: Optional[str] = Form(default=""),
|
49 |
+
reference_image_c4_weight: Optional[float] = Form(default=0.0),
|
50 |
+
reference_image_c4_stop: Optional[float] = Form(default=0.0),
|
51 |
+
):
|
52 |
+
logger.info("-" * 50)
|
53 |
+
logger.info(">>> IMAGE REDESIGN STARTED <<<")
|
54 |
+
start_time = time.time()
|
55 |
+
|
56 |
+
try:
|
57 |
+
async def process_reference_image(reference_image: Optional[UploadFile]) -> Optional[str]:
|
58 |
+
if reference_image is not None:
|
59 |
+
reference_image_bytes = await reference_image.read()
|
60 |
+
reference_image = Image.open(BytesIO(reference_image_bytes)).convert("RGB")
|
61 |
+
ref_img_base64 = BytesIO()
|
62 |
+
reference_image.save(ref_img_base64, format="WEBP")
|
63 |
+
reference_image_b64 = base64.b64encode(ref_img_base64.getvalue()).decode("utf-8")
|
64 |
+
return f"data:image/WEBP;base64,{reference_image_b64}"
|
65 |
+
return None
|
66 |
+
logger.info(">>> REFERENCE IMAGE PROCESSING FUNCTION INITIALIZED <<<")
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(f">>> REFERENCE IMAGE PROCESSING ERROR: {str(e)} <<<")
|
69 |
+
return JSONResponse(status_code=500,
|
70 |
+
content={"error": f"Error processing reference image: {str(e)}", "code": 500})
|
71 |
+
|
72 |
+
try:
|
73 |
+
image_bytes = await image.read()
|
74 |
+
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
75 |
+
img_base64 = BytesIO()
|
76 |
+
image.save(img_base64, format="WEBP")
|
77 |
+
image_data_uri = f"data:image/WEBP;base64,{base64.b64encode(img_base64.getvalue()).decode('utf-8')}"
|
78 |
+
logger.info(">>> MAIN IMAGE PROCESSED SUCCESSFULLY <<<")
|
79 |
+
except Exception as e:
|
80 |
+
logger.error(f">>> MAIN IMAGE PROCESSING ERROR: {str(e)} <<<")
|
81 |
+
return JSONResponse(status_code=500,
|
82 |
+
content={"error": f"Error processing main image: {str(e)}", "code": 500})
|
83 |
+
|
84 |
+
try:
|
85 |
+
reference_images = {
|
86 |
+
'c1': await process_reference_image(reference_image_c1),
|
87 |
+
'c2': await process_reference_image(reference_image_c2),
|
88 |
+
'c3': await process_reference_image(reference_image_c3),
|
89 |
+
'c4': await process_reference_image(reference_image_c4)
|
90 |
+
}
|
91 |
+
logger.info(">>> REFERENCE IMAGES PROCESSED SUCCESSFULLY <<<")
|
92 |
+
except Exception as e:
|
93 |
+
logger.error(f">>> REFERENCE IMAGES PROCESSING ERROR: {str(e)} <<<")
|
94 |
+
return JSONResponse(status_code=500,
|
95 |
+
content={"error": f"Error processing reference images: {str(e)}", "code": 500})
|
96 |
+
|
97 |
+
try:
|
98 |
+
input_data = {
|
99 |
+
"prompt": prompt,
|
100 |
+
"inpaint_input_image": image_data_uri,
|
101 |
+
"sharpness": 2,
|
102 |
+
"guidance_scale": 4,
|
103 |
+
"refiner_switch": 0.5,
|
104 |
+
"performance_selection": "Quality",
|
105 |
+
"aspect_ratios_selection": "1024*1024"
|
106 |
+
}
|
107 |
+
|
108 |
+
if negative_prompt:
|
109 |
+
input_data["negative_prompt"] = negative_prompt
|
110 |
+
|
111 |
+
if mask_image is not None:
|
112 |
+
mask_image_bytes = await mask_image.read()
|
113 |
+
mask_image = Image.open(BytesIO(mask_image_bytes)).convert("RGB")
|
114 |
+
mask_base64 = BytesIO()
|
115 |
+
mask_image.save(mask_base64, format="WEBP")
|
116 |
+
mask_image_data_uri = f"data:image/WEBP;base64,{base64.b64encode(mask_base64.getvalue()).decode('utf-8')}"
|
117 |
+
input_data["inpaint_input_mask"] = mask_image_data_uri
|
118 |
+
logger.info(">>> INPUT DATA PREPARED SUCCESSFULLY <<<")
|
119 |
+
except Exception as e:
|
120 |
+
logger.error(f">>> INPUT DATA PREPARATION ERROR: {str(e)} <<<")
|
121 |
+
return JSONResponse(status_code=500,
|
122 |
+
content={"error": f"Error preparing input data: {str(e)}", "code": 500})
|
123 |
+
|
124 |
+
try:
|
125 |
+
for i in range(1, 5):
|
126 |
+
c = f'c{i}'
|
127 |
+
if reference_images[c] is not None:
|
128 |
+
input_data[f"cn_img{i}"] = reference_images[c]
|
129 |
+
|
130 |
+
type_value = locals()[f'reference_image_{c}_type']
|
131 |
+
if type_value:
|
132 |
+
input_data[f"cn_type{i}"] = type_value
|
133 |
+
|
134 |
+
weight_value = locals()[f'reference_image_{c}_weight']
|
135 |
+
if weight_value != 0.0:
|
136 |
+
input_data[f"cn_weight{i}"] = weight_value
|
137 |
+
|
138 |
+
stop_value = locals()[f'reference_image_{c}_stop']
|
139 |
+
if stop_value != 0.0 or stop_value != 0:
|
140 |
+
input_data[f"cn_stop{i}"] = stop_value
|
141 |
+
logger.info(">>> REFERENCE IMAGE PARAMETERS PROCESSED <<<")
|
142 |
+
except Exception as e:
|
143 |
+
logger.error(f">>> REFERENCE IMAGE PARAMETERS ERROR: {str(e)} <<<")
|
144 |
+
return JSONResponse(status_code=500,
|
145 |
+
content={"error": f"Error processing reference image parameters: {str(e)}", "code": 500})
|
146 |
+
|
147 |
+
try:
|
148 |
+
output = image_regeneration_replicate(input_data)
|
149 |
+
response = requests.get(output[0])
|
150 |
+
output_base64 = base64.b64encode(response.content).decode('utf-8')
|
151 |
+
base64_prefix = image_data_uri.split(",")[0] + ","
|
152 |
+
logger.info(">>> IMAGE REGENERATION COMPLETED <<<")
|
153 |
+
except Exception as e:
|
154 |
+
logger.error(f">>> IMAGE REGENERATION ERROR: {str(e)} <<<")
|
155 |
+
return JSONResponse(status_code=500,
|
156 |
+
content={"error": f"Error generating image: {str(e)}", "code": 500})
|
157 |
+
|
158 |
+
try:
|
159 |
+
inference_time = round(time.time() - start_time, 2)
|
160 |
+
response = {
|
161 |
+
"output": f"{base64_prefix}{output_base64}",
|
162 |
+
"inference_time": inference_time,
|
163 |
+
"code": 200,
|
164 |
+
}
|
165 |
+
logger.info(f">>> TOTAL INFERENCE TIME: {inference_time}s <<<")
|
166 |
+
logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<")
|
167 |
+
logger.info("-" * 50)
|
168 |
+
return JSONResponse(content=response, status_code=200)
|
169 |
+
except Exception as e:
|
170 |
+
logger.error(f">>> RESPONSE CREATION ERROR: {str(e)} <<<")
|
171 |
+
return JSONResponse(status_code=500,
|
172 |
+
content={"error": f"Error creating response: {str(e)}", "code": 500})
|
src/api/makeup_tryon_api.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ NTO-TCP-HF
|
3 |
+
created @ 2024-12-05
|
4 |
+
author @ github.com/ishworrsubedii
|
5 |
+
"""
|
6 |
+
import asyncio
|
7 |
+
import gc
|
8 |
+
from io import BytesIO
|
9 |
+
from time import time
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import requests
|
14 |
+
from PIL import Image
|
15 |
+
from fastapi import APIRouter, File, Form, UploadFile
|
16 |
+
from fastapi.responses import JSONResponse
|
17 |
+
from src.api.nto_api import pipeline, supabase_upload_and_return_url
|
18 |
+
from src.utils.logger import logger
|
19 |
+
|
20 |
+
makeup_tryon_router = APIRouter(tags=["makeup_tryon"], prefix="/makeup_tryon")
|
21 |
+
|
22 |
+
|
23 |
+
def get_colors():
|
24 |
+
logger.info(">>> FETCHING COLORS FROM BACKEND <<<")
|
25 |
+
start_time = time()
|
26 |
+
try:
|
27 |
+
response = requests.get(
|
28 |
+
"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/JewelMirrorMakeupTryOn/jewelmirror_makeup_colors.json"
|
29 |
+
)
|
30 |
+
response.raise_for_status()
|
31 |
+
colors = response.json()
|
32 |
+
fetch_time = round(time() - start_time, 2)
|
33 |
+
logger.info(f">>> COLORS FETCHED SUCCESSFULLY IN {fetch_time}s <<<")
|
34 |
+
return colors
|
35 |
+
except requests.exceptions.RequestException as e:
|
36 |
+
logger.error(f">>> ERROR FETCHING COLORS: {str(e)} <<<")
|
37 |
+
return None
|
38 |
+
|
39 |
+
|
40 |
+
@makeup_tryon_router.post("/tryon")
|
41 |
+
async def tryon(image: UploadFile = File(...), lipstick_color: str = Form(...), eyeliner_color: str = Form(...),
|
42 |
+
eyeshadow_color: str = Form(...)):
|
43 |
+
logger.info("-" * 50)
|
44 |
+
logger.info(">>> MAKE UP TRY ON STARTED <<<")
|
45 |
+
logger.info(
|
46 |
+
f">>> LIPSTICK COLOR: {lipstick_color}, EYELINER COLOR: {eyeliner_color}, EYESHADOW COLOR: {eyeshadow_color} <<<")
|
47 |
+
start_time = time()
|
48 |
+
|
49 |
+
colors = get_colors()
|
50 |
+
if colors is None:
|
51 |
+
return JSONResponse(status_code=500, content={"error": "Error fetching colors", "code": 500})
|
52 |
+
|
53 |
+
# Validate and fetch colors
|
54 |
+
try:
|
55 |
+
lipstick_set = tuple(colors["Lipstick"][lipstick_color])
|
56 |
+
eyeliner_set = tuple(colors["Eyeliner"][eyeliner_color])
|
57 |
+
eyeshadow_set = tuple(colors["Eyeshadow"][eyeshadow_color])
|
58 |
+
logger.info(
|
59 |
+
f">>> COLORS FETCHED SUCCESSFULLY LIPSTICK:{lipstick_color}|{lipstick_set} EYELINER:{eyeliner_color}|{eyeliner_set} EYESHADOW:{eyeshadow_color}|{eyeshadow_set} <<<")
|
60 |
+
except KeyError as e:
|
61 |
+
logger.error(f"Invalid color selection: {str(e)}")
|
62 |
+
return JSONResponse(status_code=400, content={"error": f"Invalid color: {str(e)}", "code": 400})
|
63 |
+
|
64 |
+
try:
|
65 |
+
image_bytes = await image.read()
|
66 |
+
image = Image.open(BytesIO(image_bytes))
|
67 |
+
image = np.array(image)
|
68 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
69 |
+
logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<")
|
70 |
+
except Exception as e:
|
71 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
72 |
+
return JSONResponse(status_code=500, content={"error": "Error reading image", "code": 500})
|
73 |
+
|
74 |
+
try:
|
75 |
+
result = await pipeline.makeup_tryon(image, lipstick_set, eyeliner_set, eyeshadow_set)
|
76 |
+
logger.info(">>> MAKE UP TRY ON COMPLETED <<<")
|
77 |
+
except Exception as e:
|
78 |
+
logger.error(f">>> MAKE UP TRY ON PROCESSING ERROR: {str(e)} <<<")
|
79 |
+
return JSONResponse(status_code=500, content={"error": "Error during makeup try-on process", "code": 500})
|
80 |
+
|
81 |
+
try:
|
82 |
+
# Convert numpy array to PIL Image
|
83 |
+
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
|
84 |
+
|
85 |
+
upload_start_time = time()
|
86 |
+
upload_tasks = [
|
87 |
+
supabase_upload_and_return_url(prefix="MTO", image=result_image,necklace_id="JEWELMIRRORMAKEUPTRYON")
|
88 |
+
]
|
89 |
+
result_url = await asyncio.gather(*upload_tasks)
|
90 |
+
upload_time = round(time() - upload_start_time, 2)
|
91 |
+
|
92 |
+
if result_url[0] is None:
|
93 |
+
raise Exception("Failed to upload result image")
|
94 |
+
|
95 |
+
logger.info(f">>> RESULT IMAGES SAVED IN {upload_time}s <<<")
|
96 |
+
except Exception as e:
|
97 |
+
logger.error(f">>> RESULT SAVING ERROR: {str(e)} <<<")
|
98 |
+
return JSONResponse(content={"error": f"Error saving result images", "code": 500}, status_code=500)
|
99 |
+
|
100 |
+
try:
|
101 |
+
total_time = round(time() - start_time, 2)
|
102 |
+
response = {
|
103 |
+
"code": 200,
|
104 |
+
"output": f"{result_url[0]}",
|
105 |
+
"timing": {
|
106 |
+
"upload": upload_time,
|
107 |
+
"total": total_time
|
108 |
+
}
|
109 |
+
}
|
110 |
+
|
111 |
+
logger.info(f">>> TIMING BREAKDOWN <<<")
|
112 |
+
logger.info(f"Upload Time: {upload_time}s")
|
113 |
+
logger.info(f"Total Time: {total_time}s")
|
114 |
+
logger.info(">>> MAKE UP TRY ON COMPLETED SUCCESSFULLY <<<")
|
115 |
+
logger.info("-" * 50)
|
116 |
+
|
117 |
+
return JSONResponse(content=response, status_code=200)
|
118 |
+
|
119 |
+
except Exception as e:
|
120 |
+
logger.error(f">>> RESPONSE GENERATION ERROR: {str(e)} <<<")
|
121 |
+
return JSONResponse(content={"error": f"Error generating response", "code": 500}, status_code=500)
|
122 |
+
|
123 |
+
finally:
|
124 |
+
if 'result' in locals(): del result
|
125 |
+
gc.collect()
|
126 |
+
|
127 |
+
|
128 |
+
def bgr_to_hex(bgr):
|
129 |
+
"""Convert BGR to HEX."""
|
130 |
+
b, g, r = bgr
|
131 |
+
return f"#{r:02x}{g:02x}{b:02x}"
|
132 |
+
|
133 |
+
|
134 |
+
@makeup_tryon_router.get("/available_colors")
|
135 |
+
async def color_fetch():
|
136 |
+
logger.info(">>> FETCHING AVAILABLE COLORS <<<")
|
137 |
+
start_time = time()
|
138 |
+
|
139 |
+
colors = get_colors()
|
140 |
+
if colors is None:
|
141 |
+
return JSONResponse(status_code=500, content={"error": "Error fetching colors", "code": 500})
|
142 |
+
|
143 |
+
try:
|
144 |
+
colors_with_hex = {
|
145 |
+
category: {
|
146 |
+
name: {
|
147 |
+
"bgr": bgr,
|
148 |
+
"hex": bgr_to_hex(bgr)
|
149 |
+
}
|
150 |
+
for name, bgr in items.items()
|
151 |
+
}
|
152 |
+
for category, items in colors.items()
|
153 |
+
}
|
154 |
+
|
155 |
+
fetch_time = round(time() - start_time, 2)
|
156 |
+
logger.info(f">>> COLORS FETCHED SUCCESSFULLY IN {fetch_time}s <<<")
|
157 |
+
|
158 |
+
return JSONResponse(content={
|
159 |
+
"code": 200,
|
160 |
+
"colors": colors_with_hex,
|
161 |
+
"timing": {
|
162 |
+
"total": fetch_time
|
163 |
+
}
|
164 |
+
})
|
165 |
+
|
166 |
+
except Exception as e:
|
167 |
+
logger.error(f">>> ERROR PROCESSING COLORS: {str(e)} <<<")
|
168 |
+
return JSONResponse(
|
169 |
+
status_code=500,
|
170 |
+
content={"error": "Error processing color data", "code": 500}
|
171 |
+
)
|
src/api/mannequin_to_model_api.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ CTO_TCP_ZERO_GPU
|
3 |
+
created @ 2024-11-14
|
4 |
+
author @ github.com/ishworrsubedii
|
5 |
+
"""
|
6 |
+
import base64
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
|
11 |
+
import requests
|
12 |
+
from fastapi.routing import APIRouter
|
13 |
+
from fastapi import File, UploadFile, Form
|
14 |
+
import replicate
|
15 |
+
from starlette.responses import JSONResponse
|
16 |
+
from src.api.nto_api import supabase
|
17 |
+
from src.utils.logger import logger
|
18 |
+
|
19 |
+
mto_router = APIRouter()
|
20 |
+
|
21 |
+
|
22 |
+
def run_mto(input):
|
23 |
+
try:
|
24 |
+
logger.info("Starting mannequin to model conversion")
|
25 |
+
output = replicate.run(
|
26 |
+
"xiankgx/face-swap:cff87316e31787df12002c9e20a78a017a36cb31fde9862d8dedd15ab29b7288",
|
27 |
+
input=input
|
28 |
+
)
|
29 |
+
logger.info("Mannequin to model conversion completed successfully")
|
30 |
+
return output
|
31 |
+
except Exception as e:
|
32 |
+
logger.error(f"Error in mannequin to model conversion: {str(e)}")
|
33 |
+
return None
|
34 |
+
|
35 |
+
|
36 |
+
def read_return(url):
|
37 |
+
try:
|
38 |
+
res = requests.get(url)
|
39 |
+
logger.info("Image fetched successfully")
|
40 |
+
return res.content
|
41 |
+
except Exception as e:
|
42 |
+
logger.error(f"Error fetching image: {str(e)}")
|
43 |
+
return None
|
44 |
+
|
45 |
+
|
46 |
+
@mto_router.post("/mto_image")
|
47 |
+
async def mto_image(image: UploadFile = File(...), store_name: str = Form(...),
|
48 |
+
clothing_category: str = Form(...),
|
49 |
+
product_id: str = Form(...),
|
50 |
+
body_structure: str = Form(...),
|
51 |
+
skin_complexion: str = Form(...),
|
52 |
+
facial_structure: str = Form(...), ):
|
53 |
+
start_time = time.time()
|
54 |
+
try:
|
55 |
+
|
56 |
+
logger.info(f"Starting MTO image process for store: {store_name}")
|
57 |
+
|
58 |
+
if body_structure == "medium":
|
59 |
+
body_structure = "fat"
|
60 |
+
logger.info("Body structure adjusted from 'medium' to 'fat'")
|
61 |
+
|
62 |
+
image_bytes = await image.read()
|
63 |
+
logger.info("Source image read successfully")
|
64 |
+
|
65 |
+
mannequin_image_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/ClothingTryOn/{store_name}/{clothing_category}/{product_id}/{product_id}_{skin_complexion}_{facial_structure}_{body_structure}.webp"
|
66 |
+
logger.info(f"Fetching mannequin image")
|
67 |
+
|
68 |
+
reference_image_bytes = read_return(mannequin_image_url)
|
69 |
+
if reference_image_bytes is None:
|
70 |
+
logger.error("Failed to fetch reference image")
|
71 |
+
return JSONResponse({"error": "Failed to fetch reference image"}, status_code=500)
|
72 |
+
|
73 |
+
image_uri = f"data:image/jpeg;base64,{base64.b64encode(image_bytes).decode()}"
|
74 |
+
reference_image_uri = f"data:image/jpeg;base64,{base64.b64encode(reference_image_bytes).decode()}"
|
75 |
+
|
76 |
+
input = {
|
77 |
+
"local_source": image_uri,
|
78 |
+
"local_target": reference_image_uri
|
79 |
+
}
|
80 |
+
|
81 |
+
output = run_mto(input)
|
82 |
+
if output is None:
|
83 |
+
logger.error("Face swap process failed")
|
84 |
+
return JSONResponse({"error": "Face swap process failed"}, status_code=500)
|
85 |
+
|
86 |
+
try:
|
87 |
+
response = requests.get(str(output['image']))
|
88 |
+
image_content = response.content
|
89 |
+
|
90 |
+
base64_image = base64.b64encode(image_content).decode('utf-8')
|
91 |
+
|
92 |
+
logger.info("MTO image process completed successfully")
|
93 |
+
return JSONResponse(content={
|
94 |
+
"output": f"data:image/webp;base64,{base64_image}",
|
95 |
+
"status": "success",
|
96 |
+
"inference_time": round((time.time() - start_time), 2)
|
97 |
+
}, status_code=200)
|
98 |
+
|
99 |
+
except Exception as e:
|
100 |
+
logger.error(f"Error converting output to base64: {str(e)}")
|
101 |
+
return JSONResponse({"error": "Error processing output image"}, status_code=500)
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
logger.error(f"Error in MTO image process: {str(e)}")
|
105 |
+
return JSONResponse({"error": str(e)}, status_code=500)
|
106 |
+
|
107 |
+
|
108 |
+
@mto_router.get("/mannequin_catalogue")
|
109 |
+
async def returnJsonData(gender: str):
|
110 |
+
try:
|
111 |
+
logger.info(f"Fetching mannequin catalogue for gender: {gender}")
|
112 |
+
|
113 |
+
folderImageURL = supabase.storage.get_bucket("JSON").create_signed_url(
|
114 |
+
path=os.path.join("MannequinInfo.json"),
|
115 |
+
expires_in=3600
|
116 |
+
)["signedURL"]
|
117 |
+
|
118 |
+
logger.info("Fetching JSON data from Supabase")
|
119 |
+
r = requests.get(folderImageURL).content.decode()
|
120 |
+
mannequin_data = json.loads(r)
|
121 |
+
|
122 |
+
if gender.lower() == "female":
|
123 |
+
res = [item for item in mannequin_data if item["gender"] == "female"]
|
124 |
+
elif gender.lower() == "male":
|
125 |
+
res = [item for item in mannequin_data if item["gender"] == "male"]
|
126 |
+
else:
|
127 |
+
res = []
|
128 |
+
logger.warning(f"Invalid gender parameter: {gender}")
|
129 |
+
|
130 |
+
logger.info(f"Successfully retrieved {len(res)} mannequin entries")
|
131 |
+
return res
|
132 |
+
|
133 |
+
except Exception as e:
|
134 |
+
logger.error(f"Error in mannequin catalogue: {str(e)}")
|
135 |
+
return JSONResponse({"error": str(e)}, status_code=500)
|
src/api/nto_api.py
ADDED
@@ -0,0 +1,911 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ NTO-TCP-HF
|
3 |
+
created @ 2024-10-28
|
4 |
+
author @ github/ishworrsubedii
|
5 |
+
"""
|
6 |
+
import asyncio
|
7 |
+
import base64
|
8 |
+
import gc
|
9 |
+
import os
|
10 |
+
import secrets
|
11 |
+
import time
|
12 |
+
from io import BytesIO
|
13 |
+
|
14 |
+
import aiohttp
|
15 |
+
import cv2
|
16 |
+
import numpy as np
|
17 |
+
import replicate
|
18 |
+
import requests
|
19 |
+
from PIL import Image
|
20 |
+
from PIL.ImageOps import grayscale
|
21 |
+
from fastapi import File, UploadFile, HTTPException, Form, Depends, APIRouter
|
22 |
+
from fastapi.responses import JSONResponse
|
23 |
+
from pydantic import BaseModel
|
24 |
+
from supabase import create_client
|
25 |
+
|
26 |
+
from src.utils import deductAndTrackCredit, returnBytesData
|
27 |
+
from src.utils.logger import logger
|
28 |
+
|
29 |
+
nto_cto_router = APIRouter()
|
30 |
+
|
31 |
+
url: str = os.getenv("SUPABASE_URL")
|
32 |
+
key: str = os.getenv("SUPABASE_KEY")
|
33 |
+
|
34 |
+
supabase = create_client(supabase_url=url, supabase_key=key)
|
35 |
+
supabase_storage: str = os.getenv("SUPABASE_STORAGE")
|
36 |
+
|
37 |
+
from src.pipelines.completePipeline import Pipeline
|
38 |
+
|
39 |
+
pipeline = Pipeline()
|
40 |
+
|
41 |
+
cto_replicate: str = os.getenv(
|
42 |
+
"CTO")
|
43 |
+
|
44 |
+
bucket = supabase.storage.from_("JewelMirrorOutputs")
|
45 |
+
|
46 |
+
|
47 |
+
def replicate_run_cto(input):
|
48 |
+
output = replicate.run(
|
49 |
+
cto_replicate,
|
50 |
+
input=input)
|
51 |
+
return output
|
52 |
+
|
53 |
+
|
54 |
+
class NecklaceTryOnIDEntity(BaseModel):
|
55 |
+
necklaceImageId: str
|
56 |
+
necklaceCategory: str
|
57 |
+
storename: str
|
58 |
+
offset_x: float
|
59 |
+
offset_y: float
|
60 |
+
|
61 |
+
|
62 |
+
@nto_cto_router.post("/clothingTryOnV2")
|
63 |
+
async def clothing_try_on_v2(image: UploadFile = File(...), clothing_type: str = Form(...)):
|
64 |
+
logger.info("-" * 50)
|
65 |
+
logger.info(">>> CLOTHING TRY ON V2 STARTED <<<")
|
66 |
+
logger.info(f"Parameters: clothing_type={clothing_type}")
|
67 |
+
start_time = time.time()
|
68 |
+
|
69 |
+
try:
|
70 |
+
image_bytes = await image.read()
|
71 |
+
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
72 |
+
logger.info(">>> IMAGE LOADED SUCCESSFULLY <<<")
|
73 |
+
except Exception as e:
|
74 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
75 |
+
return JSONResponse(status_code=500, content={"error": f"Error reading image", "code": 500})
|
76 |
+
|
77 |
+
try:
|
78 |
+
mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=image)
|
79 |
+
logger.info(">>> MASK GENERATION COMPLETED <<<")
|
80 |
+
except Exception as e:
|
81 |
+
logger.error(f">>> MASK GENERATION ERROR: {str(e)} <<<")
|
82 |
+
return JSONResponse(status_code=500,
|
83 |
+
content={"error": f"Error generating mask", "code": 500})
|
84 |
+
|
85 |
+
try:
|
86 |
+
mask_img_base_64, act_img_base_64 = BytesIO(), BytesIO()
|
87 |
+
mask.save(mask_img_base_64, format="WEBP")
|
88 |
+
image.save(act_img_base_64, format="WEBP")
|
89 |
+
mask_bytes_ = base64.b64encode(mask_img_base_64.getvalue()).decode("utf-8")
|
90 |
+
image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8")
|
91 |
+
|
92 |
+
mask_data_uri = f"data:image/webp;base64,{mask_bytes_}"
|
93 |
+
image_data_uri = f"data:image/webp;base64,{image_bytes_}"
|
94 |
+
logger.info(">>> IMAGE ENCODING COMPLETED <<<")
|
95 |
+
except Exception as e:
|
96 |
+
logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
|
97 |
+
return JSONResponse(status_code=500,
|
98 |
+
content={"error": f"Error converting images to base64", "code": 500})
|
99 |
+
|
100 |
+
input = {
|
101 |
+
"mask": mask_data_uri,
|
102 |
+
"image": image_data_uri,
|
103 |
+
"prompt": f"Dull {clothing_type}, non-reflective clothing, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple, perfect eyes, perfect face, perfect body, high quality, realistic, photorealistic, high resolution,traditional full sleeve blouse",
|
104 |
+
"negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
|
105 |
+
"num_inference_steps": 25
|
106 |
+
}
|
107 |
+
|
108 |
+
try:
|
109 |
+
output = replicate_run_cto(input)
|
110 |
+
logger.info(">>> REPLICATE PROCESSING COMPLETED <<<")
|
111 |
+
except Exception as e:
|
112 |
+
logger.error(f">>> REPLICATE PROCESSING ERROR: {str(e)} <<<")
|
113 |
+
return JSONResponse(content={"error": f"Error running CTO Replicate", "code": 500}, status_code=500)
|
114 |
+
|
115 |
+
total_inference_time = round((time.time() - start_time), 2)
|
116 |
+
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
117 |
+
logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<")
|
118 |
+
logger.info("-" * 50)
|
119 |
+
|
120 |
+
response = {
|
121 |
+
"code": 200,
|
122 |
+
"output": f"{output[0]}",
|
123 |
+
"inference_time": total_inference_time
|
124 |
+
}
|
125 |
+
|
126 |
+
return JSONResponse(content=response, status_code=200)
|
127 |
+
|
128 |
+
|
129 |
+
@nto_cto_router.post("/clothingTryOn")
|
130 |
+
async def clothing_try_on(image: UploadFile = File(...),
|
131 |
+
mask: UploadFile = File(...), clothing_type: str = Form(...)):
|
132 |
+
logger.info("-" * 50)
|
133 |
+
logger.info(">>> CLOTHING TRY ON STARTED <<<")
|
134 |
+
logger.info(f"Parameters: clothing_type={clothing_type}")
|
135 |
+
start_time = time.time()
|
136 |
+
|
137 |
+
try:
|
138 |
+
image_bytes = await image.read()
|
139 |
+
mask_bytes = await mask.read()
|
140 |
+
image, mask = Image.open(BytesIO(image_bytes)).convert("RGB"), Image.open(BytesIO(mask_bytes)).convert("RGB")
|
141 |
+
logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
|
142 |
+
except Exception as e:
|
143 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
144 |
+
return JSONResponse(status_code=500, content={"error": f"Error reading image or mask", "code": 500})
|
145 |
+
|
146 |
+
try:
|
147 |
+
actual_image = image.copy()
|
148 |
+
jewellery_mask = Image.fromarray(np.bitwise_and(np.array(mask), np.array(image)))
|
149 |
+
arr_orig = np.array(grayscale(mask))
|
150 |
+
|
151 |
+
image = cv2.inpaint(np.array(image), arr_orig, 15, cv2.INPAINT_TELEA)
|
152 |
+
image = Image.fromarray(image).resize((512, 512))
|
153 |
+
|
154 |
+
arr = arr_orig.copy()
|
155 |
+
mask_y = np.where(arr == arr[arr != 0][0])[0][0]
|
156 |
+
arr[mask_y:, :] = 255
|
157 |
+
|
158 |
+
mask = Image.fromarray(arr).resize((512, 512))
|
159 |
+
logger.info(">>> IMAGE PROCESSING COMPLETED <<<")
|
160 |
+
except Exception as e:
|
161 |
+
logger.error(f">>> IMAGE PROCESSING ERROR: {str(e)} <<<")
|
162 |
+
return JSONResponse(status_code=500,
|
163 |
+
content={"error": f"Error processing image or mask", "code": 500})
|
164 |
+
|
165 |
+
try:
|
166 |
+
mask_img_base_64, act_img_base_64 = BytesIO(), BytesIO()
|
167 |
+
mask.save(mask_img_base_64, format="WEBP")
|
168 |
+
image.save(act_img_base_64, format="WEBP")
|
169 |
+
mask_bytes_ = base64.b64encode(mask_img_base_64.getvalue()).decode("utf-8")
|
170 |
+
image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8")
|
171 |
+
|
172 |
+
mask_data_uri = f"data:image/webp;base64,{mask_bytes_}"
|
173 |
+
image_data_uri = f"data:image/webp;base64,{image_bytes_}"
|
174 |
+
logger.info(">>> IMAGE ENCODING COMPLETED <<<")
|
175 |
+
except Exception as e:
|
176 |
+
logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
|
177 |
+
return JSONResponse(status_code=500,
|
178 |
+
content={"error": f"Error encoding images", "code": 500})
|
179 |
+
|
180 |
+
input = {
|
181 |
+
"mask": mask_data_uri,
|
182 |
+
"image": image_data_uri,
|
183 |
+
"prompt": f"Dull {clothing_type}, non-reflective clothing, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple, perfect eyes, perfect face, perfect body, high quality, realistic, photorealistic, high resolution,traditional full sleeve blouse",
|
184 |
+
"negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
|
185 |
+
"num_inference_steps": 25
|
186 |
+
}
|
187 |
+
|
188 |
+
try:
|
189 |
+
output = replicate_run_cto(input)
|
190 |
+
logger.info(">>> REPLICATE PROCESSING COMPLETED <<<")
|
191 |
+
except Exception as e:
|
192 |
+
logger.error(f">>> REPLICATE PROCESSING ERROR: {str(e)} <<<")
|
193 |
+
return JSONResponse(content={"error": f"Error running clothing try on", "code": 500}, status_code=500)
|
194 |
+
|
195 |
+
try:
|
196 |
+
response = requests.get(output[0])
|
197 |
+
output_image = Image.open(BytesIO(response.content)).resize(actual_image.size)
|
198 |
+
output_image = np.bitwise_and(np.array(output_image),
|
199 |
+
np.bitwise_not(np.array(Image.fromarray(arr_orig).convert("RGB"))))
|
200 |
+
result = Image.fromarray(np.bitwise_or(np.array(output_image), np.array(jewellery_mask)))
|
201 |
+
|
202 |
+
in_mem_file = BytesIO()
|
203 |
+
result.save(in_mem_file, format="WEBP", quality=85)
|
204 |
+
base_64_output = base64.b64encode(in_mem_file.getvalue()).decode('utf-8')
|
205 |
+
total_inference_time = round((time.time() - start_time), 2)
|
206 |
+
logger.info(">>> OUTPUT IMAGE PROCESSING COMPLETED <<<")
|
207 |
+
|
208 |
+
response = {
|
209 |
+
"output": f"data:image/WEBP;base64,{base_64_output}",
|
210 |
+
"code": 200,
|
211 |
+
"inference_time": total_inference_time
|
212 |
+
}
|
213 |
+
except Exception as e:
|
214 |
+
logger.error(f">>> OUTPUT IMAGE PROCESSING ERROR: {str(e)} <<<")
|
215 |
+
return JSONResponse(status_code=500, content={"error": f"Error processing output image", "code": 500})
|
216 |
+
|
217 |
+
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
218 |
+
logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<")
|
219 |
+
logger.info("-" * 50)
|
220 |
+
|
221 |
+
return JSONResponse(content=response, status_code=200)
|
222 |
+
|
223 |
+
|
224 |
+
async def parse_necklace_try_on_id(necklaceImageId: str = Form(...),
|
225 |
+
necklaceCategory: str = Form(...),
|
226 |
+
storename: str = Form(...),
|
227 |
+
offset_x: float = Form(...),
|
228 |
+
offset_y: float = Form(...)
|
229 |
+
) -> NecklaceTryOnIDEntity:
|
230 |
+
return NecklaceTryOnIDEntity(
|
231 |
+
necklaceImageId=necklaceImageId,
|
232 |
+
necklaceCategory=necklaceCategory,
|
233 |
+
storename=storename,
|
234 |
+
offset_x=offset_x,
|
235 |
+
offset_y=offset_y
|
236 |
+
)
|
237 |
+
|
238 |
+
|
239 |
+
async def supabase_upload_and_return_url(prefix: str, necklace_id: str, image: Image.Image, quality: int = 85):
|
240 |
+
try:
|
241 |
+
filename = f"{prefix}_{necklace_id}_{secrets.token_hex(8)}.webp"
|
242 |
+
|
243 |
+
loop = asyncio.get_event_loop()
|
244 |
+
image_bytes = await loop.run_in_executor(
|
245 |
+
None,
|
246 |
+
process_image,
|
247 |
+
image,
|
248 |
+
quality
|
249 |
+
)
|
250 |
+
|
251 |
+
async with aiohttp.ClientSession() as session:
|
252 |
+
headers = {
|
253 |
+
"Authorization": f"Bearer {key}",
|
254 |
+
"Content-Type": "image/webp"
|
255 |
+
}
|
256 |
+
|
257 |
+
upload_url = f"{url}/storage/v1/object/JewelMirrorOutputs/{filename}"
|
258 |
+
|
259 |
+
async with session.post(
|
260 |
+
upload_url,
|
261 |
+
data=image_bytes,
|
262 |
+
headers=headers
|
263 |
+
) as response:
|
264 |
+
if response.status != 200:
|
265 |
+
raise Exception(f"Upload failed with status {response.status}")
|
266 |
+
|
267 |
+
return bucket.get_public_url(filename)
|
268 |
+
|
269 |
+
except Exception as e:
|
270 |
+
logger.error(f"Failed to upload image: {str(e)}")
|
271 |
+
return None
|
272 |
+
|
273 |
+
|
274 |
+
def process_image(image: Image.Image, quality: int) -> bytes:
|
275 |
+
try:
|
276 |
+
if image.mode in ['RGBA', 'P']:
|
277 |
+
image = image.convert('RGB')
|
278 |
+
|
279 |
+
max_size = 3000
|
280 |
+
if image.width > max_size or image.height > max_size:
|
281 |
+
ratio = min(max_size / image.width, max_size / image.height)
|
282 |
+
new_size = (int(image.width * ratio), int(image.height * ratio))
|
283 |
+
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
284 |
+
|
285 |
+
with BytesIO() as buffer:
|
286 |
+
image.save(
|
287 |
+
buffer,
|
288 |
+
format='WEBP',
|
289 |
+
quality=quality,
|
290 |
+
optimize=True,
|
291 |
+
method=6
|
292 |
+
)
|
293 |
+
return buffer.getvalue()
|
294 |
+
except Exception as e:
|
295 |
+
logger.error(f"Image processing failed: {str(e)}")
|
296 |
+
raise
|
297 |
+
|
298 |
+
|
299 |
+
@nto_cto_router.post("/necklaceTryOnID")
|
300 |
+
async def necklace_try_on_id(necklace_try_on_id: NecklaceTryOnIDEntity = Depends(parse_necklace_try_on_id),
|
301 |
+
image: UploadFile = File(...)):
|
302 |
+
logger.info("-" * 50)
|
303 |
+
logger.info(">>> NECKLACE TRY ON ID STARTED <<<")
|
304 |
+
logger.info(f"Parameters: storename={necklace_try_on_id.storename}, "
|
305 |
+
f"necklaceCategory={necklace_try_on_id.necklaceCategory}, "
|
306 |
+
f"necklaceImageId={necklace_try_on_id.necklaceImageId}")
|
307 |
+
start_time = time.time()
|
308 |
+
|
309 |
+
try:
|
310 |
+
image_loading_start = time.time()
|
311 |
+
imageBytes = await image.read()
|
312 |
+
jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{necklace_try_on_id.storename}/{necklace_try_on_id.necklaceCategory}/image/{necklace_try_on_id.necklaceImageId}.png"
|
313 |
+
image, jewellery = Image.open(BytesIO(imageBytes)), Image.open(returnBytesData(url=jewellery_url))
|
314 |
+
image_loading_time = round(time.time() - image_loading_start, 2)
|
315 |
+
logger.info(f">>> IMAGES LOADED SUCCESSFULLY in {image_loading_time}s <<<")
|
316 |
+
except Exception as e:
|
317 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
318 |
+
return JSONResponse(content={
|
319 |
+
"error": f"The requested resource (Image, necklace category, or store) is not available. Please verify the availability and try again",
|
320 |
+
"code": 404}, status_code=404)
|
321 |
+
|
322 |
+
try:
|
323 |
+
nto_start_time = time.time()
|
324 |
+
result, headerText, mask = await pipeline.necklaceTryOnDynamicOffset_(
|
325 |
+
image=image,
|
326 |
+
jewellery=jewellery,
|
327 |
+
storename=necklace_try_on_id.storename,
|
328 |
+
offset=[necklace_try_on_id.offset_x, necklace_try_on_id.offset_y]
|
329 |
+
)
|
330 |
+
nto_time = round(time.time() - nto_start_time, 2)
|
331 |
+
logger.info(f">>> NECKLACE TRY ON PROCESSING COMPLETED in {nto_time}s <<<")
|
332 |
+
|
333 |
+
if result is None:
|
334 |
+
logger.error(">>> NO FACE DETECTED IN THE IMAGE <<<")
|
335 |
+
return JSONResponse(
|
336 |
+
content={"error": "No face detected in the image please try again with a different image",
|
337 |
+
"code": 400}, status_code=400)
|
338 |
+
|
339 |
+
except Exception as e:
|
340 |
+
logger.error(f">>> NECKLACE TRY ON PROCESSING ERROR: {str(e)} <<<")
|
341 |
+
return JSONResponse(content={"error": f"Error during necklace try-on process", "code": 500},
|
342 |
+
status_code=500)
|
343 |
+
|
344 |
+
try:
|
345 |
+
upload_start_time = time.time()
|
346 |
+
upload_tasks = [
|
347 |
+
supabase_upload_and_return_url(prefix="NTO", image=result, necklace_id=necklace_try_on_id.necklaceImageId),
|
348 |
+
supabase_upload_and_return_url(prefix="NTO", image=mask, necklace_id=necklace_try_on_id.necklaceImageId)
|
349 |
+
]
|
350 |
+
result_url, mask_url = await asyncio.gather(*upload_tasks)
|
351 |
+
upload_time = round(time.time() - upload_start_time, 2)
|
352 |
+
|
353 |
+
if not result_url or not mask_url:
|
354 |
+
raise Exception("Failed to upload one or both images")
|
355 |
+
|
356 |
+
logger.info(f">>> RESULT IMAGES SAVED IN {upload_time}s <<<")
|
357 |
+
except Exception as e:
|
358 |
+
logger.error(f">>> RESULT SAVING ERROR: {str(e)} <<<")
|
359 |
+
return JSONResponse(content={"error": f"Error saving result images", "code": 500}, status_code=500)
|
360 |
+
|
361 |
+
try:
|
362 |
+
total_time = round(time.time() - start_time, 2)
|
363 |
+
response = {
|
364 |
+
"code": 200,
|
365 |
+
"output": f"{result_url}",
|
366 |
+
"mask": f"{mask_url}",
|
367 |
+
"timing": {
|
368 |
+
"image_loading": image_loading_time,
|
369 |
+
"nto_processing": nto_time,
|
370 |
+
"upload": upload_time,
|
371 |
+
"total": total_time
|
372 |
+
}
|
373 |
+
}
|
374 |
+
|
375 |
+
logger.info(f">>> TIMING BREAKDOWN <<<")
|
376 |
+
logger.info(f"Image Loading: {image_loading_time}s")
|
377 |
+
logger.info(f"NTO Processing: {nto_time}s")
|
378 |
+
logger.info(f"Upload Time: {upload_time}s")
|
379 |
+
logger.info(f"Total Time: {total_time}s")
|
380 |
+
logger.info(">>> NECKLACE TRY ON COMPLETED <<<")
|
381 |
+
logger.info("-" * 50)
|
382 |
+
|
383 |
+
return JSONResponse(content=response, status_code=200)
|
384 |
+
|
385 |
+
except Exception as e:
|
386 |
+
logger.error(f">>> RESPONSE GENERATION ERROR: {str(e)} <<<")
|
387 |
+
return JSONResponse(content={"error": f"Error generating response", "code": 500}, status_code=500)
|
388 |
+
|
389 |
+
finally:
|
390 |
+
if 'result' in locals(): del result
|
391 |
+
gc.collect()
|
392 |
+
|
393 |
+
|
394 |
+
#
|
395 |
+
# @nto_cto_router.post("/canvasPoints")
|
396 |
+
# async def canvas_points(necklace_try_on_id: NecklaceTryOnIDEntity = Depends(parse_necklace_try_on_id),
|
397 |
+
# image: UploadFile = File(...)):
|
398 |
+
# logger.info("-" * 50)
|
399 |
+
# logger.info(">>> CANVAS POINTS STARTED <<<")
|
400 |
+
# logger.info(f"Parameters: storename={necklace_try_on_id.storename}, "
|
401 |
+
# f"necklaceCategory={necklace_try_on_id.necklaceCategory}, "
|
402 |
+
# f"necklaceImageId={necklace_try_on_id.necklaceImageId}")
|
403 |
+
# start_time = time.time()
|
404 |
+
#
|
405 |
+
# try:
|
406 |
+
# imageBytes = await image.read()
|
407 |
+
# jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{necklace_try_on_id.storename}/{necklace_try_on_id.necklaceCategory}/image/{necklace_try_on_id.necklaceImageId}.png"
|
408 |
+
# image, jewellery = Image.open(BytesIO(imageBytes)), Image.open(returnBytesData(url=jewellery_url))
|
409 |
+
# logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
|
410 |
+
# except Exception as e:
|
411 |
+
# logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
412 |
+
# return JSONResponse(content={
|
413 |
+
# "error": f"The requested resource (Image, necklace category, or store) is not available. Please verify the availability and try again. Error",
|
414 |
+
# "code": 404}, status_code=404)
|
415 |
+
#
|
416 |
+
# try:
|
417 |
+
# response = await pipeline.canvasPoint(image=image, jewellery=jewellery, storename=necklace_try_on_id.storename)
|
418 |
+
# response = {"code": 200, "output": response}
|
419 |
+
# logger.info(">>> CANVAS POINTS PROCESSING COMPLETED <<<")
|
420 |
+
# except Exception as e:
|
421 |
+
# logger.error(f">>> CANVAS POINTS PROCESSING ERROR: {str(e)} <<<")
|
422 |
+
# return JSONResponse(content={"error": f"Error during canvas point process", "code": 500},
|
423 |
+
# status_code=500)
|
424 |
+
#
|
425 |
+
# try:
|
426 |
+
# creditResponse = deductAndTrackCredit(storename=necklace_try_on_id.storename, endpoint="/necklaceTryOnID")
|
427 |
+
# if creditResponse == "No Credits Available":
|
428 |
+
# logger.error(">>> NO CREDITS REMAINING <<<")
|
429 |
+
# return JSONResponse(content={"error": "No Credits Remaining", "code": 402}, status_code=402)
|
430 |
+
# logger.info(">>> CREDITS DEDUCTED SUCCESSFULLY <<<")
|
431 |
+
# except Exception as e:
|
432 |
+
# logger.error(f">>> CREDIT DEDUCTION ERROR: {str(e)} <<<")
|
433 |
+
# return JSONResponse(content={"error": f"Error deducting credits", "code": 500}, status_code=500)
|
434 |
+
#
|
435 |
+
# total_inference_time = round((time.time() - start_time), 2)
|
436 |
+
# logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
437 |
+
# logger.info(f">>> CANVAS POINTS COMPLETED <<<")
|
438 |
+
# logger.info("-" * 50)
|
439 |
+
#
|
440 |
+
# return JSONResponse(status_code=200, content=response)
|
441 |
+
|
442 |
+
|
443 |
+
@nto_cto_router.post("/necklaceTryOnWithPoints")
|
444 |
+
async def necklace_try_on_with_points(necklace_try_on_id: NecklaceTryOnIDEntity = Depends(parse_necklace_try_on_id),
|
445 |
+
image: UploadFile = File(...),
|
446 |
+
left_x: int = Form(...),
|
447 |
+
left_y: int = Form(...),
|
448 |
+
right_x: int = Form(...),
|
449 |
+
right_y: int = Form(...)):
|
450 |
+
logger.info("-" * 50)
|
451 |
+
logger.info(">>> NECKLACE TRY ON WITH POINTS STARTED <<<")
|
452 |
+
logger.info(f"Parameters: storename={necklace_try_on_id.storename}, "
|
453 |
+
f"necklaceCategory={necklace_try_on_id.necklaceCategory}, "
|
454 |
+
f"necklaceImageId={necklace_try_on_id.necklaceImageId}, "
|
455 |
+
f"left_point=({left_x}, {left_y}), right_point=({right_x}, {right_y})")
|
456 |
+
start_time = time.time()
|
457 |
+
|
458 |
+
try:
|
459 |
+
imageBytes = await image.read()
|
460 |
+
jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{necklace_try_on_id.storename}/{necklace_try_on_id.necklaceCategory}/image/{necklace_try_on_id.necklaceImageId}.png"
|
461 |
+
image, jewellery = Image.open(BytesIO(imageBytes)), Image.open(returnBytesData(url=jewellery_url))
|
462 |
+
logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
|
463 |
+
except Exception as e:
|
464 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
465 |
+
return JSONResponse(content={
|
466 |
+
"error": f"The requested resource (Image, necklace category, or store) is not available. Please verify the availability and try again. Error: {str(e)}",
|
467 |
+
"code": 404}, status_code=404)
|
468 |
+
|
469 |
+
try:
|
470 |
+
result, headerText, mask = await pipeline.necklaceTryOnWithPoints_(
|
471 |
+
image=image, jewellery=jewellery, left_shoulder=(left_x, left_y), right_shoulder=(right_x, right_y),
|
472 |
+
storename=necklace_try_on_id.storename
|
473 |
+
)
|
474 |
+
logger.info(">>> NECKLACE TRY ON PROCESSING COMPLETED <<<")
|
475 |
+
except Exception as e:
|
476 |
+
logger.error(f">>> NECKLACE TRY ON PROCESSING ERROR: {str(e)} <<<")
|
477 |
+
return JSONResponse(content={"error": f"Error during necklace try-on process", "code": 500},
|
478 |
+
status_code=500)
|
479 |
+
|
480 |
+
try:
|
481 |
+
inMemFile = BytesIO()
|
482 |
+
inMemFileMask = BytesIO()
|
483 |
+
result.save(inMemFile, format="WEBP", quality=85)
|
484 |
+
mask.save(inMemFileMask, format="WEBP", quality=85)
|
485 |
+
outputBytes = inMemFile.getvalue()
|
486 |
+
maskBytes = inMemFileMask.getvalue()
|
487 |
+
logger.info(">>> RESULT IMAGES SAVED <<<")
|
488 |
+
except Exception as e:
|
489 |
+
logger.error(f">>> RESULT SAVING ERROR: {str(e)} <<<")
|
490 |
+
return JSONResponse(content={"error": f"Error saving result images", "code": 500}, status_code=500)
|
491 |
+
|
492 |
+
try:
|
493 |
+
creditResponse = deductAndTrackCredit(storename=necklace_try_on_id.storename, endpoint="/necklaceTryOnID")
|
494 |
+
total_inference_time = round((time.time() - start_time), 2)
|
495 |
+
response = {
|
496 |
+
"code": 200,
|
497 |
+
"output": f"data:image/WEBP;base64,{base64.b64encode(outputBytes).decode('utf-8')}",
|
498 |
+
"mask": f"data:image/WEBP;base64,{base64.b64encode(maskBytes).decode('utf-8')}",
|
499 |
+
"inference_time": total_inference_time
|
500 |
+
}
|
501 |
+
if creditResponse == "No Credits Available":
|
502 |
+
logger.error(">>> NO CREDITS REMAINING <<<")
|
503 |
+
response = {"error": "No Credits Remaining"}
|
504 |
+
return JSONResponse(content=response, status_code=402)
|
505 |
+
logger.info(">>> CREDITS DEDUCTED SUCCESSFULLY <<<")
|
506 |
+
except Exception as e:
|
507 |
+
logger.error(f">>> CREDIT DEDUCTION ERROR: {str(e)} <<<")
|
508 |
+
return JSONResponse(content={"error": f"Error deducting credits", "code": 500}, status_code=500)
|
509 |
+
|
510 |
+
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
511 |
+
logger.info(f">>> NECKLACE TRY ON WITH POINTS COMPLETED <<<")
|
512 |
+
logger.info("-" * 50)
|
513 |
+
|
514 |
+
return JSONResponse(content=response, status_code=200)
|
515 |
+
|
516 |
+
|
517 |
+
@nto_cto_router.post("/clothingAndNecklaceTryOn")
|
518 |
+
async def clothing_and_necklace_try_on(
|
519 |
+
image: UploadFile = File(...),
|
520 |
+
necklaceImageId: str = Form(...),
|
521 |
+
necklaceCategory: str = Form(...),
|
522 |
+
storename: str = Form(...),
|
523 |
+
clothing_type: str = Form(...)
|
524 |
+
):
|
525 |
+
logger.info("-" * 50)
|
526 |
+
logger.info(">>> CLOTHING AND NECKLACE TRY ON STARTED <<<")
|
527 |
+
logger.info(f"Parameters: storename={storename}, "
|
528 |
+
f"necklaceCategory={necklaceCategory}, "
|
529 |
+
f"necklaceImageId={necklaceImageId}, "
|
530 |
+
f"clothing_type={clothing_type}")
|
531 |
+
start_time = time.time()
|
532 |
+
|
533 |
+
def image_to_base64(img: Image.Image) -> str:
|
534 |
+
buffer = BytesIO()
|
535 |
+
img.save(buffer, format="WEBP", quality=85, optimize=True)
|
536 |
+
return f"data:image/webp;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
537 |
+
|
538 |
+
try:
|
539 |
+
person_bytes = await image.read()
|
540 |
+
person_image = Image.open(BytesIO(person_bytes)).convert("RGB").resize((512, 512))
|
541 |
+
|
542 |
+
jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{storename}/{necklaceCategory}/image/{necklaceImageId}.png"
|
543 |
+
necklace_image = Image.open(returnBytesData(url=jewellery_url)).convert("RGBA")
|
544 |
+
|
545 |
+
logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
|
546 |
+
|
547 |
+
mask, left_point, right_point = await pipeline.shoulderPointMaskGeneration_(image=person_image)
|
548 |
+
logger.info(">>> MASK AND POINTS GENERATION COMPLETED <<<")
|
549 |
+
|
550 |
+
mask_data_uri, image_data_uri = await asyncio.gather(
|
551 |
+
asyncio.to_thread(image_to_base64, mask),
|
552 |
+
asyncio.to_thread(image_to_base64, person_image)
|
553 |
+
)
|
554 |
+
|
555 |
+
cto_output = replicate_run_cto({
|
556 |
+
"mask": mask_data_uri,
|
557 |
+
"image": image_data_uri,
|
558 |
+
"prompt": f"Dull {clothing_type}, non-reflective clothing, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple, perfect eyes, perfect face, perfect body, high quality, realistic, photorealistic, high resolution,traditional full sleeve blouse",
|
559 |
+
"negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
|
560 |
+
"num_inference_steps": 20
|
561 |
+
})
|
562 |
+
|
563 |
+
if not cto_output or not isinstance(cto_output, (list, tuple)) or not cto_output[0]:
|
564 |
+
raise ValueError("Invalid output from clothing try-on")
|
565 |
+
|
566 |
+
async with aiohttp.ClientSession() as session:
|
567 |
+
async with session.get(str(cto_output[0])) as response:
|
568 |
+
if response.status != 200:
|
569 |
+
raise HTTPException(status_code=response.status, detail="Failed to fetch CTO output")
|
570 |
+
cto_result_bytes = await response.read()
|
571 |
+
|
572 |
+
with BytesIO(cto_result_bytes) as buf:
|
573 |
+
cto_result_image = Image.open(buf).convert("RGB")
|
574 |
+
|
575 |
+
result, headerText, _ = await pipeline.necklaceTryOnWithPoints_(
|
576 |
+
image=cto_result_image,
|
577 |
+
jewellery=necklace_image,
|
578 |
+
left_shoulder=left_point,
|
579 |
+
right_shoulder=right_point,
|
580 |
+
storename=storename
|
581 |
+
)
|
582 |
+
|
583 |
+
if result is None:
|
584 |
+
raise ValueError("Failed to process necklace try-on")
|
585 |
+
|
586 |
+
result_url = await supabase_upload_and_return_url(prefix="NTOCTO", image=result, necklace_id=necklaceImageId)
|
587 |
+
|
588 |
+
if not result_url:
|
589 |
+
raise ValueError("Failed to upload result image")
|
590 |
+
|
591 |
+
response = {
|
592 |
+
"code": 200,
|
593 |
+
"output": result_url,
|
594 |
+
"inference_time": round((time.time() - start_time), 2)
|
595 |
+
}
|
596 |
+
|
597 |
+
except ValueError as ve:
|
598 |
+
logger.error(f">>> PROCESSING ERROR: {str(ve)} <<<")
|
599 |
+
return JSONResponse(status_code=400, content={"error": str(ve), "code": 400})
|
600 |
+
except Exception as e:
|
601 |
+
logger.error(f">>> PROCESSING ERROR: {str(e)} <<<")
|
602 |
+
return JSONResponse(status_code=500, content={"error": "Error during image processing", "code": 500})
|
603 |
+
finally:
|
604 |
+
gc.collect()
|
605 |
+
|
606 |
+
logger.info(f">>> TOTAL INFERENCE TIME: {response['inference_time']}s <<<")
|
607 |
+
logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<")
|
608 |
+
logger.info("-" * 50)
|
609 |
+
|
610 |
+
return JSONResponse(content=response, status_code=200)
|
611 |
+
|
612 |
+
|
613 |
+
@nto_cto_router.post("/m_nto")
|
614 |
+
async def mannequin_nto(necklace_try_on_id: NecklaceTryOnIDEntity = Depends(parse_necklace_try_on_id),
|
615 |
+
image: UploadFile = File(...)):
|
616 |
+
logger.info("-" * 50)
|
617 |
+
logger.info(">>> MANNEQUIN NTO STARTED <<<")
|
618 |
+
logger.info(f"Parameters: storename={necklace_try_on_id.storename}, "
|
619 |
+
f"necklaceCategory={necklace_try_on_id.necklaceCategory}, "
|
620 |
+
f"necklaceImageId={necklace_try_on_id.necklaceImageId}")
|
621 |
+
start_time = time.time()
|
622 |
+
|
623 |
+
try:
|
624 |
+
imageBytes = await image.read()
|
625 |
+
jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{necklace_try_on_id.storename}/{necklace_try_on_id.necklaceCategory}/image/{necklace_try_on_id.necklaceImageId}.png"
|
626 |
+
image, jewellery = Image.open(BytesIO(imageBytes)), Image.open(returnBytesData(url=jewellery_url))
|
627 |
+
logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
|
628 |
+
except Exception as e:
|
629 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
630 |
+
return JSONResponse(content={
|
631 |
+
"error": f"The requested resource (Image, necklace category, or store) is not available. Please verify the availability and try again",
|
632 |
+
"code": 404}, status_code=404)
|
633 |
+
|
634 |
+
try:
|
635 |
+
result, resized_img = await pipeline.necklaceTryOnMannequin_(image=image, jewellery=jewellery)
|
636 |
+
|
637 |
+
if result is None:
|
638 |
+
logger.error(">>> NO FACE DETECTED IN THE IMAGE <<<")
|
639 |
+
return JSONResponse(
|
640 |
+
content={"error": "No face detected in the image please try again with a different image",
|
641 |
+
"code": 400}, status_code=400)
|
642 |
+
logger.info(">>> NECKLACE TRY ON PROCESSING COMPLETED <<<")
|
643 |
+
except Exception as e:
|
644 |
+
logger.error(f">>> NECKLACE TRY ON PROCESSING ERROR: {str(e)} <<<")
|
645 |
+
return JSONResponse(content={"error": f"Error during necklace try-on process", "code": 500},
|
646 |
+
status_code=500)
|
647 |
+
|
648 |
+
try:
|
649 |
+
logger.info(">>> SAVING RESULT IMAGES <<<")
|
650 |
+
start_time_saving = time.time()
|
651 |
+
|
652 |
+
# Upload both images concurrently
|
653 |
+
upload_tasks = supabase_upload_and_return_url(prefix="MNTO", image=result, necklace_id=necklace_try_on_id.necklaceImageId)
|
654 |
+
result_url = await asyncio.gather(upload_tasks)
|
655 |
+
|
656 |
+
if result_url[0] is None:
|
657 |
+
raise Exception("Failed to upload one or both images")
|
658 |
+
|
659 |
+
logger.info(f">>> RESULT IMAGES SAVED IN {round((time.time() - start_time_saving), 2)}s <<<")
|
660 |
+
logger.info(">>> RESULT IMAGES SAVED <<<")
|
661 |
+
except Exception as e:
|
662 |
+
logger.error(f">>> RESULT SAVING ERROR: {str(e)} <<<")
|
663 |
+
return JSONResponse(content={"error": f"Error saving result images", "code": 500}, status_code=500)
|
664 |
+
try:
|
665 |
+
try:
|
666 |
+
total_backend_time = round((time.time() - start_time), 2)
|
667 |
+
response = {
|
668 |
+
"code": 200,
|
669 |
+
"output": f"{result_url[0]}",
|
670 |
+
"inference_time": total_backend_time
|
671 |
+
}
|
672 |
+
|
673 |
+
|
674 |
+
|
675 |
+
except Exception as e:
|
676 |
+
logger.error(f">>> RESPONSE GENERATION ERROR: {str(e)} <<<")
|
677 |
+
return JSONResponse(content={"error": f"Error generating response", "code": 500}, status_code=500)
|
678 |
+
|
679 |
+
logger.info(f">>> TOTAL INFERENCE TIME: {total_backend_time}s <<<")
|
680 |
+
logger.info(f">>> NECKLACE TRY ON COMPLETED :: {necklace_try_on_id.storename} <<<")
|
681 |
+
logger.info("-" * 50)
|
682 |
+
|
683 |
+
return JSONResponse(content=response, status_code=200)
|
684 |
+
|
685 |
+
finally:
|
686 |
+
if 'result' in locals(): del result
|
687 |
+
gc.collect()
|
688 |
+
|
689 |
+
|
690 |
+
@nto_cto_router.post("/nto_mto_combined")
|
691 |
+
async def combined_cto_nto(
|
692 |
+
image: UploadFile = File(...),
|
693 |
+
clothing_type: str = Form(...),
|
694 |
+
necklace_id: str = Form(...),
|
695 |
+
necklace_category: str = Form(...),
|
696 |
+
storename: str = Form(...),
|
697 |
+
offset_x: float = Form(...),
|
698 |
+
offset_y: float = Form(...)
|
699 |
+
):
|
700 |
+
logger.info("-" * 50)
|
701 |
+
logger.info(">>> COMBINED CTO-NTO STARTED <<<")
|
702 |
+
logger.info(f"Parameters: storename={storename}, necklace_category={necklace_category}, "
|
703 |
+
f"necklace_id={necklace_id}, clothing_type={clothing_type}")
|
704 |
+
start_time = time.time()
|
705 |
+
|
706 |
+
def image_to_base64(img: Image.Image) -> str:
|
707 |
+
buffer = BytesIO()
|
708 |
+
img.save(buffer, format="WEBP", quality=85, optimize=True)
|
709 |
+
return f"data:image/webp;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
710 |
+
|
711 |
+
try:
|
712 |
+
# Load source image and necklace
|
713 |
+
image_bytes = await image.read()
|
714 |
+
source_image = Image.open(BytesIO(image_bytes)).convert("RGB").resize((512, 512))
|
715 |
+
|
716 |
+
jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{storename}/{necklace_category}/image/{necklace_id}.png"
|
717 |
+
necklace_image = Image.open(returnBytesData(url=jewellery_url)).convert("RGBA")
|
718 |
+
logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
|
719 |
+
except Exception as e:
|
720 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
721 |
+
return JSONResponse(content={
|
722 |
+
"error": "Error loading images. Please verify the image and necklace availability.",
|
723 |
+
"code": 404
|
724 |
+
}, status_code=404)
|
725 |
+
|
726 |
+
try:
|
727 |
+
# Generate mask and shoulder points
|
728 |
+
mask_start_time = time.time()
|
729 |
+
mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=source_image)
|
730 |
+
mask_time = round(time.time() - mask_start_time, 2)
|
731 |
+
logger.info(f">>> MASK GENERATION COMPLETED in {mask_time}s <<<")
|
732 |
+
|
733 |
+
# Convert images to base64
|
734 |
+
encoding_start_time = time.time()
|
735 |
+
mask_data_uri, image_data_uri = await asyncio.gather(
|
736 |
+
asyncio.to_thread(image_to_base64, mask),
|
737 |
+
asyncio.to_thread(image_to_base64, source_image)
|
738 |
+
)
|
739 |
+
encoding_time = round(time.time() - encoding_start_time, 2)
|
740 |
+
logger.info(f">>> IMAGE ENCODING COMPLETED in {encoding_time}s <<<")
|
741 |
+
|
742 |
+
# Perform CTO
|
743 |
+
cto_start_time = time.time()
|
744 |
+
cto_output = replicate_run_cto({
|
745 |
+
"mask": mask_data_uri,
|
746 |
+
"image": image_data_uri,
|
747 |
+
"prompt": f"Dull {clothing_type}, non-reflective clothing, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple, perfect eyes, perfect face, perfect body, high quality, realistic, photorealistic, high resolution,traditional full sleeve blouse",
|
748 |
+
"negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
|
749 |
+
"num_inference_steps": 20
|
750 |
+
})
|
751 |
+
cto_time = round(time.time() - cto_start_time, 2)
|
752 |
+
logger.info(f">>> CTO COMPLETED in {cto_time}s <<<")
|
753 |
+
|
754 |
+
if not cto_output or not isinstance(cto_output, (list, tuple)) or not cto_output[0]:
|
755 |
+
raise ValueError("Invalid output from clothing try-on")
|
756 |
+
|
757 |
+
# Get CTO result image
|
758 |
+
async with aiohttp.ClientSession() as session:
|
759 |
+
async with session.get(str(cto_output[0])) as response:
|
760 |
+
if response.status != 200:
|
761 |
+
raise HTTPException(status_code=response.status, detail="Failed to fetch CTO output")
|
762 |
+
cto_result_bytes = await response.read()
|
763 |
+
|
764 |
+
# Perform NTO
|
765 |
+
nto_start_time = time.time()
|
766 |
+
with BytesIO(cto_result_bytes) as buf:
|
767 |
+
cto_result_image = Image.open(buf).convert("RGB")
|
768 |
+
result, headerText, _ = await pipeline.necklaceTryOnDynamicOffset_(
|
769 |
+
image=cto_result_image,
|
770 |
+
jewellery=necklace_image,
|
771 |
+
storename=storename,
|
772 |
+
offset=[offset_x, offset_y]
|
773 |
+
)
|
774 |
+
nto_time = round(time.time() - nto_start_time, 2)
|
775 |
+
logger.info(f">>> NTO COMPLETED in {nto_time}s <<<")
|
776 |
+
|
777 |
+
if result is None:
|
778 |
+
raise ValueError("Failed to process necklace try-on")
|
779 |
+
|
780 |
+
upload_start_time = time.time()
|
781 |
+
result_url = await supabase_upload_and_return_url(
|
782 |
+
prefix="NTOCTO",
|
783 |
+
image=result,
|
784 |
+
necklace_id=necklace_id
|
785 |
+
)
|
786 |
+
upload_time = round(time.time() - upload_start_time, 2)
|
787 |
+
logger.info(f">>> RESULT UPLOADED in {upload_time}s <<<")
|
788 |
+
|
789 |
+
if not result_url:
|
790 |
+
raise ValueError("Failed to upload result image")
|
791 |
+
|
792 |
+
total_time = round(time.time() - start_time, 2)
|
793 |
+
response = {
|
794 |
+
"code": 200,
|
795 |
+
"output": result_url,
|
796 |
+
"timing": {
|
797 |
+
"mask_generation": mask_time,
|
798 |
+
"encoding": encoding_time,
|
799 |
+
"cto_inference": cto_time,
|
800 |
+
"nto_inference": nto_time,
|
801 |
+
"upload": upload_time,
|
802 |
+
"total": total_time
|
803 |
+
}
|
804 |
+
}
|
805 |
+
|
806 |
+
except ValueError as ve:
|
807 |
+
logger.error(f">>> PROCESSING ERROR: {str(ve)} <<<")
|
808 |
+
return JSONResponse(status_code=400, content={"error": str(ve), "code": 400})
|
809 |
+
except Exception as e:
|
810 |
+
logger.error(f">>> PROCESSING ERROR: {str(e)} <<<")
|
811 |
+
return JSONResponse(status_code=500, content={"error": "Error during image processing", "code": 500})
|
812 |
+
finally:
|
813 |
+
if 'result' in locals(): del result
|
814 |
+
gc.collect()
|
815 |
+
|
816 |
+
logger.info(f">>> TOTAL PROCESSING TIME: {total_time}s <<<")
|
817 |
+
logger.info(">>> COMBINED CTO-NTO COMPLETED SUCCESSFULLY <<<")
|
818 |
+
logger.info("-" * 50)
|
819 |
+
|
820 |
+
return JSONResponse(content=response, status_code=200)
|
821 |
+
|
822 |
+
|
823 |
+
@nto_cto_router.post("/nto-dynamic-offset")
|
824 |
+
async def necklace_try_on_id_dynamic_offset(
|
825 |
+
necklace_try_on_id: NecklaceTryOnIDEntity = Depends(parse_necklace_try_on_id),
|
826 |
+
image: UploadFile = File(...)):
|
827 |
+
logger.info("-" * 50)
|
828 |
+
logger.info(">>> NECKLACE TRY ON ID STARTED <<<")
|
829 |
+
logger.info(f"Parameters: storename={necklace_try_on_id.storename}, "
|
830 |
+
f"necklaceCategory={necklace_try_on_id.necklaceCategory}, "
|
831 |
+
f"necklaceImageId={necklace_try_on_id.necklaceImageId}")
|
832 |
+
start_time = time.time()
|
833 |
+
|
834 |
+
try:
|
835 |
+
image_loading_start = time.time()
|
836 |
+
imageBytes = await image.read()
|
837 |
+
jewellery_url = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{necklace_try_on_id.storename}/{necklace_try_on_id.necklaceCategory}/image/{necklace_try_on_id.necklaceImageId}.png"
|
838 |
+
image, jewellery = Image.open(BytesIO(imageBytes)), Image.open(returnBytesData(url=jewellery_url))
|
839 |
+
image_loading_time = round(time.time() - image_loading_start, 2)
|
840 |
+
logger.info(f">>> IMAGES LOADED SUCCESSFULLY in {image_loading_time}s <<<")
|
841 |
+
except Exception as e:
|
842 |
+
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
843 |
+
return JSONResponse(content={
|
844 |
+
"error": f"The requested resource (Image, necklace category, or store) is not available. Please verify the availability and try again",
|
845 |
+
"code": 404}, status_code=404)
|
846 |
+
|
847 |
+
try:
|
848 |
+
nto_start_time = time.time()
|
849 |
+
result, headerText, mask = await pipeline.necklaceTryOnDynamicOffset_(
|
850 |
+
image=image,
|
851 |
+
jewellery=jewellery,
|
852 |
+
storename=necklace_try_on_id.storename,
|
853 |
+
offset=[necklace_try_on_id.offset_x, necklace_try_on_id.offset_y]
|
854 |
+
)
|
855 |
+
nto_time = round(time.time() - nto_start_time, 2)
|
856 |
+
logger.info(f">>> NECKLACE TRY ON PROCESSING COMPLETED in {nto_time}s <<<")
|
857 |
+
|
858 |
+
if result is None:
|
859 |
+
logger.error(">>> NO FACE DETECTED IN THE IMAGE <<<")
|
860 |
+
return JSONResponse(
|
861 |
+
content={"error": "No face detected in the image please try again with a different image",
|
862 |
+
"code": 400}, status_code=400)
|
863 |
+
|
864 |
+
except Exception as e:
|
865 |
+
logger.error(f">>> NECKLACE TRY ON PROCESSING ERROR: {str(e)} <<<")
|
866 |
+
return JSONResponse(content={"error": f"Error during necklace try-on process", "code": 500},
|
867 |
+
status_code=500)
|
868 |
+
|
869 |
+
try:
|
870 |
+
upload_start_time = time.time()
|
871 |
+
upload_tasks = [
|
872 |
+
supabase_upload_and_return_url(prefix="NTO", image=result, necklace_id=necklace_try_on_id.necklaceImageId)
|
873 |
+
]
|
874 |
+
result_url = await asyncio.gather(*upload_tasks)
|
875 |
+
upload_time = round(time.time() - upload_start_time, 2)
|
876 |
+
|
877 |
+
logger.info(f">>> RESULT IMAGES SAVED IN {upload_time}s <<<")
|
878 |
+
except Exception as e:
|
879 |
+
logger.error(f">>> RESULT SAVING ERROR: {str(e)} <<<")
|
880 |
+
return JSONResponse(content={"error": f"Error saving result images", "code": 500}, status_code=500)
|
881 |
+
|
882 |
+
try:
|
883 |
+
total_time = round(time.time() - start_time, 2)
|
884 |
+
response = {
|
885 |
+
"code": 200,
|
886 |
+
"output": f"{result_url[0]}",
|
887 |
+
"timing": {
|
888 |
+
"image_loading": image_loading_time,
|
889 |
+
"nto_processing": nto_time,
|
890 |
+
"upload": upload_time,
|
891 |
+
"total": total_time
|
892 |
+
}
|
893 |
+
}
|
894 |
+
|
895 |
+
logger.info(f">>> TIMING BREAKDOWN <<<")
|
896 |
+
logger.info(f"Image Loading: {image_loading_time}s")
|
897 |
+
logger.info(f"NTO Processing: {nto_time}s")
|
898 |
+
logger.info(f"Upload Time: {upload_time}s")
|
899 |
+
logger.info(f"Total Time: {total_time}s")
|
900 |
+
logger.info(">>> NECKLACE TRY ON COMPLETED <<<")
|
901 |
+
logger.info("-" * 50)
|
902 |
+
|
903 |
+
return JSONResponse(content=response, status_code=200)
|
904 |
+
|
905 |
+
except Exception as e:
|
906 |
+
logger.error(f">>> RESPONSE GENERATION ERROR: {str(e)} <<<")
|
907 |
+
return JSONResponse(content={"error": f"Error generating response", "code": 500}, status_code=500)
|
908 |
+
|
909 |
+
finally:
|
910 |
+
if 'result' in locals(): del result
|
911 |
+
gc.collect()
|
src/components/__init__.py
ADDED
File without changes
|
src/components/auto_crop.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ NTO-TCP-HF
|
3 |
+
created @ 2024-10-28
|
4 |
+
author @ github.com/ishworrsubedii
|
5 |
+
"""
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
def crop_transparent_image(image_data: bytes) -> tuple[bytes, dict]:
|
12 |
+
try:
|
13 |
+
image = Image.open(BytesIO(image_data))
|
14 |
+
|
15 |
+
if image.format != 'PNG':
|
16 |
+
raise ValueError("Only PNG images are supported")
|
17 |
+
|
18 |
+
width = image.size[0]
|
19 |
+
height = image.size[1]
|
20 |
+
pixels = image.load()
|
21 |
+
|
22 |
+
top = height
|
23 |
+
bottom = 0
|
24 |
+
left = width
|
25 |
+
right = 0
|
26 |
+
|
27 |
+
# Find boundaries of non-transparent pixels
|
28 |
+
for y in range(height):
|
29 |
+
for x in range(width):
|
30 |
+
pixel = pixels[x, y]
|
31 |
+
if isinstance(pixel, tuple) and len(pixel) == 4:
|
32 |
+
if pixel[3] != 0:
|
33 |
+
left = min(left, x)
|
34 |
+
top = min(top, y)
|
35 |
+
right = max(right, x)
|
36 |
+
bottom = max(bottom, y)
|
37 |
+
|
38 |
+
left = max(0, left)
|
39 |
+
top = max(0, top)
|
40 |
+
right = min(width, right + 1)
|
41 |
+
bottom = min(height, bottom + 1)
|
42 |
+
|
43 |
+
if left >= right or top >= bottom:
|
44 |
+
left, top, right, bottom = 0, 0, width, height
|
45 |
+
|
46 |
+
# Crop image
|
47 |
+
cropped_image = image.crop((left, top, right, bottom))
|
48 |
+
|
49 |
+
output_buffer = BytesIO()
|
50 |
+
cropped_image.save(output_buffer, format='PNG')
|
51 |
+
output_buffer.seek(0)
|
52 |
+
|
53 |
+
metadata = {
|
54 |
+
"original_size": f"{width}x{height}",
|
55 |
+
"cropped_size": f"{cropped_image.width}x{cropped_image.height}"
|
56 |
+
}
|
57 |
+
|
58 |
+
return output_buffer.getvalue(), metadata
|
59 |
+
|
60 |
+
except Exception as e:
|
61 |
+
raise ValueError(f"Error processing image: {str(e)}")
|
src/components/color_extraction.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ NTO-TCP-HF
|
3 |
+
created @ 2024-10-28
|
4 |
+
author @ github.com/ishworrsubedii
|
5 |
+
"""
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
class ColorExtractionRMBG:
|
12 |
+
def __init__(self):
|
13 |
+
self.HSV_LOWER = None
|
14 |
+
self.HSV_UPPER = None
|
15 |
+
|
16 |
+
def hex_to_rgb(self, hex_color: str):
|
17 |
+
hex_color = hex_color.lstrip("#")
|
18 |
+
return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4))
|
19 |
+
|
20 |
+
def rgb_to_hsv(self, rgb_color):
|
21 |
+
rgb_array = np.uint8([[rgb_color]])
|
22 |
+
hsv_array = cv2.cvtColor(rgb_array, cv2.COLOR_RGB2HSV)
|
23 |
+
return hsv_array[0][0]
|
24 |
+
|
25 |
+
def set_thresholds(self, hex_color: str, threshold: int):
|
26 |
+
hsv_color = self.rgb_to_hsv(self.hex_to_rgb(hex_color))
|
27 |
+
lower_bound = np.clip([hsv_color[0] - threshold, 50, 50], 0, 255)
|
28 |
+
upper_bound = np.clip([hsv_color[0] + threshold, 255, 255], 0, 255)
|
29 |
+
return lower_bound, upper_bound
|
30 |
+
|
31 |
+
def extract_color(self, image: np.ndarray, hex_color: str, threshold: int):
|
32 |
+
self.HSV_LOWER, self.HSV_UPPER = self.set_thresholds(hex_color, threshold)
|
33 |
+
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
34 |
+
mask = cv2.inRange(hsv, self.HSV_LOWER, self.HSV_UPPER)
|
35 |
+
result = cv2.bitwise_and(image, image, mask=mask)
|
36 |
+
result = cv2.cvtColor(result, cv2.COLOR_BGR2BGRA)
|
37 |
+
result[:, :, 3] = mask
|
38 |
+
return result
|
src/components/makeup_try_on.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ NTO-TCP-HF
|
3 |
+
created @ 2024-12-05
|
4 |
+
author @ github.com/ishworrsubedii
|
5 |
+
"""
|
6 |
+
import cv2
|
7 |
+
import mediapipe as mp
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
class MakeupTryOn:
|
12 |
+
def __init__(self):
|
13 |
+
self.mp_face_mesh = mp.solutions.face_mesh
|
14 |
+
self.mp_drawing = mp.solutions.drawing_utils
|
15 |
+
self.face_points = {
|
16 |
+
"LIP_UPPER": [61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291, 308, 415, 310, 312, 13, 82, 81, 80, 191, 78],
|
17 |
+
"LIP_LOWER": [61, 146, 91, 181, 84, 17, 314, 405, 321, 375, 291, 308, 324, 402, 317, 14, 87, 178, 88, 95,
|
18 |
+
78, 61],
|
19 |
+
"EYESHADOW_LEFT": [226, 247, 30, 29, 27, 28, 56, 190, 243, 173, 157, 158, 159, 160, 161, 246, 33, 130, 226],
|
20 |
+
"EYESHADOW_RIGHT": [463, 414, 286, 258, 257, 259, 260, 467, 446, 359, 263, 466, 388, 387, 386, 385, 384,
|
21 |
+
398,
|
22 |
+
362,
|
23 |
+
463],
|
24 |
+
"EYELINER_LEFT": [243, 112, 26, 22, 23, 24, 110, 25, 226, 130, 33, 7, 163, 144, 145, 153, 154, 155, 133,
|
25 |
+
243],
|
26 |
+
"EYELINER_RIGHT": [463, 362, 382, 381, 380, 374, 373, 390, 249, 263, 359, 446, 255, 339, 254, 253, 252, 256,
|
27 |
+
341,
|
28 |
+
463],
|
29 |
+
}
|
30 |
+
|
31 |
+
# Read landmarks
|
32 |
+
def read_landmarks(self, image: np.array):
|
33 |
+
with self.mp_face_mesh.FaceMesh(refine_landmarks=True) as face_mesh:
|
34 |
+
results = face_mesh.process(image)
|
35 |
+
if not results.multi_face_landmarks:
|
36 |
+
raise ValueError("No face detected")
|
37 |
+
face_landmarks = results.multi_face_landmarks[0].landmark
|
38 |
+
|
39 |
+
landmark_coordinates = {}
|
40 |
+
for idx, landmark in enumerate(face_landmarks):
|
41 |
+
landmark_px = self.mp_drawing._normalized_to_pixel_coordinates(
|
42 |
+
landmark.x, landmark.y, image.shape[1], image.shape[0]
|
43 |
+
)
|
44 |
+
if landmark_px:
|
45 |
+
landmark_coordinates[idx] = landmark_px
|
46 |
+
|
47 |
+
return landmark_coordinates
|
48 |
+
|
49 |
+
# Add makeup
|
50 |
+
def add_makeup(self, image, mask, landmarks, feature, color):
|
51 |
+
# Create mask for facial features with color
|
52 |
+
points = np.array([landmarks[idx] for idx in self.face_points[feature] if idx in landmarks])
|
53 |
+
if points.size > 0:
|
54 |
+
cv2.fillPoly(mask, [points], color)
|
55 |
+
return mask
|
56 |
+
|
57 |
+
# Apply makeup functions
|
58 |
+
def apply_lipstick(self, image, landmarks, lipstick_color):
|
59 |
+
mask = np.zeros_like(image)
|
60 |
+
mask = self.add_makeup(image, mask, landmarks, "LIP_UPPER", lipstick_color)
|
61 |
+
mask = self.add_makeup(image, mask, landmarks, "LIP_LOWER", lipstick_color)
|
62 |
+
return mask
|
63 |
+
|
64 |
+
def apply_eyeshadow(self, image, landmarks, eyeshadow_color):
|
65 |
+
mask = np.zeros_like(image)
|
66 |
+
mask = self.add_makeup(image, mask, landmarks, "EYESHADOW_LEFT", eyeshadow_color)
|
67 |
+
mask = self.add_makeup(image, mask, landmarks, "EYESHADOW_RIGHT", eyeshadow_color)
|
68 |
+
return mask
|
69 |
+
|
70 |
+
def apply_eyeliner(self, image, landmarks, eyeliner_color):
|
71 |
+
mask = np.zeros_like(image)
|
72 |
+
mask = self.add_makeup(image, mask, landmarks, "EYELINER_LEFT", eyeliner_color)
|
73 |
+
mask = self.add_makeup(image, mask, landmarks, "EYELINER_RIGHT", eyeliner_color)
|
74 |
+
return mask
|
75 |
+
|
76 |
+
def realistic_makeup(self, image, mask):
|
77 |
+
# Convert to float32 for better precision in calculations
|
78 |
+
image = image.astype(np.float32) / 255.0
|
79 |
+
mask = mask.astype(np.float32) / 255.0
|
80 |
+
|
81 |
+
# Apply Gaussian blur for natural blending
|
82 |
+
blurred_mask = cv2.GaussianBlur(mask, (7, 7), 4)
|
83 |
+
|
84 |
+
# Create alpha channel for better color blending
|
85 |
+
alpha = blurred_mask[:, :, -1] if mask.shape[-1] == 4 else np.mean(blurred_mask, axis=2)
|
86 |
+
alpha = np.expand_dims(alpha, axis=2)
|
87 |
+
|
88 |
+
# Blend colors using alpha compositing
|
89 |
+
output = image * (1.0 - alpha) + mask * alpha
|
90 |
+
|
91 |
+
# Ensure values are in valid range and convert back to uint8
|
92 |
+
output = np.clip(output * 255.0, 0, 255).astype(np.uint8)
|
93 |
+
|
94 |
+
return output
|
95 |
+
|
96 |
+
def full_makeup(self, image, lipstick_color, eyeliner_color, eyeshadow_color):
|
97 |
+
landmarks = self.read_landmarks(image)
|
98 |
+
|
99 |
+
mask_lipstick = self.apply_lipstick(image, landmarks, lipstick_color)
|
100 |
+
mask_eyeliner = self.apply_eyeliner(image, landmarks, eyeliner_color)
|
101 |
+
mask_eyeshadow = self.apply_eyeshadow(image, landmarks, eyeshadow_color)
|
102 |
+
final_mask = mask_lipstick + mask_eyeliner + mask_eyeshadow
|
103 |
+
output = self.realistic_makeup(image, final_mask)
|
104 |
+
return output
|
src/components/necklaceTryOn.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import math
|
3 |
+
import time
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import cvzone
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from cvzone.FaceMeshModule import FaceMeshDetector
|
12 |
+
from cvzone.PoseModule import PoseDetector
|
13 |
+
|
14 |
+
from src.api.nto_api import supabase
|
15 |
+
from src.utils import add_watermark_store, add_watermark_jewelmirror, returnBytesData, addWatermark
|
16 |
+
from src.utils.exceptions import CustomException
|
17 |
+
from src.utils.logger import logger
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class NecklaceTryOnConfig:
|
22 |
+
logoURL: str = "https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/MagicMirror/FullImages/{}.png"
|
23 |
+
|
24 |
+
|
25 |
+
class NecklaceTryOn:
|
26 |
+
def __init__(self) -> None:
|
27 |
+
self.detector = PoseDetector()
|
28 |
+
self.necklaceTryOnConfig = NecklaceTryOnConfig()
|
29 |
+
self.meshDetector = FaceMeshDetector(staticMode=True, maxFaces=1)
|
30 |
+
self.logo_cache = {}
|
31 |
+
self.jewelmirror_logo_cache = {}
|
32 |
+
# self.BODY_THRESHOLD = 0.36
|
33 |
+
|
34 |
+
def offset_calculate(self, necklace_id, body_type):
|
35 |
+
try:
|
36 |
+
start_time = time.time()
|
37 |
+
row = supabase.table("MagicMirror").select("*").eq("Id", necklace_id).execute()
|
38 |
+
data = row.data
|
39 |
+
|
40 |
+
if data and len(data) > 0:
|
41 |
+
x_offset = float(data[0][f"x_{body_type}"])
|
42 |
+
y_offset = float(data[0][f"y_{body_type}"])
|
43 |
+
logger.info(
|
44 |
+
f">>> NECKLACE TRY ON :: OFFSET CALCULATION TAKEN: {round(time.time() - start_time, 2)}s <<<")
|
45 |
+
return [x_offset, y_offset]
|
46 |
+
else:
|
47 |
+
return [0.2, 0.5]
|
48 |
+
|
49 |
+
except Exception as e:
|
50 |
+
logger.error(f"Error in offset_calculate: {str(e)}")
|
51 |
+
return [0.2, 0.5]
|
52 |
+
|
53 |
+
# def calculate_body_type(self, lmList, faces, necklace_id):
|
54 |
+
# try:
|
55 |
+
# start_time = time.time()
|
56 |
+
# face = faces[0]
|
57 |
+
# left_face = np.array(face[234])
|
58 |
+
# right_face = np.array(face[454])
|
59 |
+
#
|
60 |
+
# left_shoulder = np.array([lmList[11][1], lmList[11][2]])
|
61 |
+
# right_shoulder = np.array([lmList[12][1], lmList[12][2]])
|
62 |
+
#
|
63 |
+
# face_width = np.linalg.norm(right_face - left_face)
|
64 |
+
# shoulder_width = np.linalg.norm(right_shoulder - left_shoulder)
|
65 |
+
#
|
66 |
+
# body_ratio = shoulder_width / face_width
|
67 |
+
#
|
68 |
+
# body_type = "lean" if body_ratio > self.BODY_THRESHOLD else "chubby"
|
69 |
+
# logger.info(f">>> NECKLACE TRY ON :: BODY TYPE: {body_type} <<<")
|
70 |
+
# logger.info(f">>> NECKLACE TRY ON :: Time taken to calculate body type: {round(time.time() - start_time, 2)}s <<<")
|
71 |
+
# offset = self.offset_calculate(body_type=body_type, necklace_id=necklace_id)
|
72 |
+
#
|
73 |
+
# measurements = {
|
74 |
+
# 'face_width': round(face_width, 2),
|
75 |
+
# 'shoulder_width': round(shoulder_width, 2),
|
76 |
+
# 'ratio': round(body_ratio, 2)
|
77 |
+
# }
|
78 |
+
# logger.info(f">>> NECKLACE TRY ON :: OFFSET CALCULATED: {offset}")
|
79 |
+
# logger.info(f">>> NECKLACE TRY ON :: BODY MEASUREMENTS: {measurements} <<<")
|
80 |
+
# logger.info(f">>> NECKLACE TRY ON :: TOTAL OFFSET CALCULATION TAKEN: {round(time.time() - start_time, 2)}s <<<")
|
81 |
+
# return offset
|
82 |
+
#
|
83 |
+
# except Exception as e:
|
84 |
+
# logger.error(f">>> NECKLACE TRY ON :: ERROR IN CALCULATE BODY TYPE: {str(e)} <<<")
|
85 |
+
# return [0.2, 0.5]
|
86 |
+
|
87 |
+
def necklaceTryOnOffsetBody(self, image: Image.Image, jewellery: Image.Image, storename: str,
|
88 |
+
offset: list[float]) -> list[
|
89 |
+
Union[Image.Image, str]]:
|
90 |
+
try:
|
91 |
+
logger.info(f">>> NECKLACE TRY ON STARTED :: {storename} <<<")
|
92 |
+
|
93 |
+
image = np.array(image.convert("RGB").resize((3000, 3000)))
|
94 |
+
copy_image = image.copy()
|
95 |
+
jewellery = np.array(jewellery.convert("RGBA"))
|
96 |
+
|
97 |
+
logger.info(f"NECKLACE TRY ON :: detecting pose and landmarks :: {storename}")
|
98 |
+
image = self.detector.findPose(image)
|
99 |
+
lmList, _ = self.detector.findPosition(image, bboxWithHands=False, draw=False)
|
100 |
+
|
101 |
+
img, faces = self.meshDetector.findFaceMesh(image, draw=False)
|
102 |
+
leftLandmark, rightLandmark = faces[0][172], faces[0][397]
|
103 |
+
# offset = self.calculate_body_type(lmList, faces, necklace_id)
|
104 |
+
logger.info(f"NECKLACE TRY ON :: offset: {offset}")
|
105 |
+
|
106 |
+
landmarksDistance = np.linalg.norm(np.array(leftLandmark) - np.array(rightLandmark))
|
107 |
+
|
108 |
+
logger.info(f"NECKLACE TRY ON :: estimating neck points :: {storename}")
|
109 |
+
avg_x1 = int(leftLandmark[0] - landmarksDistance * offset[0])
|
110 |
+
avg_x2 = int(rightLandmark[0] + landmarksDistance * offset[0])
|
111 |
+
avg_y1 = int(leftLandmark[1] + landmarksDistance * offset[1])
|
112 |
+
avg_y2 = int(rightLandmark[1] + landmarksDistance * offset[1])
|
113 |
+
|
114 |
+
logger.info(f"NECKLACE TRY ON :: scaling the necklace image :: {storename}")
|
115 |
+
angle = math.ceil(self.detector.findAngle((avg_x2, avg_y2), (avg_x1, avg_y1), (avg_x2, avg_y1))[0])
|
116 |
+
if avg_y2 >= avg_y1:
|
117 |
+
angle *= -1
|
118 |
+
|
119 |
+
xdist = avg_x2 - avg_x1
|
120 |
+
origImgRatio = xdist / jewellery.shape[1]
|
121 |
+
ydist = jewellery.shape[0] * origImgRatio
|
122 |
+
|
123 |
+
logger.info(f"NECKLACE TRY ON :: adding offset based on the necklace shape :: {storename}")
|
124 |
+
image_gray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY)
|
125 |
+
offset = int(0.8 * xdist * (np.argmax(image_gray[0, :] != 255) / jewellery.shape[1]))
|
126 |
+
|
127 |
+
jewellery = cv2.resize(jewellery, (int(xdist), int(ydist)), interpolation=cv2.INTER_AREA)
|
128 |
+
jewellery = cvzone.rotateImage(jewellery, angle)
|
129 |
+
y_coordinate = avg_y1 - offset
|
130 |
+
available_space = copy_image.shape[0] - y_coordinate
|
131 |
+
extra = jewellery.shape[0] - available_space
|
132 |
+
|
133 |
+
headerText = "To see more of the necklace, please step back slightly." if extra > 0 else "success"
|
134 |
+
|
135 |
+
logger.info(f"NECKLACE TRY ON :: generating output :: {storename}")
|
136 |
+
result = cvzone.overlayPNG(copy_image, jewellery, (avg_x1, y_coordinate))
|
137 |
+
result = Image.fromarray(result.astype(np.uint8))
|
138 |
+
|
139 |
+
# Add JewelMirror logo
|
140 |
+
if "default" not in self.jewelmirror_logo_cache:
|
141 |
+
self.jewelmirror_logo_cache["default"] = Image.open(
|
142 |
+
returnBytesData(url=self.necklaceTryOnConfig.logoURL.format("default")))
|
143 |
+
|
144 |
+
# Try to add store logo if it exists
|
145 |
+
store_logo_exists = False
|
146 |
+
if storename != "default":
|
147 |
+
try:
|
148 |
+
if storename not in self.logo_cache:
|
149 |
+
self.logo_cache[storename] = Image.open(
|
150 |
+
returnBytesData(url=self.necklaceTryOnConfig.logoURL.format(storename)))
|
151 |
+
result = add_watermark_store(background=result, logo=self.logo_cache[storename])
|
152 |
+
store_logo_exists = True
|
153 |
+
except Exception as e:
|
154 |
+
logger.warning(f"Store logo not found for {storename}: {str(e)}")
|
155 |
+
|
156 |
+
# Add JewelMirror logo based on store logo existence
|
157 |
+
result = add_watermark_jewelmirror(
|
158 |
+
background=result,
|
159 |
+
logo=self.jewelmirror_logo_cache["default"],
|
160 |
+
position="right" if store_logo_exists else "left"
|
161 |
+
)
|
162 |
+
|
163 |
+
# Create binary mask
|
164 |
+
blackedNecklace = np.zeros_like(copy_image)
|
165 |
+
cvzone.overlayPNG(blackedNecklace, jewellery, (avg_x1, y_coordinate))
|
166 |
+
binaryMask = cv2.cvtColor(blackedNecklace.astype(np.uint8), cv2.COLOR_BGR2GRAY)
|
167 |
+
binaryMask = (binaryMask > 5).astype(np.uint8) * 255
|
168 |
+
mask = Image.fromarray(binaryMask).convert("RGB")
|
169 |
+
|
170 |
+
logger.info(f"NECKLACE TRY ON :: output generated successfully :: {storename}")
|
171 |
+
|
172 |
+
gc.collect()
|
173 |
+
return [result, headerText, mask]
|
174 |
+
|
175 |
+
except Exception as e:
|
176 |
+
logger.error(f">>> NECKLACE TRY ON ERROR: {str(e)} <<<")
|
177 |
+
logger.error(f"{CustomException(e)}:: {storename}")
|
178 |
+
return [None, "error", None]
|
179 |
+
|
180 |
+
def necklaceTryOn(self, image: Image.Image, jewellery: Image.Image, storename: str) -> list[
|
181 |
+
Union[Image.Image, str]]:
|
182 |
+
try:
|
183 |
+
logger.info(f">>> NECKLACE TRY ON STARTED :: {storename} <<<")
|
184 |
+
|
185 |
+
image = np.array(image.convert("RGB").resize((3000, 3000)))
|
186 |
+
copy_image = image.copy()
|
187 |
+
jewellery = np.array(jewellery.convert("RGBA"))
|
188 |
+
|
189 |
+
logger.info(f"NECKLACE TRY ON :: detecting pose and landmarks :: {storename}")
|
190 |
+
image = self.detector.findPose(image)
|
191 |
+
lmList, _ = self.detector.findPosition(image, bboxWithHands=False, draw=False)
|
192 |
+
|
193 |
+
img, faces = self.meshDetector.findFaceMesh(image, draw=False)
|
194 |
+
leftLandmark, rightLandmark = faces[0][172], faces[0][397]
|
195 |
+
|
196 |
+
landmarksDistance = np.linalg.norm(np.array(leftLandmark) - np.array(rightLandmark))
|
197 |
+
|
198 |
+
logger.info(f"NECKLACE TRY ON :: estimating neck points :: {storename}")
|
199 |
+
avg_x1 = int(leftLandmark[0] - landmarksDistance * 0.12)
|
200 |
+
avg_x2 = int(rightLandmark[0] + landmarksDistance * 0.12)
|
201 |
+
avg_y1 = int(leftLandmark[1] + landmarksDistance * 0.5)
|
202 |
+
avg_y2 = int(rightLandmark[1] + landmarksDistance * 0.5)
|
203 |
+
|
204 |
+
logger.info(f"NECKLACE TRY ON :: scaling the necklace image :: {storename}")
|
205 |
+
angle = math.ceil(self.detector.findAngle((avg_x2, avg_y2), (avg_x1, avg_y1), (avg_x2, avg_y1))[0])
|
206 |
+
if avg_y2 >= avg_y1:
|
207 |
+
angle *= -1
|
208 |
+
|
209 |
+
xdist = avg_x2 - avg_x1
|
210 |
+
origImgRatio = xdist / jewellery.shape[1]
|
211 |
+
ydist = jewellery.shape[0] * origImgRatio
|
212 |
+
|
213 |
+
logger.info(f"NECKLACE TRY ON :: adding offset based on the necklace shape :: {storename}")
|
214 |
+
image_gray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY)
|
215 |
+
offset = int(0.8 * xdist * (np.argmax(image_gray[0, :] != 255) / jewellery.shape[1]))
|
216 |
+
|
217 |
+
jewellery = cv2.resize(jewellery, (int(xdist), int(ydist)), interpolation=cv2.INTER_AREA)
|
218 |
+
jewellery = cvzone.rotateImage(jewellery, angle)
|
219 |
+
y_coordinate = avg_y1 - offset
|
220 |
+
available_space = copy_image.shape[0] - y_coordinate
|
221 |
+
extra = jewellery.shape[0] - available_space
|
222 |
+
|
223 |
+
headerText = "To see more of the necklace, please step back slightly." if extra > 0 else "success"
|
224 |
+
|
225 |
+
logger.info(f"NECKLACE TRY ON :: generating output :: {storename}")
|
226 |
+
result = cvzone.overlayPNG(copy_image, jewellery, (avg_x1, y_coordinate))
|
227 |
+
image = Image.fromarray(result.astype(np.uint8))
|
228 |
+
|
229 |
+
if storename not in self.logo_cache:
|
230 |
+
self.logo_cache[storename] = Image.open(
|
231 |
+
returnBytesData(url=self.necklaceTryOnConfig.logoURL.format(storename)))
|
232 |
+
result = addWatermark(background=image, logo=self.logo_cache[storename])
|
233 |
+
|
234 |
+
# Create binary mask
|
235 |
+
blackedNecklace = np.zeros_like(copy_image)
|
236 |
+
cvzone.overlayPNG(blackedNecklace, jewellery, (avg_x1, y_coordinate))
|
237 |
+
binaryMask = cv2.cvtColor(blackedNecklace.astype(np.uint8), cv2.COLOR_BGR2GRAY)
|
238 |
+
binaryMask = (binaryMask > 5).astype(np.uint8) * 255
|
239 |
+
mask = Image.fromarray(binaryMask).convert("RGB")
|
240 |
+
|
241 |
+
gc.collect()
|
242 |
+
return [result, headerText, mask]
|
243 |
+
|
244 |
+
except Exception as e:
|
245 |
+
logger.error(f">>> NECKLACE TRY ON ERROR: {str(e)} <<<")
|
246 |
+
logger.error(f"{CustomException(e)}:: {storename}")
|
247 |
+
|
248 |
+
raise [None, "error", None]
|
249 |
+
|
250 |
+
def necklaceTryOnMannequin(self, image, jewellery) -> list[
|
251 |
+
Union[Image.Image, str]]:
|
252 |
+
try:
|
253 |
+
image_np = np.array(image.convert("RGB"))
|
254 |
+
jewellery = np.array(jewellery.convert("RGBA"))
|
255 |
+
|
256 |
+
height = image_np.shape[0]
|
257 |
+
middle_point = height // 3
|
258 |
+
upper_half = image_np[:middle_point, :]
|
259 |
+
lower_half = image_np[middle_point:, :]
|
260 |
+
|
261 |
+
upper_half = cv2.resize(upper_half, (upper_half.shape[1] * 3, upper_half.shape[0] * 3))
|
262 |
+
copy_upper = upper_half.copy()
|
263 |
+
|
264 |
+
# Apply pose detection on upper half
|
265 |
+
upper_half = self.detector.findPose(upper_half)
|
266 |
+
lmList, _ = self.detector.findPosition(upper_half, bboxWithHands=False, draw=True)
|
267 |
+
|
268 |
+
img, faces = self.meshDetector.findFaceMesh(upper_half, draw=True)
|
269 |
+
|
270 |
+
if not faces:
|
271 |
+
return Exception("No face detected in the image")
|
272 |
+
|
273 |
+
leftLandmark, rightLandmark = faces[0][172], faces[0][397]
|
274 |
+
|
275 |
+
landmarksDistance = np.linalg.norm(np.array(leftLandmark) - np.array(rightLandmark))
|
276 |
+
|
277 |
+
avg_x1 = int(leftLandmark[0] - landmarksDistance * 0.12)
|
278 |
+
avg_x2 = int(rightLandmark[0] + landmarksDistance * 0.12)
|
279 |
+
avg_y1 = int(leftLandmark[1] + landmarksDistance * 0.5)
|
280 |
+
avg_y2 = int(rightLandmark[1] + landmarksDistance * 0.5)
|
281 |
+
|
282 |
+
angle = math.ceil(self.detector.findAngle((avg_x2, avg_y2), (avg_x1, avg_y1), (avg_x2, avg_y1))[0])
|
283 |
+
if avg_y2 >= avg_y1:
|
284 |
+
angle *= -1
|
285 |
+
|
286 |
+
xdist = avg_x2 - avg_x1
|
287 |
+
origImgRatio = xdist / jewellery.shape[1]
|
288 |
+
ydist = jewellery.shape[0] * origImgRatio
|
289 |
+
|
290 |
+
image_gray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY)
|
291 |
+
offset = int(0.8 * xdist * (np.argmax(image_gray[0, :] != 255) / jewellery.shape[1]))
|
292 |
+
|
293 |
+
jewellery = cv2.resize(jewellery, (int(xdist), int(ydist)), interpolation=cv2.INTER_AREA)
|
294 |
+
jewellery = cvzone.rotateImage(jewellery, angle)
|
295 |
+
|
296 |
+
y_coordinate = avg_y1 - offset
|
297 |
+
|
298 |
+
result_upper = cvzone.overlayPNG(copy_upper, jewellery, (avg_x1, y_coordinate))
|
299 |
+
|
300 |
+
final_result = cv2.resize(result_upper, (image_np.shape[1], image_np.shape[0] - middle_point))
|
301 |
+
|
302 |
+
final_result = np.vstack((final_result, lower_half))
|
303 |
+
|
304 |
+
gc.collect()
|
305 |
+
return Image.fromarray(final_result.astype(np.uint8)), Image.fromarray(result_upper.astype(np.uint8))
|
306 |
+
|
307 |
+
except Exception as e:
|
308 |
+
print(f"Error: {e}")
|
309 |
+
return e
|
310 |
+
|
311 |
+
def shoulderPointMaskGeneration(self, image: Image.Image) -> tuple[Image.Image, tuple[int, int], tuple[int, int]]:
|
312 |
+
try:
|
313 |
+
image = np.array(image)
|
314 |
+
copy_image = image.copy()
|
315 |
+
|
316 |
+
logger.info("SHOULDER POINT MASK GENERATION :: detecting pose and landmarks")
|
317 |
+
image = self.detector.findPose(image)
|
318 |
+
lmList, _ = self.detector.findPosition(image, bboxWithHands=False, draw=False)
|
319 |
+
|
320 |
+
img, faces = self.meshDetector.findFaceMesh(image, draw=False)
|
321 |
+
leftLandmark, rightLandmark = faces[0][172], faces[0][397]
|
322 |
+
|
323 |
+
landmarksDistance = np.linalg.norm(np.array(leftLandmark) - np.array(rightLandmark))
|
324 |
+
|
325 |
+
logger.info("SHOULDER POINT MASK GENERATION :: estimating neck points")
|
326 |
+
# Using the same point calculation logic as necklaceTryOn
|
327 |
+
avg_x1 = int(leftLandmark[0] - landmarksDistance * 0.12)
|
328 |
+
avg_x2 = int(rightLandmark[0] + landmarksDistance * 0.12)
|
329 |
+
avg_y1 = int(leftLandmark[1] + landmarksDistance * 0.5)
|
330 |
+
avg_y2 = int(rightLandmark[1] + landmarksDistance * 0.5)
|
331 |
+
|
332 |
+
logger.info("SHOULDER POINT MASK GENERATION :: generating shoulder point mask")
|
333 |
+
mask = np.zeros_like(image[:, :, 0])
|
334 |
+
mask[avg_y1:, :] = 255
|
335 |
+
pts = np.array([[0, 0], [image.shape[1], 0], [avg_x2, avg_y2], [avg_x1, avg_y1]], np.int32)
|
336 |
+
pts = pts.reshape((-1, 1, 2))
|
337 |
+
cv2.fillPoly(mask, [pts], 0)
|
338 |
+
|
339 |
+
black_n_white_mask = np.zeros_like(image[:, :, 0])
|
340 |
+
black_n_white_mask[avg_y1:, :] = 255
|
341 |
+
logger.info("SHOULDER POINT MASK GENERATION :: mask generated successfully")
|
342 |
+
|
343 |
+
left_point = (avg_x1, avg_y1)
|
344 |
+
right_point = (avg_x2, avg_y2)
|
345 |
+
|
346 |
+
return Image.fromarray(black_n_white_mask.astype(np.uint8)), left_point, right_point
|
347 |
+
|
348 |
+
except Exception as e:
|
349 |
+
logger.error(f"SHOULDER POINT MASK GENERATION ERROR: {str(e)}")
|
350 |
+
raise CustomException(e)
|
351 |
+
|
352 |
+
def necklaceTryOnWithPoints(self, image: Image.Image, jewellery: Image.Image, storename: str,
|
353 |
+
left_point: tuple[int, int], right_point: tuple[int, int]) -> list[
|
354 |
+
Union[Image.Image, str]]:
|
355 |
+
try:
|
356 |
+
logger.info(f">>> NECKLACE TRY ON WITH POINTS STARTED :: {storename} <<<")
|
357 |
+
|
358 |
+
image = np.array(image.convert("RGB"))
|
359 |
+
# .resize((3000, 3000)))
|
360 |
+
copy_image = image.copy()
|
361 |
+
jewellery = np.array(jewellery.convert("RGBA"))
|
362 |
+
|
363 |
+
logger.info(f"NECKLACE TRY ON :: scaling the necklace image based on given points :: {storename}")
|
364 |
+
|
365 |
+
avg_x1 = left_point[0]
|
366 |
+
avg_x2 = right_point[0]
|
367 |
+
avg_y1 = left_point[1]
|
368 |
+
avg_y2 = right_point[1]
|
369 |
+
|
370 |
+
angle = math.ceil(self.detector.findAngle((avg_x2, avg_y2), (avg_x1, avg_y1), (avg_x2, avg_y1))[0])
|
371 |
+
if avg_y2 >= avg_y1:
|
372 |
+
angle *= -1
|
373 |
+
|
374 |
+
xdist = avg_x2 - avg_x1
|
375 |
+
origImgRatio = xdist / jewellery.shape[1]
|
376 |
+
ydist = jewellery.shape[0] * origImgRatio
|
377 |
+
|
378 |
+
logger.info(f"NECKLACE TRY ON :: adding offset based on the necklace shape :: {storename}")
|
379 |
+
image_gray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY)
|
380 |
+
offset = int(0.8 * xdist * (np.argmax(image_gray[0, :] != 255) / jewellery.shape[1]))
|
381 |
+
|
382 |
+
jewellery = cv2.resize(jewellery, (int(xdist), int(ydist)), interpolation=cv2.INTER_AREA)
|
383 |
+
jewellery = cvzone.rotateImage(jewellery, angle)
|
384 |
+
y_coordinate = avg_y1 - offset
|
385 |
+
available_space = copy_image.shape[0] - y_coordinate
|
386 |
+
extra = jewellery.shape[0] - available_space
|
387 |
+
|
388 |
+
headerText = "To see more of the necklace, please step back slightly." if extra > 0 else "success"
|
389 |
+
|
390 |
+
logger.info(f"NECKLACE TRY ON :: generating output with given points :: {storename}")
|
391 |
+
result = cvzone.overlayPNG(copy_image, jewellery, (avg_x1, y_coordinate))
|
392 |
+
image = Image.fromarray(result.astype(np.uint8))
|
393 |
+
|
394 |
+
if storename not in self.logo_cache:
|
395 |
+
self.logo_cache[storename] = Image.open(
|
396 |
+
returnBytesData(url=self.necklaceTryOnConfig.logoURL.format(storename)))
|
397 |
+
result = addWatermark(background=image, logo=self.logo_cache[storename])
|
398 |
+
|
399 |
+
# Create binary mask
|
400 |
+
blackedNecklace = np.zeros_like(copy_image)
|
401 |
+
cvzone.overlayPNG(blackedNecklace, jewellery, (avg_x1, y_coordinate))
|
402 |
+
binaryMask = cv2.cvtColor(blackedNecklace.astype(np.uint8), cv2.COLOR_BGR2GRAY)
|
403 |
+
binaryMask = (binaryMask > 5).astype(np.uint8) * 255
|
404 |
+
mask = Image.fromarray(binaryMask).convert("RGB")
|
405 |
+
|
406 |
+
gc.collect()
|
407 |
+
return [result, headerText, mask]
|
408 |
+
|
409 |
+
except Exception as e:
|
410 |
+
logger.error(f"{CustomException(e)}:: {storename}")
|
411 |
+
raise CustomException(e)
|
412 |
+
|
413 |
+
def necklaceTryOnOffset(self, image: Image.Image, jewellery: Image.Image, storename: str, offset_x: float = 0.12,
|
414 |
+
offset_y: float = 0.5) -> list[Union[Image.Image, str]]:
|
415 |
+
try:
|
416 |
+
image = np.array(image.convert("RGB").resize((3000, 3000)))
|
417 |
+
copy_image = image.copy()
|
418 |
+
jewellery = np.array(jewellery.convert("RGBA"))
|
419 |
+
|
420 |
+
image = self.detector.findPose(image)
|
421 |
+
lmList, _ = self.detector.findPosition(image, bboxWithHands=False, draw=False)
|
422 |
+
|
423 |
+
img, faces = self.meshDetector.findFaceMesh(image, draw=False)
|
424 |
+
leftLandmark, rightLandmark = faces[0][172], faces[0][397]
|
425 |
+
landmarksDistance = np.linalg.norm(np.array(leftLandmark) - np.array(rightLandmark))
|
426 |
+
|
427 |
+
avg_x1 = int(leftLandmark[0] - landmarksDistance * offset_x)
|
428 |
+
avg_x2 = int(rightLandmark[0] + landmarksDistance * offset_x)
|
429 |
+
avg_y1 = int(leftLandmark[1] + landmarksDistance * offset_y)
|
430 |
+
avg_y2 = int(rightLandmark[1] + landmarksDistance * offset_y)
|
431 |
+
|
432 |
+
angle = math.ceil(self.detector.findAngle((avg_x2, avg_y2), (avg_x1, avg_y1), (avg_x2, avg_y1))[0])
|
433 |
+
if avg_y2 >= avg_y1:
|
434 |
+
angle *= -1
|
435 |
+
|
436 |
+
xdist = avg_x2 - avg_x1
|
437 |
+
origImgRatio = xdist / jewellery.shape[1]
|
438 |
+
ydist = jewellery.shape[0] * origImgRatio
|
439 |
+
|
440 |
+
image_gray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY)
|
441 |
+
offset = int(0.8 * xdist * (np.argmax(image_gray[0, :] != 255) / jewellery.shape[1]))
|
442 |
+
|
443 |
+
jewellery = cv2.resize(jewellery, (int(xdist), int(ydist)), interpolation=cv2.INTER_AREA)
|
444 |
+
jewellery = cvzone.rotateImage(jewellery, angle)
|
445 |
+
|
446 |
+
y_coordinate = avg_y1 - offset
|
447 |
+
available_space = copy_image.shape[0] - y_coordinate
|
448 |
+
|
449 |
+
result = cvzone.overlayPNG(copy_image, jewellery, (avg_x1, y_coordinate))
|
450 |
+
headerText = "To see more of the necklace, please step back slightly." if available_space < jewellery.shape[
|
451 |
+
0] else "success"
|
452 |
+
image = Image.fromarray(result.astype(np.uint8))
|
453 |
+
|
454 |
+
if storename not in self.logo_cache:
|
455 |
+
self.logo_cache[storename] = Image.open(
|
456 |
+
returnBytesData(url=self.necklaceTryOnConfig.logoURL.format(storename)))
|
457 |
+
result = addWatermark(background=image, logo=self.logo_cache[storename])
|
458 |
+
|
459 |
+
blackedNecklace = np.zeros_like(copy_image)
|
460 |
+
cvzone.overlayPNG(blackedNecklace, jewellery, (avg_x1, y_coordinate))
|
461 |
+
binaryMask = cv2.cvtColor(blackedNecklace.astype(np.uint8), cv2.COLOR_BGR2GRAY)
|
462 |
+
binaryMask = (binaryMask > 5).astype(np.uint8) * 255
|
463 |
+
mask = Image.fromarray(binaryMask).convert("RGB")
|
464 |
+
|
465 |
+
gc.collect()
|
466 |
+
return [result, headerText, mask]
|
467 |
+
|
468 |
+
except Exception as e:
|
469 |
+
logger.error(f">>> NECKLACE TRY ON ERROR: {str(e)} <<<")
|
470 |
+
logger.error(f"{CustomException(e)}:: {storename}")
|
471 |
+
raise [None, "error", None]
|
src/components/title_des_gen.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
project @ NTO-TCP-HF
|
3 |
+
created @ 2024-10-29
|
4 |
+
author @ github.com/ishworrsubedii
|
5 |
+
"""
|
6 |
+
import os
|
7 |
+
|
8 |
+
import google.generativeai as genai
|
9 |
+
from tempfile import NamedTemporaryFile
|
10 |
+
|
11 |
+
|
12 |
+
class NecklaceProductListing:
|
13 |
+
def __init__(self, api_key):
|
14 |
+
genai.configure(api_key=api_key)
|
15 |
+
self.model = genai.GenerativeModel(model_name="gemini-1.5-flash")
|
16 |
+
self.prompt = """Analyze this necklace image and create an e-commerce product listing with:
|
17 |
+
1. An SEO-friendly, compelling product title (3-8 words)
|
18 |
+
2. A persuasive product description for online shoppers (80-120 words)
|
19 |
+
Format your response exactly like this:
|
20 |
+
Title: [Product Title]
|
21 |
+
Description: [Product Description]
|
22 |
+
For the description, include:
|
23 |
+
- Opening hook about the piece's beauty or uniqueness
|
24 |
+
- Key materials and specifications (metal type, gemstones, length, closure type)
|
25 |
+
- Highlight 2-3 standout design features
|
26 |
+
- Styling suggestions (what to wear it with, occasions)
|
27 |
+
- Quality/craftsmanship mentions
|
28 |
+
- One emotional benefit (how it makes the wearer feel)
|
29 |
+
Write in a professional yet warm tone that appeals to online jewelry shoppers. Focus on benefits and value proposition."""
|
30 |
+
|
31 |
+
def save_image_tempfile(self, pil_image):
|
32 |
+
with NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
|
33 |
+
pil_image.save(temp_file, format="PNG")
|
34 |
+
return temp_file.name
|
35 |
+
|
36 |
+
def image_upload(self, image_path):
|
37 |
+
sample_file = genai.upload_file(path=image_path, display_name="necklace_image")
|
38 |
+
return sample_file
|
39 |
+
|
40 |
+
def gen_title_desc(self, image):
|
41 |
+
image_path = self.save_image_tempfile(image)
|
42 |
+
sample_file = self.image_upload(image_path)
|
43 |
+
os.remove(image_path)
|
44 |
+
|
45 |
+
response = self.model.generate_content([sample_file, self.prompt])
|
46 |
+
return response.text
|
src/pipelines/__init__.py
ADDED
File without changes
|
src/pipelines/completePipeline.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
from src.components.makeup_try_on import MakeupTryOn
|
6 |
+
from src.components.necklaceTryOn import NecklaceTryOn
|
7 |
+
|
8 |
+
|
9 |
+
class Pipeline:
|
10 |
+
def __init__(self) -> None:
|
11 |
+
self.necklaceTryOnObj = NecklaceTryOn()
|
12 |
+
self.makeup_tryon_obj = MakeupTryOn()
|
13 |
+
|
14 |
+
async def necklaceTryOn_(self, image: Image.Image, jewellery: Image.Image, storename: str) -> \
|
15 |
+
list[
|
16 |
+
Union[Image.Image, str]]:
|
17 |
+
result, headerText, mask = self.necklaceTryOnObj.necklaceTryOn(image=image, jewellery=jewellery,
|
18 |
+
storename=storename)
|
19 |
+
return [result, headerText, mask]
|
20 |
+
|
21 |
+
async def necklaceTryOnDynamicOffset_(self, image: Image.Image, jewellery: Image.Image, storename: str,
|
22 |
+
offset: list) -> \
|
23 |
+
list[
|
24 |
+
Union[Image.Image, str]]:
|
25 |
+
result, headerText, mask = self.necklaceTryOnObj.necklaceTryOnOffsetBody(image=image, jewellery=jewellery,
|
26 |
+
storename=storename,
|
27 |
+
offset=offset)
|
28 |
+
return [result, headerText, mask]
|
29 |
+
|
30 |
+
async def shoulderPointMaskGeneration_(self, image: Image.Image):
|
31 |
+
mask, left, right = self.necklaceTryOnObj.shoulderPointMaskGeneration(image=image)
|
32 |
+
return mask, left, right
|
33 |
+
|
34 |
+
async def canvasPoint(self, image: Image.Image, jewellery: Image.Image, storename: str) -> dict:
|
35 |
+
points = self.necklaceTryOnObj.canvasPoints(image=image, jewellery=jewellery, storename=storename)
|
36 |
+
return points
|
37 |
+
|
38 |
+
async def necklaceTryOnWithPoints_(self, image: Image.Image, jewellery: Image.Image, left_shoulder: tuple[int, int],
|
39 |
+
right_shoulder: tuple[int, int], storename: str) -> list[
|
40 |
+
Union[Image.Image, str]]:
|
41 |
+
result, headerText, mask = self.necklaceTryOnObj.necklaceTryOnWithPoints(image=image, jewellery=jewellery,
|
42 |
+
left_point=left_shoulder,
|
43 |
+
right_point=right_shoulder,
|
44 |
+
storename=storename)
|
45 |
+
return [result, headerText, mask]
|
46 |
+
|
47 |
+
def necklaceTryOnMannequin_(self, image: Image.Image, jewellery: Image.Image):
|
48 |
+
result, resized_image = self.necklaceTryOnObj.necklaceTryOnMannequin(image, jewellery)
|
49 |
+
|
50 |
+
return result, resized_image
|
51 |
+
|
52 |
+
async def necklace_try_on_offset(self, image: Image.Image, jewellery: Image.Image, offset_x: float = 0.12,
|
53 |
+
offset_y: float = 0.5, storename: str = "default") -> list[
|
54 |
+
Union[Image.Image, str]]:
|
55 |
+
result, header_text, mask = self.necklaceTryOnObj.necklaceTryOnOffset(image, jewellery, offset_x=offset_x,
|
56 |
+
offset_y=offset_y, storename=storename)
|
57 |
+
|
58 |
+
return result, header_text, mask
|
59 |
+
|
60 |
+
async def makeup_tryon(self, image: Image.Image, lipstick_color: set, eyeliner_color: set, eyeshadow_color: set):
|
61 |
+
|
62 |
+
result = self.makeup_tryon_obj.full_makeup(image, lipstick_color=lipstick_color, eyeliner_color=eyeliner_color,
|
63 |
+
eyeshadow_color=eyeshadow_color)
|
64 |
+
return result
|
src/utils/__init__.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from io import BytesIO
|
3 |
+
|
4 |
+
import requests
|
5 |
+
from PIL import Image
|
6 |
+
from supabase import create_client, Client
|
7 |
+
|
8 |
+
|
9 |
+
# function to add watermark to images
|
10 |
+
def addWatermark(background: Image.Image, logo: Image.Image) -> Image.Image:
|
11 |
+
background = background.convert("RGBA")
|
12 |
+
logo = logo.convert("RGBA").resize((int(0.08 * background.size[0]), int(0.08 * background.size[0])))
|
13 |
+
background.paste(logo, (10, background.size[0] - logo.size[0] - 10), logo)
|
14 |
+
return background
|
15 |
+
|
16 |
+
|
17 |
+
def add_watermark_store(background: Image.Image, logo: Image.Image) -> Image.Image:
|
18 |
+
"""Add store logo watermark to top-left corner"""
|
19 |
+
background = background.convert("RGBA")
|
20 |
+
|
21 |
+
# Calculate size based on image dimensions (10% of width for store logo)
|
22 |
+
width = int(0.20 * background.size[0])
|
23 |
+
height = int(width * (logo.size[1] / logo.size[0])) # maintain aspect ratio
|
24 |
+
logo = logo.convert("RGBA").resize((width, height))
|
25 |
+
|
26 |
+
padding = int(0.01 * background.size[0])
|
27 |
+
|
28 |
+
background.paste(logo, (padding, padding), logo)
|
29 |
+
return background
|
30 |
+
|
31 |
+
|
32 |
+
def add_watermark_jewelmirror(background: Image.Image, logo: Image.Image, position: str = "right") -> Image.Image:
|
33 |
+
"""
|
34 |
+
Add JewelMirror logo watermark
|
35 |
+
position: "right" for top-right corner, "left" for top-left corner
|
36 |
+
"""
|
37 |
+
background = background.convert("RGBA")
|
38 |
+
|
39 |
+
# Calculate size based on image dimensions (8% of width for JewelMirror logo)
|
40 |
+
width = int(0.20 * background.size[0])
|
41 |
+
height = int(width * (logo.size[1] / logo.size[0])) # maintain aspect ratio
|
42 |
+
logo = logo.convert("RGBA").resize((width, height))
|
43 |
+
|
44 |
+
# Add padding from corners (1% of image width)
|
45 |
+
padding = int(0.01 * background.size[0])
|
46 |
+
|
47 |
+
# Position based on parameter
|
48 |
+
background.paste(logo, (background.size[0] - logo.size[0] - padding, padding), logo)
|
49 |
+
|
50 |
+
return background
|
51 |
+
|
52 |
+
|
53 |
+
def add_watermark_jewelmirror_bottom_right(background: Image.Image, logo: Image.Image) -> Image.Image:
|
54 |
+
background = background.convert("RGBA")
|
55 |
+
logo = logo.convert("RGBA")
|
56 |
+
|
57 |
+
w, h = int(0.08 * background.size[0]), int(0.06 * background.size[1])
|
58 |
+
logo = logo.resize((w, h))
|
59 |
+
|
60 |
+
padding = 0
|
61 |
+
logo_x = background.size[0] - logo.size[0] - padding
|
62 |
+
logo_y = background.size[1] - logo.size[1] - padding
|
63 |
+
|
64 |
+
background.paste(logo, (logo_x, logo_y), logo)
|
65 |
+
|
66 |
+
return background
|
67 |
+
|
68 |
+
|
69 |
+
# function to download an image from url and return as bytes objects
|
70 |
+
def returnBytesData(url: str) -> BytesIO:
|
71 |
+
response = requests.get(url)
|
72 |
+
return BytesIO(response.content)
|
73 |
+
|
74 |
+
|
75 |
+
# function to get public URLs of paths
|
76 |
+
def supabaseGetPublicURL(path: str) -> str:
|
77 |
+
url_string = f"https://lvuhhlrkcuexzqtsbqyu.supabase.co/storage/v1/object/public/Stores/{path}"
|
78 |
+
return url_string.replace(" ", "%20")
|
79 |
+
|
80 |
+
|
81 |
+
# function to deduct credit
|
82 |
+
def deductAndTrackCredit(storename: str, endpoint: str) -> str:
|
83 |
+
url: str = os.environ["SUPABASE_URL"]
|
84 |
+
key: str = os.environ["SUPABASE_KEY"]
|
85 |
+
supabase: Client = create_client(url, key)
|
86 |
+
current, _ = supabase.table('ClientConfig').select('CreditBalance').eq("StoreName", f"{storename}").execute()
|
87 |
+
if current[1] == []:
|
88 |
+
return "Not Found"
|
89 |
+
else:
|
90 |
+
current = current[1][0]["CreditBalance"]
|
91 |
+
if current > 0:
|
92 |
+
data, _ = supabase.table('ClientConfig').update({'CreditBalance': current - 1}).eq("StoreName",
|
93 |
+
f"{storename}").execute()
|
94 |
+
data, _ = supabase.table('UsageHistory').insert(
|
95 |
+
{'StoreName': f"{storename}", 'APIEndpoint': f"{endpoint}"}).execute()
|
96 |
+
return "Success"
|
97 |
+
else:
|
98 |
+
return "No Credits Available"
|
src/utils/backgroundEnhancerArchitecture.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from huggingface_hub import PyTorchModelHubMixin
|
5 |
+
|
6 |
+
|
7 |
+
class REBNCONV(nn.Module):
|
8 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
9 |
+
super(REBNCONV, self).__init__()
|
10 |
+
|
11 |
+
self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
|
12 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
13 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
hx = x
|
17 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
18 |
+
|
19 |
+
return xout
|
20 |
+
|
21 |
+
|
22 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
23 |
+
def _upsample_like(src, tar):
|
24 |
+
src = F.interpolate(src, size=tar.shape[2:], mode='bilinear')
|
25 |
+
|
26 |
+
return src
|
27 |
+
|
28 |
+
|
29 |
+
### RSU-7 ###
|
30 |
+
class RSU7(nn.Module):
|
31 |
+
|
32 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
33 |
+
super(RSU7, self).__init__()
|
34 |
+
|
35 |
+
self.in_ch = in_ch
|
36 |
+
self.mid_ch = mid_ch
|
37 |
+
self.out_ch = out_ch
|
38 |
+
|
39 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
40 |
+
|
41 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
42 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
43 |
+
|
44 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
45 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
46 |
+
|
47 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
48 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
49 |
+
|
50 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
51 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
52 |
+
|
53 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
54 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
55 |
+
|
56 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
57 |
+
|
58 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
59 |
+
|
60 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
61 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
62 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
63 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
64 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
65 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
b, c, h, w = x.shape
|
69 |
+
|
70 |
+
hx = x
|
71 |
+
hxin = self.rebnconvin(hx)
|
72 |
+
|
73 |
+
hx1 = self.rebnconv1(hxin)
|
74 |
+
hx = self.pool1(hx1)
|
75 |
+
|
76 |
+
hx2 = self.rebnconv2(hx)
|
77 |
+
hx = self.pool2(hx2)
|
78 |
+
|
79 |
+
hx3 = self.rebnconv3(hx)
|
80 |
+
hx = self.pool3(hx3)
|
81 |
+
|
82 |
+
hx4 = self.rebnconv4(hx)
|
83 |
+
hx = self.pool4(hx4)
|
84 |
+
|
85 |
+
hx5 = self.rebnconv5(hx)
|
86 |
+
hx = self.pool5(hx5)
|
87 |
+
|
88 |
+
hx6 = self.rebnconv6(hx)
|
89 |
+
|
90 |
+
hx7 = self.rebnconv7(hx6)
|
91 |
+
|
92 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
93 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
94 |
+
|
95 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
96 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
97 |
+
|
98 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
99 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
100 |
+
|
101 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
102 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
103 |
+
|
104 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
105 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
106 |
+
|
107 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
108 |
+
|
109 |
+
return hx1d + hxin
|
110 |
+
|
111 |
+
|
112 |
+
### RSU-6 ###
|
113 |
+
class RSU6(nn.Module):
|
114 |
+
|
115 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
116 |
+
super(RSU6, self).__init__()
|
117 |
+
|
118 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
119 |
+
|
120 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
121 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
122 |
+
|
123 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
124 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
125 |
+
|
126 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
127 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
128 |
+
|
129 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
130 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
131 |
+
|
132 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
133 |
+
|
134 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
135 |
+
|
136 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
137 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
138 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
139 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
140 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
141 |
+
|
142 |
+
def forward(self, x):
|
143 |
+
hx = x
|
144 |
+
|
145 |
+
hxin = self.rebnconvin(hx)
|
146 |
+
|
147 |
+
hx1 = self.rebnconv1(hxin)
|
148 |
+
hx = self.pool1(hx1)
|
149 |
+
|
150 |
+
hx2 = self.rebnconv2(hx)
|
151 |
+
hx = self.pool2(hx2)
|
152 |
+
|
153 |
+
hx3 = self.rebnconv3(hx)
|
154 |
+
hx = self.pool3(hx3)
|
155 |
+
|
156 |
+
hx4 = self.rebnconv4(hx)
|
157 |
+
hx = self.pool4(hx4)
|
158 |
+
|
159 |
+
hx5 = self.rebnconv5(hx)
|
160 |
+
|
161 |
+
hx6 = self.rebnconv6(hx5)
|
162 |
+
|
163 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
164 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
165 |
+
|
166 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
167 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
168 |
+
|
169 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
170 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
171 |
+
|
172 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
173 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
174 |
+
|
175 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
176 |
+
|
177 |
+
return hx1d + hxin
|
178 |
+
|
179 |
+
|
180 |
+
### RSU-5 ###
|
181 |
+
class RSU5(nn.Module):
|
182 |
+
|
183 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
184 |
+
super(RSU5, self).__init__()
|
185 |
+
|
186 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
187 |
+
|
188 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
189 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
190 |
+
|
191 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
192 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
193 |
+
|
194 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
195 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
196 |
+
|
197 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
198 |
+
|
199 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
200 |
+
|
201 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
202 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
203 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
204 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
hx = x
|
208 |
+
|
209 |
+
hxin = self.rebnconvin(hx)
|
210 |
+
|
211 |
+
hx1 = self.rebnconv1(hxin)
|
212 |
+
hx = self.pool1(hx1)
|
213 |
+
|
214 |
+
hx2 = self.rebnconv2(hx)
|
215 |
+
hx = self.pool2(hx2)
|
216 |
+
|
217 |
+
hx3 = self.rebnconv3(hx)
|
218 |
+
hx = self.pool3(hx3)
|
219 |
+
|
220 |
+
hx4 = self.rebnconv4(hx)
|
221 |
+
|
222 |
+
hx5 = self.rebnconv5(hx4)
|
223 |
+
|
224 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
225 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
226 |
+
|
227 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
228 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
229 |
+
|
230 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
231 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
232 |
+
|
233 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
234 |
+
|
235 |
+
return hx1d + hxin
|
236 |
+
|
237 |
+
|
238 |
+
### RSU-4 ###
|
239 |
+
class RSU4(nn.Module):
|
240 |
+
|
241 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
242 |
+
super(RSU4, self).__init__()
|
243 |
+
|
244 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
245 |
+
|
246 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
247 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
248 |
+
|
249 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
250 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
251 |
+
|
252 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
253 |
+
|
254 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
255 |
+
|
256 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
257 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
258 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
259 |
+
|
260 |
+
def forward(self, x):
|
261 |
+
hx = x
|
262 |
+
|
263 |
+
hxin = self.rebnconvin(hx)
|
264 |
+
|
265 |
+
hx1 = self.rebnconv1(hxin)
|
266 |
+
hx = self.pool1(hx1)
|
267 |
+
|
268 |
+
hx2 = self.rebnconv2(hx)
|
269 |
+
hx = self.pool2(hx2)
|
270 |
+
|
271 |
+
hx3 = self.rebnconv3(hx)
|
272 |
+
|
273 |
+
hx4 = self.rebnconv4(hx3)
|
274 |
+
|
275 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
276 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
277 |
+
|
278 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
279 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
280 |
+
|
281 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
282 |
+
|
283 |
+
return hx1d + hxin
|
284 |
+
|
285 |
+
|
286 |
+
### RSU-4F ###
|
287 |
+
class RSU4F(nn.Module):
|
288 |
+
|
289 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
290 |
+
super(RSU4F, self).__init__()
|
291 |
+
|
292 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
293 |
+
|
294 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
295 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
296 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
297 |
+
|
298 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
299 |
+
|
300 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
301 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
302 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
303 |
+
|
304 |
+
def forward(self, x):
|
305 |
+
hx = x
|
306 |
+
|
307 |
+
hxin = self.rebnconvin(hx)
|
308 |
+
|
309 |
+
hx1 = self.rebnconv1(hxin)
|
310 |
+
hx2 = self.rebnconv2(hx1)
|
311 |
+
hx3 = self.rebnconv3(hx2)
|
312 |
+
|
313 |
+
hx4 = self.rebnconv4(hx3)
|
314 |
+
|
315 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
316 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
317 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
318 |
+
|
319 |
+
return hx1d + hxin
|
320 |
+
|
321 |
+
|
322 |
+
class myrebnconv(nn.Module):
|
323 |
+
def __init__(self, in_ch=3,
|
324 |
+
out_ch=1,
|
325 |
+
kernel_size=3,
|
326 |
+
stride=1,
|
327 |
+
padding=1,
|
328 |
+
dilation=1,
|
329 |
+
groups=1):
|
330 |
+
super(myrebnconv, self).__init__()
|
331 |
+
|
332 |
+
self.conv = nn.Conv2d(in_ch,
|
333 |
+
out_ch,
|
334 |
+
kernel_size=kernel_size,
|
335 |
+
stride=stride,
|
336 |
+
padding=padding,
|
337 |
+
dilation=dilation,
|
338 |
+
groups=groups)
|
339 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
340 |
+
self.rl = nn.ReLU(inplace=True)
|
341 |
+
|
342 |
+
def forward(self, x):
|
343 |
+
return self.rl(self.bn(self.conv(x)))
|
344 |
+
|
345 |
+
|
346 |
+
class BackgroundEnhancerArchitecture(nn.Module, PyTorchModelHubMixin):
|
347 |
+
|
348 |
+
def __init__(self, config: dict = {"in_ch": 3, "out_ch": 1}):
|
349 |
+
super(BackgroundEnhancerArchitecture, self).__init__()
|
350 |
+
in_ch = config["in_ch"]
|
351 |
+
out_ch = config["out_ch"]
|
352 |
+
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
353 |
+
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
354 |
+
|
355 |
+
self.stage1 = RSU7(64, 32, 64)
|
356 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
357 |
+
|
358 |
+
self.stage2 = RSU6(64, 32, 128)
|
359 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
360 |
+
|
361 |
+
self.stage3 = RSU5(128, 64, 256)
|
362 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
363 |
+
|
364 |
+
self.stage4 = RSU4(256, 128, 512)
|
365 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
366 |
+
|
367 |
+
self.stage5 = RSU4F(512, 256, 512)
|
368 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
369 |
+
|
370 |
+
self.stage6 = RSU4F(512, 256, 512)
|
371 |
+
|
372 |
+
# decoder
|
373 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
374 |
+
self.stage4d = RSU4(1024, 128, 256)
|
375 |
+
self.stage3d = RSU5(512, 64, 128)
|
376 |
+
self.stage2d = RSU6(256, 32, 64)
|
377 |
+
self.stage1d = RSU7(128, 16, 64)
|
378 |
+
|
379 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
380 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
381 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
382 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
383 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
384 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
385 |
+
|
386 |
+
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
hx = x
|
390 |
+
|
391 |
+
hxin = self.conv_in(hx)
|
392 |
+
# hx = self.pool_in(hxin)
|
393 |
+
|
394 |
+
# stage 1
|
395 |
+
hx1 = self.stage1(hxin)
|
396 |
+
hx = self.pool12(hx1)
|
397 |
+
|
398 |
+
# stage 2
|
399 |
+
hx2 = self.stage2(hx)
|
400 |
+
hx = self.pool23(hx2)
|
401 |
+
|
402 |
+
# stage 3
|
403 |
+
hx3 = self.stage3(hx)
|
404 |
+
hx = self.pool34(hx3)
|
405 |
+
|
406 |
+
# stage 4
|
407 |
+
hx4 = self.stage4(hx)
|
408 |
+
hx = self.pool45(hx4)
|
409 |
+
|
410 |
+
# stage 5
|
411 |
+
hx5 = self.stage5(hx)
|
412 |
+
hx = self.pool56(hx5)
|
413 |
+
|
414 |
+
# stage 6
|
415 |
+
hx6 = self.stage6(hx)
|
416 |
+
hx6up = _upsample_like(hx6, hx5)
|
417 |
+
|
418 |
+
# -------------------- decoder --------------------
|
419 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
420 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
421 |
+
|
422 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
423 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
424 |
+
|
425 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
426 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
427 |
+
|
428 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
429 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
430 |
+
|
431 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
432 |
+
|
433 |
+
# side output
|
434 |
+
d1 = self.side1(hx1d)
|
435 |
+
d1 = _upsample_like(d1, x)
|
436 |
+
|
437 |
+
d2 = self.side2(hx2d)
|
438 |
+
d2 = _upsample_like(d2, x)
|
439 |
+
|
440 |
+
d3 = self.side3(hx3d)
|
441 |
+
d3 = _upsample_like(d3, x)
|
442 |
+
|
443 |
+
d4 = self.side4(hx4d)
|
444 |
+
d4 = _upsample_like(d4, x)
|
445 |
+
|
446 |
+
d5 = self.side5(hx5d)
|
447 |
+
d5 = _upsample_like(d5, x)
|
448 |
+
|
449 |
+
d6 = self.side6(hx6)
|
450 |
+
d6 = _upsample_like(d6, x)
|
451 |
+
|
452 |
+
return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1d, hx2d,
|
453 |
+
hx3d, hx4d,
|
454 |
+
hx5d, hx6]
|
src/utils/exceptions.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
def error_message_detail(error):
|
4 |
+
_, _, exc_info = sys.exc_info()
|
5 |
+
filename = exc_info.tb_frame.f_code.co_filename
|
6 |
+
lineno = exc_info.tb_lineno
|
7 |
+
error_message = "Error encountered in line no [{}], filename : [{}], saying [{}]".format(lineno, filename, error)
|
8 |
+
return error_message
|
9 |
+
|
10 |
+
class CustomException(Exception):
|
11 |
+
def __init__(self, error_message):
|
12 |
+
super().__init__(error_message)
|
13 |
+
self.error_message = error_message_detail(error_message)
|
14 |
+
|
15 |
+
def __str__(self) -> str:
|
16 |
+
return self.error_message
|
src/utils/logger.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
|
4 |
+
logger = logging.getLogger(__name__)
|
5 |
+
logger.setLevel(logging.INFO)
|
6 |
+
|
7 |
+
log_dir = os.path.join(os.getcwd(), "logs")
|
8 |
+
os.makedirs(log_dir, exist_ok=True)
|
9 |
+
|
10 |
+
LOG_FILE = os.path.join(log_dir, "running_logs.log")
|
11 |
+
|
12 |
+
logFormat = "[%(asctime)s: %(levelname)s: %(module)s: %(message)s]"
|
13 |
+
logFormatter = logging.Formatter(fmt=logFormat, style="%")
|
14 |
+
|
15 |
+
streamHandler = logging.StreamHandler()
|
16 |
+
streamHandler.setFormatter(logFormatter)
|
17 |
+
|
18 |
+
fileHandler = logging.FileHandler(filename=LOG_FILE)
|
19 |
+
fileHandler.setFormatter(logFormatter)
|
20 |
+
|
21 |
+
logger.addHandler(streamHandler)
|
22 |
+
logger.addHandler(fileHandler)
|