Spaces:
Runtime error
Runtime error
Joshua Sundance Bailey
commited on
Commit
•
171c1a6
0
Parent(s):
initial commit
Browse files- .gitattributes +35 -0
- .github/ISSUE_TEMPLATE/bug_report.md +38 -0
- .github/ISSUE_TEMPLATE/feature_request.md +17 -0
- .github/pull_request_template.md +12 -0
- .github/workflows/check-file-size-limit.yml +14 -0
- .github/workflows/hf-space.yml +21 -0
- .gitignore +94 -0
- .pre-commit-config.yaml +65 -0
- LICENSE +9 -0
- README.md +78 -0
- app.py +109 -0
- requirements.txt +6 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.github/ISSUE_TEMPLATE/bug_report.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Bug report
|
3 |
+
about: Create a report to help us improve
|
4 |
+
title: ''
|
5 |
+
labels: bug
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Describe the bug**
|
11 |
+
A clear and concise description of what the bug is.
|
12 |
+
|
13 |
+
**To Reproduce**
|
14 |
+
Steps to reproduce the behavior:
|
15 |
+
1. Go to '...'
|
16 |
+
2. Click on '....'
|
17 |
+
3. Scroll down to '....'
|
18 |
+
4. See error
|
19 |
+
|
20 |
+
**Expected behavior**
|
21 |
+
A clear and concise description of what you expected to happen.
|
22 |
+
|
23 |
+
**Screenshots**
|
24 |
+
If applicable, add screenshots to help explain your problem.
|
25 |
+
|
26 |
+
**Desktop (please complete the following information):**
|
27 |
+
- OS: [e.g. iOS]
|
28 |
+
- Browser [e.g. chrome, safari]
|
29 |
+
- Version [e.g. 22]
|
30 |
+
|
31 |
+
**Smartphone (please complete the following information):**
|
32 |
+
- Device: [e.g. iPhone6]
|
33 |
+
- OS: [e.g. iOS8.1]
|
34 |
+
- Browser [e.g. stock browser, safari]
|
35 |
+
- Version [e.g. 22]
|
36 |
+
|
37 |
+
**Additional context**
|
38 |
+
Add any other context about the problem here.
|
.github/ISSUE_TEMPLATE/feature_request.md
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: Feature request
|
3 |
+
about: Suggest an idea for this project
|
4 |
+
title: ''
|
5 |
+
labels: enhancement
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Describe the solution you'd like**
|
11 |
+
A clear and concise description of what you want to happen.
|
12 |
+
|
13 |
+
**Describe alternatives you've considered**
|
14 |
+
A clear and concise description of any alternative solutions or features you've considered.
|
15 |
+
|
16 |
+
**Additional context**
|
17 |
+
Add any other context or screenshots about the feature request here.
|
.github/pull_request_template.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Thank you for contributing!
|
2 |
+
Before submitting this PR, please make sure:
|
3 |
+
|
4 |
+
- [ ] Your code builds clean without any errors or warnings
|
5 |
+
- [ ] Your code doesn't break anything we can't fix
|
6 |
+
- [ ] You have added appropriate tests
|
7 |
+
|
8 |
+
Please check one or more of the following to describe the nature of this PR:
|
9 |
+
- [ ] New feature
|
10 |
+
- [ ] Bug fix
|
11 |
+
- [ ] Documentation
|
12 |
+
- [ ] Other
|
.github/workflows/check-file-size-limit.yml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: 10 MB file size limit
|
2 |
+
on:
|
3 |
+
pull_request:
|
4 |
+
branches: [main]
|
5 |
+
|
6 |
+
jobs:
|
7 |
+
check-file-sizes:
|
8 |
+
runs-on: ubuntu-latest
|
9 |
+
steps:
|
10 |
+
- name: Check large files
|
11 |
+
uses: ActionsDesk/lfs-warning@v2.0
|
12 |
+
with:
|
13 |
+
filesizelimit: 10485760 # this is 10MB so we can sync to HF Spaces
|
14 |
+
token: ${{ secrets.WORKFLOW_GIT_ACCESS_TOKEN }}
|
.github/workflows/hf-space.yml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Push to HuggingFace Space
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [main]
|
6 |
+
workflow_dispatch:
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
push-to-huggingface:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
steps:
|
12 |
+
- uses: actions/checkout@v2
|
13 |
+
with:
|
14 |
+
fetch-depth: 0
|
15 |
+
token: ${{ secrets.WORKFLOW_GIT_ACCESS_TOKEN }}
|
16 |
+
|
17 |
+
- name: Push to HuggingFace Space
|
18 |
+
env:
|
19 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
20 |
+
run: |
|
21 |
+
git push https://joshuasundance:$HF_TOKEN@huggingface.co/spaces/joshuasundance/mtg-coloridentity main
|
.gitignore
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hf_cache/
|
2 |
+
govgis-nov2023/
|
3 |
+
*$py.class
|
4 |
+
*.chainlit
|
5 |
+
*.chroma
|
6 |
+
*.cover
|
7 |
+
*.egg
|
8 |
+
*.egg-info/
|
9 |
+
*.env
|
10 |
+
*.langchain.db
|
11 |
+
*.log
|
12 |
+
*.manifest
|
13 |
+
*.mo
|
14 |
+
*.pot
|
15 |
+
*.py,cover
|
16 |
+
*.py[cod]
|
17 |
+
*.sage.py
|
18 |
+
*.so
|
19 |
+
*.spec
|
20 |
+
.DS_STORE
|
21 |
+
.Python
|
22 |
+
.cache
|
23 |
+
.coverage
|
24 |
+
.coverage.*
|
25 |
+
.dmypy.json
|
26 |
+
.eggs/
|
27 |
+
.env
|
28 |
+
.hypothesis/
|
29 |
+
.idea
|
30 |
+
.installed.cfg
|
31 |
+
.ipynb_checkpoints
|
32 |
+
.mypy_cache/
|
33 |
+
.nox/
|
34 |
+
.pyre/
|
35 |
+
.pytest_cache/
|
36 |
+
.python-version
|
37 |
+
.ropeproject
|
38 |
+
.ruff_cache/
|
39 |
+
.scrapy
|
40 |
+
.spyderproject
|
41 |
+
.spyproject
|
42 |
+
.tox/
|
43 |
+
.venv
|
44 |
+
.vscode
|
45 |
+
.webassets-cache
|
46 |
+
/site
|
47 |
+
ENV/
|
48 |
+
MANIFEST
|
49 |
+
__pycache__
|
50 |
+
__pycache__/
|
51 |
+
__pypackages__/
|
52 |
+
build/
|
53 |
+
celerybeat-schedule
|
54 |
+
celerybeat.pid
|
55 |
+
coverage.xml
|
56 |
+
credentials.json
|
57 |
+
data/
|
58 |
+
db.sqlite3
|
59 |
+
db.sqlite3-journal
|
60 |
+
develop-eggs/
|
61 |
+
dist/
|
62 |
+
dmypy.json
|
63 |
+
docs/_build/
|
64 |
+
downloads/
|
65 |
+
eggs/
|
66 |
+
env.bak/
|
67 |
+
env/
|
68 |
+
fly.toml
|
69 |
+
htmlcov/
|
70 |
+
instance/
|
71 |
+
ipython_config.py
|
72 |
+
junk/
|
73 |
+
lib/
|
74 |
+
lib64/
|
75 |
+
local_settings.py
|
76 |
+
models/*.bin
|
77 |
+
nosetests.xml
|
78 |
+
lab/scratch/
|
79 |
+
lab/
|
80 |
+
parts/
|
81 |
+
pip-delete-this-directory.txt
|
82 |
+
pip-log.txt
|
83 |
+
pip-wheel-metadata/
|
84 |
+
profile_default/
|
85 |
+
sdist/
|
86 |
+
share/python-wheels/
|
87 |
+
storage
|
88 |
+
target/
|
89 |
+
token.json
|
90 |
+
var/
|
91 |
+
venv
|
92 |
+
venv.bak/
|
93 |
+
venv/
|
94 |
+
wheels/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Don't know what this file is? See https://pre-commit.com/
|
2 |
+
# pip install pre-commit
|
3 |
+
# pre-commit install
|
4 |
+
# pre-commit autoupdate
|
5 |
+
# Apply to all files without commiting:
|
6 |
+
# pre-commit run --all-files
|
7 |
+
# I recommend running this until you pass all checks, and then commit.
|
8 |
+
# Fix what you need to and then let the pre-commit hooks resolve their conflicts.
|
9 |
+
# You may need to git add -u between runs.
|
10 |
+
exclude: "AI_CHANGELOG.md"
|
11 |
+
repos:
|
12 |
+
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
13 |
+
rev: "v0.1.15"
|
14 |
+
hooks:
|
15 |
+
- id: ruff
|
16 |
+
args: [--fix, --exit-non-zero-on-fix, --ignore, E501]
|
17 |
+
- repo: https://github.com/koalaman/shellcheck-precommit
|
18 |
+
rev: v0.9.0
|
19 |
+
hooks:
|
20 |
+
- id: shellcheck
|
21 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
22 |
+
rev: v4.5.0
|
23 |
+
hooks:
|
24 |
+
- id: check-ast
|
25 |
+
- id: check-builtin-literals
|
26 |
+
- id: check-merge-conflict
|
27 |
+
- id: check-symlinks
|
28 |
+
- id: check-toml
|
29 |
+
- id: check-xml
|
30 |
+
- id: debug-statements
|
31 |
+
- id: check-case-conflict
|
32 |
+
- id: check-docstring-first
|
33 |
+
- id: check-executables-have-shebangs
|
34 |
+
- id: check-json
|
35 |
+
# - id: check-yaml
|
36 |
+
- id: debug-statements
|
37 |
+
- id: fix-byte-order-marker
|
38 |
+
- id: detect-private-key
|
39 |
+
- id: end-of-file-fixer
|
40 |
+
- id: trailing-whitespace
|
41 |
+
- id: mixed-line-ending
|
42 |
+
- id: requirements-txt-fixer
|
43 |
+
- repo: https://github.com/pre-commit/mirrors-mypy
|
44 |
+
rev: v1.8.0
|
45 |
+
hooks:
|
46 |
+
- id: mypy
|
47 |
+
additional_dependencies:
|
48 |
+
- types-PyYAML
|
49 |
+
- repo: https://github.com/asottile/add-trailing-comma
|
50 |
+
rev: v3.1.0
|
51 |
+
hooks:
|
52 |
+
- id: add-trailing-comma
|
53 |
+
#- repo: https://github.com/dannysepler/rm_unneeded_f_str
|
54 |
+
# rev: v0.2.0
|
55 |
+
# hooks:
|
56 |
+
# - id: rm-unneeded-f-str
|
57 |
+
- repo: https://github.com/psf/black
|
58 |
+
rev: 24.1.1
|
59 |
+
hooks:
|
60 |
+
- id: black
|
61 |
+
- repo: https://github.com/PyCQA/bandit
|
62 |
+
rev: 1.7.7
|
63 |
+
hooks:
|
64 |
+
- id: bandit
|
65 |
+
args: ["-x", "tests/*.py"]
|
LICENSE
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Joshua Sundance Bailey
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
6 |
+
|
7 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
8 |
+
|
9 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: mtg-coloridentity
|
3 |
+
emoji: 🧙
|
4 |
+
colorFrom: white
|
5 |
+
colorTo: red
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.30.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
# mtg-coloridentity
|
14 |
+
|
15 |
+
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
|
16 |
+
[![python](https://img.shields.io/badge/Python-3-3776AB.svg?style=flat&logo=python&logoColor=white)](https://www.python.org)
|
17 |
+
|
18 |
+
[![Push to HuggingFace Space](https://github.com/joshuasundance-swca/mtg-coloridentity/actions/workflows/hf-space.yml/badge.svg)](https://github.com/joshuasundance-swca/mtg-coloridentity/actions/workflows/hf-space.yml)
|
19 |
+
[![Open HuggingFace Space](https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/joshuasundance/mtg-coloridentity)
|
20 |
+
|
21 |
+
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit)
|
22 |
+
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v1.json)](https://github.com/charliermarsh/ruff)
|
23 |
+
[![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/)
|
24 |
+
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
|
25 |
+
|
26 |
+
[![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit)
|
27 |
+
|
28 |
+
|
29 |
+
# mtg-coloridentity
|
30 |
+
|
31 |
+
🤖 This README was written by GPT-4. 🤖
|
32 |
+
|
33 |
+
## Overview
|
34 |
+
This Streamlit app is designed for the multi-label classification of Magic: The Gathering (MTG) cards,
|
35 |
+
specifically focusing on their color identity.
|
36 |
+
It utilizes a pre-trained model hosted on Hugging Face, `joshuasundance/mtg-coloridentity-multilabel-classification`,
|
37 |
+
to predict the color identity of MTG cards based on their names and descriptions.
|
38 |
+
|
39 |
+
## Features
|
40 |
+
- Interactive UI: Users can input the name and text of any MTG card to get predictions on its color identity.
|
41 |
+
- Color Probabilities: The app displays the probability of each color identity (Black, Green, Red, Blue, White) for the given card.
|
42 |
+
- Random Card Selection: With a "Roll the Dice" feature, users can load the text of a random MTG card from the dataset.
|
43 |
+
|
44 |
+
## How It Works
|
45 |
+
The app fetches a pre-trained `SetFit` model from Hugging Face and uses it to
|
46 |
+
predict the color identities of MTG cards.
|
47 |
+
The model's predictions are displayed as a bar chart,
|
48 |
+
showing the probability of each color identity.
|
49 |
+
|
50 |
+
## Getting Started
|
51 |
+
To run this app locally, clone the repository and ensure you have the following prerequisites installed:
|
52 |
+
|
53 |
+
- Python 3.x
|
54 |
+
- `streamlit`
|
55 |
+
- `pandas`
|
56 |
+
- `seaborn`
|
57 |
+
- `matplotlib`
|
58 |
+
- `datasets` and `setfit` from Hugging Face
|
59 |
+
|
60 |
+
## Contributions, Support, and Contact
|
61 |
+
|
62 |
+
Contributions to this project are welcome! Please feel free to submit issues and pull requests.
|
63 |
+
|
64 |
+
For support, please raise an issue on GitHub or in the HuggingFace space.
|
65 |
+
|
66 |
+
## License
|
67 |
+
|
68 |
+
This project is under the [MIT License](LICENSE.md).
|
69 |
+
|
70 |
+
## Acknowledgments
|
71 |
+
|
72 |
+
Thanks to HuggingFace and `setfit`!
|
73 |
+
|
74 |
+
## TODO
|
75 |
+
- [ ] make a todo list ;)
|
76 |
+
- [ ] improve READMEs
|
77 |
+
- [ ] make better model(s)
|
78 |
+
- [x] learn in public
|
app.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from typing import Sequence
|
4 |
+
|
5 |
+
import datasets
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import pandas as pd
|
8 |
+
import seaborn as sns
|
9 |
+
import streamlit as st
|
10 |
+
from setfit import SetFitModel
|
11 |
+
|
12 |
+
st.set_page_config(
|
13 |
+
page_title="mtg-coloridentity-multilabel-classification",
|
14 |
+
page_icon="🧙",
|
15 |
+
layout="wide",
|
16 |
+
initial_sidebar_state="collapsed",
|
17 |
+
menu_items=None,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
default_hf_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
|
22 |
+
HF_HOME = os.environ.get("HF_HOME", default_hf_home)
|
23 |
+
|
24 |
+
coloridentity_model = "joshuasundance/mtg-coloridentity-multilabel-classification"
|
25 |
+
|
26 |
+
colors = ["B", "G", "R", "U", "W"]
|
27 |
+
labels = ["black", "green", "red", "blue", "white"]
|
28 |
+
|
29 |
+
sns.set()
|
30 |
+
|
31 |
+
col1, col2 = st.columns(2)
|
32 |
+
|
33 |
+
|
34 |
+
@st.cache_resource
|
35 |
+
def get_model(
|
36 |
+
model_id: str = coloridentity_model,
|
37 |
+
cache_dir: str = HF_HOME,
|
38 |
+
**kwargs,
|
39 |
+
) -> SetFitModel:
|
40 |
+
return SetFitModel.from_pretrained(model_id, cache_dir=cache_dir, **kwargs)
|
41 |
+
|
42 |
+
|
43 |
+
@st.cache_data
|
44 |
+
def get_data(
|
45 |
+
repo_id: str = coloridentity_model,
|
46 |
+
cache_dir: str = HF_HOME,
|
47 |
+
**kwargs,
|
48 |
+
) -> datasets.Dataset:
|
49 |
+
dataset_dict = datasets.load_dataset(repo_id, cache_dir=cache_dir, **kwargs)
|
50 |
+
return datasets.concatenate_datasets(
|
51 |
+
list(dataset_dict.values()),
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
def get_random_text() -> str:
|
56 |
+
return dataset.select([random.randint(0, len(dataset))])[0]["text"] # nosec
|
57 |
+
|
58 |
+
|
59 |
+
@st.cache_data
|
60 |
+
def get_preds(input_text: str) -> Sequence[float]:
|
61 |
+
return model.predict_proba(input_text)
|
62 |
+
|
63 |
+
|
64 |
+
def prob_bars(preds: Sequence[float]) -> None:
|
65 |
+
_preds = (float(p) for p in preds)
|
66 |
+
df = pd.DataFrame(
|
67 |
+
zip(labels, _preds),
|
68 |
+
columns=["Color", "Probability"],
|
69 |
+
)
|
70 |
+
plt.figure(figsize=(8, 6))
|
71 |
+
ax = sns.barplot(x="Color", y="Probability", data=df, palette=labels)
|
72 |
+
|
73 |
+
# Add data labels on each bar
|
74 |
+
for p in ax.patches:
|
75 |
+
ax.annotate(
|
76 |
+
format(p.get_height(), ".4f"),
|
77 |
+
(p.get_x() + p.get_width() / 2.0, p.get_height()),
|
78 |
+
ha="center",
|
79 |
+
va="center",
|
80 |
+
xytext=(0, 9),
|
81 |
+
textcoords="offset points",
|
82 |
+
)
|
83 |
+
|
84 |
+
plt.title("Prediction Probabilities")
|
85 |
+
plt.xlabel("Color")
|
86 |
+
plt.ylabel("Probability")
|
87 |
+
st.pyplot(plt.gcf())
|
88 |
+
|
89 |
+
|
90 |
+
model = get_model()
|
91 |
+
dataset = get_data()
|
92 |
+
default_text = get_random_text()
|
93 |
+
|
94 |
+
if "input_text" not in st.session_state:
|
95 |
+
st.session_state.input_text = default_text
|
96 |
+
|
97 |
+
with col1:
|
98 |
+
if st.button("🎲 Roll the Dice"):
|
99 |
+
st.session_state.input_text = get_random_text()
|
100 |
+
input_text = st.text_area(
|
101 |
+
"Card name and text",
|
102 |
+
st.session_state.input_text,
|
103 |
+
height=400,
|
104 |
+
)
|
105 |
+
|
106 |
+
preds = get_preds(input_text)
|
107 |
+
|
108 |
+
with col2:
|
109 |
+
prob_bars(preds)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
matplotlib
|
3 |
+
pandas
|
4 |
+
seaborn
|
5 |
+
setfit
|
6 |
+
streamlit
|