GabrielML commited on
Commit
234009d
·
1 Parent(s): 5d723ef

Init gradio repo

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /workspace
4
+ ADD requirements.txt /workspace/requirements.txt
5
+ RUN pip install -U pip
6
+ RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
7
+ RUN pip install -r /workspace/requirements.txt
8
+
9
+ COPY src /workspace/src
10
+
11
+ ENV HOME=/workspace
12
+ CMD ["python", "src/app.py", "--host", "0.0.0.0", "--port", "7860"]
13
+ # ENTRYPOINT python src/app.py
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ torch==2.0.1+cpu
3
+ --find-links https://download.pytorch.org/whl/torch_stable.html
4
+ torchvision==0.15.2+cpu
5
+ efficientnet-pytorch==0.7.1
6
+ gradio==3.44.4
7
+ Markdown==3.4.4
8
+ Pillow==10.0.1
9
+ tqdm==4.66.1
src/CustomModels.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class DinoVisionClassifier(torch.nn.Module):
4
+ def __init__(self, dinov2, num_classes=5):
5
+ super(DinoVisionClassifier, self).__init__()
6
+ self.transformer = dinov2
7
+ self.classifier = torch.nn.Sequential(
8
+ torch.nn.Linear(384, 64),
9
+ torch.nn.ReLU(),
10
+ torch.nn.Dropout(0.2),
11
+ torch.nn.Linear(64, num_classes)
12
+ )
13
+
14
+ def forward(self, x):
15
+ x = self.transformer(x)
16
+ x = self.transformer.norm(x)
17
+ x = self.classifier(x)
18
+ return x
src/__init__.py ADDED
File without changes
src/app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils import load_specific_model, inference
3
+ import markdown
4
+
5
+ current_model = None # Initialize the current model as None
6
+
7
+ # Define a set of example images
8
+ example_images = [
9
+ ("Beispielbild Glas", "src/examples/Glas.jpg"),
10
+ ("Beispielbild Organic", "src/examples/Organic.jpg"),
11
+ ("Beispielbild Papier", "src/examples/Papier.jpg"),
12
+ ("Beispielbild Restmüll", "src/examples/Restmuell.jpg"),
13
+ ("Beispielbild Wertstoff", "src/examples/Wertstoff.jpg")
14
+ ]
15
+
16
+ def load_model(model_name):
17
+ global current_model
18
+ if model_name is None:
19
+ raise gr.Error("No model selected!")
20
+ if current_model is not None:
21
+ current_model = None
22
+
23
+ current_model = load_specific_model(model_name)
24
+ current_model.eval()
25
+
26
+ def predict(inp):
27
+ global current_model
28
+ if current_model is None:
29
+ raise gr.Error("No model loaded!")
30
+
31
+ confidences = inference(current_model, inp)
32
+ return confidences
33
+
34
+ with gr.Blocks() as demo:
35
+ with open('src/app_template.md', 'r') as f:
36
+ markdown_string = f.read()
37
+ header = gr.Markdown(markdown_string)
38
+
39
+ with gr.Row(variant="panel", equal_height=True):
40
+
41
+ user_image = gr.Image(
42
+ type="pil",
43
+ label="Upload Your Own Image",
44
+ info="You can also upload your own image for prediction.",
45
+ scale=2,
46
+ height=350,
47
+ )
48
+
49
+ with gr.Column():
50
+ output = gr.Label(
51
+ num_top_classes=3,
52
+ label="Output",
53
+ info="Top three predicted classes and their confidences.",
54
+ scale=2,
55
+ )
56
+
57
+ model_dropdown = gr.Dropdown(
58
+ ["EfficientNet-B3", "EfficientNet-B4", "vgg19", "resnet50", "dinov2_vits14"],
59
+ label="Model",
60
+ info="Select a model to use.",
61
+ scale=1,
62
+ )
63
+ model_dropdown.change(load_model, model_dropdown, show_progress=True, queue=True)
64
+ predict_button = gr.Button(label="Predict", info="Click to make a prediction.", scale=1)
65
+ predict_button.click(fn=predict, inputs=user_image, outputs=output, queue=True)
66
+
67
+ gr.Markdown("## Example Images")
68
+ gr.Markdown("You can just drag and drop these images into the image uploader above!")
69
+
70
+ with gr.Row():
71
+ for name, image_path in example_images:
72
+ example_image = gr.Image(
73
+ value=image_path,
74
+ label=name,
75
+ type="pil",
76
+ height=220,
77
+ interactive=False,
78
+ )
79
+
80
+ if __name__ == "__main__":
81
+ demo.queue()
82
+ demo.launch()
src/app_template.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Waste Classification Demo
2
+ This interactive demo allows you to classify waste items using various deep learning models. Choose a model, upload an image of a waste item, and click the "Predict" button to view the top three predicted classes and their confidences. The project was developed by [Ilyesse Hettenbach](https://github.com/ilyii) and [Gabriel Schurr](https://github.com/Gabriel9753) as part of a project work at the [University of Applied Sciences Karlsruhe](https://www.h-ka.de/).
3
+ Enjoy using the Waste Classification Demo to classify waste items and explore the capabilities of different deep learning models!
4
+
5
+ ## Models
6
+ The demo currently supports the following models: [EfficientNet-B3](https://arxiv.org/abs/1905.11946), [EfficientNet-B4](https://arxiv.org/abs/1905.11946), [VGG19](https://arxiv.org/abs/1409.1556), [ResNet50](https://arxiv.org/abs/1512.03385), and [DinoV2](https://arxiv.org/abs/2304.07193) in its smallest variant. The models were primarily trained on data we generated.
src/download_models.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is just used to download the models from the internet.
3
+ '''
4
+ from torchvision import models
5
+ from efficientnet_pytorch import EfficientNet
6
+ import torch
7
+
8
+ def main():
9
+ try:
10
+ print("Downloading EfficientNet-B3...")
11
+ _ = EfficientNet.from_pretrained("efficientnet-b3")
12
+ except Exception as e:
13
+ print(f"Error while downloading EfficientNet-B3: {e}")
14
+
15
+ try:
16
+ print("Downloading EfficientNet-B4...")
17
+ _ = EfficientNet.from_pretrained("efficientnet-b4")
18
+ except Exception as e:
19
+ print(f"Error while downloading EfficientNet-B4: {e}")
20
+
21
+ try:
22
+ print("Downloading vgg19...")
23
+ _ = models.vgg19()
24
+ except Exception as e:
25
+ print(f"Error while downloading vgg19: {e}")
26
+
27
+ try:
28
+ print("Downloading resnet50...")
29
+ _ = models.resnet50()
30
+ except Exception as e:
31
+ print(f"Error while downloading resnet50: {e}")
32
+
33
+ try:
34
+ print("Downloading dinov2_vits14...")
35
+ _ = torch.hub.load('facebookresearch/dinov2', "dinov2_vits14")
36
+ except Exception as e:
37
+ print(f"Error while downloading dinov2_vits14: {e}")
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
src/examples/Glas.jpg ADDED

Git LFS Details

  • SHA256: 97f497c65cf0133d72346771c1d943ffb42a8fc33ca07dfa62a3434c8d0c9d0d
  • Pointer size: 131 Bytes
  • Size of remote file: 865 kB
src/examples/Organic.jpg ADDED

Git LFS Details

  • SHA256: c5f3f9e7e9760aea75ef4fe75e95581b82737c6ca45d24b75cac2f00e3c51ac4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.89 MB
src/examples/Papier.jpg ADDED

Git LFS Details

  • SHA256: ee064e0fc1d7f326fc43f010893dcad27fb5309811a9209df28972a9b6f29613
  • Pointer size: 131 Bytes
  • Size of remote file: 972 kB
src/examples/Restmuell.jpg ADDED

Git LFS Details

  • SHA256: 29edba7f596fdcbf60339a106cd6a62815d39a22d567423ce8dca480b7c9789c
  • Pointer size: 131 Bytes
  • Size of remote file: 914 kB
src/examples/Wertstoff.jpg ADDED

Git LFS Details

  • SHA256: 611b44325a56785b9dbb4051ca593fd0c75fa0b0f1126d4f363c158c6f486427
  • Pointer size: 132 Bytes
  • Size of remote file: 1.02 MB
src/models/dinov2_info.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Model: dinov2_vits14
2
+ Timestamp: 2023-09-25_17-27-57
3
+ Batch Size: 128
4
+ Learning Rate: 2e-06
5
+ Number of Epochs: 20
6
+ Linear Layer: [384, 64, 5]
7
+ Dropout: 0.2
8
+ Train length: 18108
9
+ Validation length: 4656
10
+
11
+ Train History:
12
+ Epoch 1: Acc 91.000 | Loss 0.283
13
+ Epoch 2: Acc 94.000 | Loss 0.174
14
+ Epoch 3: Acc 95.000 | Loss 0.137
15
+ Epoch 4: Acc 96.000 | Loss 0.115
16
+ Epoch 5: Acc 96.000 | Loss 0.095
17
+ Epoch 6: Acc 96.000 | Loss 0.088
18
+ Epoch 7: Acc 97.000 | Loss 0.087
19
+ Epoch 8: Acc 97.000 | Loss 0.080
20
+ Epoch 9: Acc 97.000 | Loss 0.077
21
+ Epoch 10: Acc 97.000 | Loss 0.076
22
+ Epoch 11: Acc 97.000 | Loss 0.075
23
+ Epoch 12: Acc 97.000 | Loss 0.069
24
+ Epoch 13: Acc 97.000 | Loss 0.081
25
+ Epoch 14: Acc 97.000 | Loss 0.068
26
+ Epoch 15: Acc 98.000 | Loss 0.059
27
+ Epoch 16: Acc 97.000 | Loss 0.063
28
+ Epoch 17: Acc 98.000 | Loss 0.054
29
+ Epoch 18: Acc 97.000 | Loss 0.058
30
+ Epoch 19: Acc 97.000 | Loss 0.067
31
+ Epoch 20: Acc 98.000 | Loss 0.060
src/models/dinov2_vits14_0.054_98.00.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d638a6b7b0619cd6154b26f1a85d640367b9235fdfe741d00d6c830d66f1f318
3
+ size 88398037
src/models/eff_b3_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7896d02c1746ff84bbc9d4d9fe6c891d1a74a688708a486e0dece1f42bf1580
3
+ size 43361093
src/models/eff_b4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9bb6e5fbc6adcc94b12c39c4c868acddb86d45413195cdd7760d5ca939135c1
3
+ size 70974461
src/models/resnet50.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c0b8388c2d4bf11f86396930db7c11896e71b9abb71de3e145ad163c5505a59
3
+ size 94389825
src/models/vgg19.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08f8a40e0add29e781ed7cf4593abedfeca41fa8aa31970660190804005bbf08
3
+ size 80610559
src/utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+ from torchvision import models
3
+ from efficientnet_pytorch import EfficientNet
4
+ import torch
5
+ from CustomModels import DinoVisionClassifier
6
+
7
+ classes = {0: 'Glas', 1: 'Organic', 2: 'Papier', 3: 'Restmüll', 4: 'Wertstoff'}
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ transform = transforms.Compose(
12
+ [transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
13
+ transforms.ToTensor(),
14
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
15
+ ]
16
+ )
17
+
18
+ transform_dinov2 = transforms.Compose(
19
+ [ transforms.Resize(256),
20
+ transforms.CenterCrop(224),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
23
+ ]
24
+ )
25
+
26
+ def load_specific_model(model_name):
27
+ current_model = None
28
+ if model_name == "EfficientNet-B3":
29
+ current_model = EfficientNet.from_pretrained("efficientnet-b3", num_classes=len(classes.keys()))
30
+ current_model.load_state_dict(torch.load("src/models/eff_b3_model.pt", map_location="cpu"))
31
+ elif model_name == "EfficientNet-B4":
32
+ current_model = EfficientNet.from_pretrained("efficientnet-b4", num_classes=len(classes.keys()))
33
+ current_model.load_state_dict(torch.load("src/models/eff_b4.pt", map_location="cpu"))
34
+ elif model_name == "vgg19":
35
+ current_model = models.vgg19()
36
+ in_features = current_model.classifier[0].in_features
37
+ current_model.classifier = torch.nn.Linear(in_features, len(classes.keys()))
38
+ current_model.load_state_dict(torch.load("src/models/vgg19.pt", map_location="cpu"))
39
+ elif model_name == "resnet50":
40
+ current_model = models.resnet50()
41
+ in_features = current_model.fc.in_features
42
+ current_model.fc = torch.nn.Linear(in_features, len(classes.keys()))
43
+ current_model.load_state_dict(torch.load("src/models/resnet50.pt", map_location="cpu"))
44
+ elif model_name == "dinov2_vits14":
45
+ current_model = torch.hub.load('facebookresearch/dinov2', "dinov2_vits14")
46
+ current_model = DinoVisionClassifier(current_model, num_classes=len(classes.keys()))
47
+ current_model.load_state_dict(torch.load("src/models/dinov2_vits14_0.054_98.00.pth", map_location="cpu"))
48
+
49
+ print(f"Loaded model {model_name}")
50
+ return current_model.eval().to(device)
51
+
52
+ def inference(model, inp):
53
+ model.eval()
54
+ inp = transform(inp) if model.__class__.__name__ != "DinoVisionClassifier" else transform_dinov2(inp)
55
+ inp = inp.unsqueeze(0).to(device)
56
+ if torch.cuda.is_available():
57
+ with torch.no_grad(), torch.cuda.amp.autocast():
58
+ prediction = torch.nn.functional.softmax(model(inp)[0], dim=0).cpu().numpy()
59
+ else:
60
+ with torch.no_grad():
61
+ prediction = torch.nn.functional.softmax(model(inp)[0], dim=0).cpu().numpy()
62
+
63
+ confidences = {classes[i]: float(prediction[i]) for i in range(len(classes.keys()))}
64
+ return confidences
65
+