Wauplin HF staff commited on
Commit
02afadf
·
0 Parent(s):

first commit

Browse files
Files changed (7) hide show
  1. .gitignore +140 -0
  2. Makefile +10 -0
  3. app.py +53 -0
  4. gallery_history.py +122 -0
  5. pyproject.toml +19 -0
  6. requirements.txt +7 -0
  7. setup.cfg +16 -0
.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib64/
18
+ parts/
19
+ sdist/
20
+ var/
21
+ wheels/
22
+ pip-wheel-metadata/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+
53
+ # Translations
54
+ *.mo
55
+ *.pot
56
+
57
+ # Django stuff:
58
+ *.log
59
+ local_settings.py
60
+ db.sqlite3
61
+ db.sqlite3-journal
62
+
63
+ # Flask stuff:
64
+ instance/
65
+ .webassets-cache
66
+
67
+ # Scrapy stuff:
68
+ .scrapy
69
+
70
+ # Sphinx documentation
71
+ docs/_build/
72
+
73
+ # PyBuilder
74
+ target/
75
+
76
+ # Jupyter Notebook
77
+ .ipynb_checkpoints
78
+
79
+ # IPython
80
+ profile_default/
81
+ ipython_config.py
82
+
83
+ # pyenv
84
+ .python-version
85
+
86
+ # pipenv
87
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
89
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
90
+ # install all needed dependencies.
91
+ #Pipfile.lock
92
+
93
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
94
+ __pypackages__/
95
+
96
+ # Celery stuff
97
+ celerybeat-schedule
98
+ celerybeat.pid
99
+
100
+ # SageMath parsed files
101
+ *.sage.py
102
+
103
+ # Environments
104
+ .env
105
+ .venv
106
+ .venv*
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+ .venv*
113
+
114
+ # Spyder project settings
115
+ .spyderproject
116
+ .spyproject
117
+
118
+ # Rope project settings
119
+ .ropeproject
120
+
121
+ # mkdocs documentation
122
+ /site
123
+
124
+ # mypy
125
+ .mypy_cache/
126
+ .dmypy.json
127
+ dmypy.json
128
+
129
+ # Pyre type checker
130
+ .pyre/
131
+ .vscode/
132
+ .idea/
133
+
134
+ .DS_Store
135
+
136
+ # Ruff
137
+ .ruff_cache
138
+
139
+ # Spell checker config
140
+ cspell.json
Makefile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: quality style
2
+
3
+ quality:
4
+ black --check .
5
+ ruff .
6
+ mypy .
7
+
8
+ style:
9
+ black .
10
+ ruff . --fix
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import json
4
+ import pathlib
5
+ import tempfile
6
+
7
+ import gradio as gr
8
+ from gradio_client import Client
9
+
10
+
11
+ client = Client("runwayml/stable-diffusion-v1-5")
12
+
13
+
14
+ def generate(prompt: str) -> tuple[str, list[str]]:
15
+ negative_prompt = ""
16
+ guidance_scale = 9.0
17
+ out_dir = client.predict(prompt, fn_index=1)
18
+
19
+ config = {
20
+ "prompt": prompt,
21
+ "negative_prompt": negative_prompt,
22
+ "guidance_scale": guidance_scale,
23
+ }
24
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as config_file:
25
+ json.dump(config, config_file)
26
+
27
+ with (pathlib.Path(out_dir) / "captions.json").open() as f:
28
+ paths = list(json.load(f).keys())
29
+ return paths
30
+
31
+
32
+ with gr.Blocks(css="style.css") as demo:
33
+ with gr.Group():
34
+ prompt = gr.Text(show_label=False, placeholder="Prompt")
35
+ gallery = gr.Gallery(
36
+ show_label=False,
37
+ columns=2,
38
+ rows=2,
39
+ height="600px",
40
+ object_fit="scale-down",
41
+ )
42
+
43
+ prompt.submit(
44
+ fn=generate,
45
+ inputs=prompt,
46
+ outputs=gallery,
47
+ )
48
+
49
+ with gr.Tab("Past generations"):
50
+ gr.Markdown("building...")
51
+
52
+ if __name__ == "__main__":
53
+ demo.launch()
gallery_history.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ How to use:
3
+ 1. Create a Space with a Persistent Storage attached. Filesystem will be available under `/data`.
4
+ 2. Add `hf_oauth: true` to the Space metadata (README.md). Make sure to have Gradio>=3.41.0 configured.
5
+ 3. Add `HISTORY_FOLDER` as a Space variable (example. `"/data/history"`).
6
+ 4. Add `filelock` as dependency in `requirements.txt`.
7
+ 5. Add history gallery to your Gradio app:
8
+ a. Add imports: `from gallery_history import fetch_gallery_history, show_gallery_history`
9
+ a. Add `history = show_gallery_history()` within `gr.Blocks` context.
10
+ b. Add `.then(fn=fetch_gallery_history, inputs=[prompt, result], outputs=history)` on the generate event.
11
+ """
12
+ import json
13
+ import os
14
+ import shutil
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional, Tuple
17
+ from uuid import uuid4
18
+
19
+ import gradio as gr
20
+ from filelock import FileLock
21
+
22
+
23
+ _folder = os.environ.get("HISTORY_FOLDER")
24
+ if _folder is None:
25
+ print(
26
+ "'HISTORY_FOLDER' environment variable not set. User history will be saved "
27
+ "locally and will be lost when the Space instance is restarted."
28
+ )
29
+ _folder = Path(__file__).parent / "history"
30
+ HISTORY_FOLDER_PATH = Path(_folder)
31
+
32
+ IMAGES_FOLDER_PATH = HISTORY_FOLDER_PATH / "images"
33
+ IMAGES_FOLDER_PATH.mkdir(parents=True, exist_ok=True)
34
+
35
+
36
+ def show_gallery_history():
37
+ gr.Markdown(
38
+ "## Your past generations\n\n(Log in to keep a gallery of your previous generations."
39
+ " Your history will be saved and available on your next visit.)"
40
+ )
41
+ with gr.Column():
42
+ with gr.Row():
43
+ gr.LoginButton(min_width=250)
44
+ gr.LogoutButton(min_width=250)
45
+ gallery = gr.Gallery(
46
+ label="Past images",
47
+ show_label=True,
48
+ elem_id="gallery",
49
+ object_fit="contain",
50
+ columns=3,
51
+ height=300,
52
+ preview=False,
53
+ show_share_button=False,
54
+ show_download_button=False,
55
+ )
56
+ gr.Markdown("Make sure to save your images from time to time, this gallery may be deleted in the future.")
57
+ gallery.attach_load_event(fetch_gallery_history, every=None)
58
+ return gallery
59
+
60
+
61
+ def fetch_gallery_history(
62
+ prompt: Optional[str] = None,
63
+ result: Optional[Dict] = None,
64
+ user: Optional[gr.OAuthProfile] = None,
65
+ ):
66
+ if user is None:
67
+ return []
68
+ try:
69
+ if prompt is not None and result is not None: # None values means no new images
70
+ return _update_user_history(user["preferred_username"], [(item["name"], prompt) for item in result])
71
+ else:
72
+ return _read_user_history(user["preferred_username"])
73
+ except Exception as e:
74
+ raise gr.Error(f"Error while fetching history: {e}") from e
75
+
76
+
77
+ ####################
78
+ # Internal helpers #
79
+ ####################
80
+
81
+
82
+ def _read_user_history(username: str) -> List[Tuple[str, str]]:
83
+ """Return saved history for that user."""
84
+ with _user_lock(username):
85
+ path = _user_history_path(username)
86
+ if path.exists():
87
+ return json.loads(path.read_text())
88
+ return [] # No history yet
89
+
90
+
91
+ def _update_user_history(username: str, new_images: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
92
+ """Update history for that user and return it."""
93
+ with _user_lock(username):
94
+ # Read existing
95
+ path = _user_history_path(username)
96
+ if path.exists():
97
+ images = json.loads(path.read_text())
98
+ else:
99
+ images = [] # No history yet
100
+
101
+ # Copy images to persistent folder
102
+ images = [(_copy_image(src_path), prompt) for src_path, prompt in new_images] + images
103
+
104
+ # Save and return
105
+ path.write_text(json.dumps(images))
106
+ return images
107
+
108
+
109
+ def _user_history_path(username: str) -> Path:
110
+ return HISTORY_FOLDER_PATH / f"{username}.json"
111
+
112
+
113
+ def _user_lock(username: str) -> FileLock:
114
+ """Ensure history is not corrupted if concurrent calls."""
115
+ return FileLock(f"{_user_history_path(username)}.lock")
116
+
117
+
118
+ def _copy_image(src: str) -> str:
119
+ """Copy image to the persistent storage."""
120
+ dst = IMAGES_FOLDER_PATH / f"{uuid4().hex}_{Path(src).name}" # keep file ext
121
+ shutil.copyfile(src, dst)
122
+ return str(dst)
pyproject.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 119
3
+ target_version = ['py37', 'py38', 'py39', 'py310']
4
+ preview = true
5
+
6
+ [tool.mypy]
7
+ ignore_missing_imports = true
8
+ no_implicit_optional = true
9
+ scripts_are_modules = true
10
+
11
+ [tool.ruff]
12
+ # Ignored rules:
13
+ # "E501" -> line length violation
14
+ ignore = ["E501"]
15
+ select = ["E", "F", "I", "W"]
16
+ line-length = 119
17
+
18
+ [tool.ruff.isort]
19
+ lines-after-imports = 2
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=3.44
2
+ huggingface_hub>=0.17
3
+
4
+ # dev-deps
5
+ ruff
6
+ black
7
+ mypy
setup.cfg ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [isort]
2
+ default_section = FIRSTPARTY
3
+ ensure_newline_before_comments = True
4
+ force_grid_wrap = 0
5
+ include_trailing_comma = True
6
+ known_third_party = gradio
7
+
8
+ line_length = 119
9
+ lines_after_imports = 2
10
+ multi_line_output = 3
11
+ use_parentheses = True
12
+
13
+ [flake8]
14
+ exclude = .git,__pycache__,old,build,dist,.venv*
15
+ ignore = B028, E203, E501, E741, W503
16
+ max-line-length = 119