Update README.md
Browse files
README.md
CHANGED
@@ -14,22 +14,65 @@ should probably proofread and complete it, then remove this comment. -->
|
|
14 |
|
15 |
# beit-sketch-classifier
|
16 |
|
17 |
-
This model is a
|
18 |
It achieves the following results on the evaluation set:
|
19 |
- Loss: 1.6083
|
20 |
- Accuracy: 0.7480
|
21 |
|
22 |
-
## Model description
|
23 |
-
|
24 |
-
More information needed
|
25 |
-
|
26 |
## Intended uses & limitations
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
More information needed
|
33 |
|
34 |
## Training procedure
|
35 |
|
|
|
14 |
|
15 |
# beit-sketch-classifier
|
16 |
|
17 |
+
This model is a version of [microsoft/beit-base-patch16-224-pt22k-ft22k](https://huggingface.co/microsoft/beit-base-patch16-224-pt22k-ft22k) fine-tuned on a dataset of Quick!Draw! sketches ([1 percent of the 50M sketches](https://huggingface.co/datasets/kmewhort/quickdraw-bins-1pct-sample)).
|
18 |
It achieves the following results on the evaluation set:
|
19 |
- Loss: 1.6083
|
20 |
- Accuracy: 0.7480
|
21 |
|
|
|
|
|
|
|
|
|
22 |
## Intended uses & limitations
|
23 |
|
24 |
+
It's intended to be used to classifier sketches with a line-segment input format (there's no data augmentation in the fine-tuning; the input raster images ideally need to be generated from line-vector format very similarly to the training images).
|
25 |
+
|
26 |
+
You can generate the requisite PIL images from Quickdraw `bin` format with the following:
|
27 |
+
|
28 |
+
```
|
29 |
+
# packed bytes -> dict (fro mhttps://github.com/googlecreativelab/quickdraw-dataset/blob/master/examples/binary_file_parser.py)
|
30 |
+
def unpack_drawing(file_handle):
|
31 |
+
key_id, = unpack('Q', file_handle.read(8))
|
32 |
+
country_code, = unpack('2s', file_handle.read(2))
|
33 |
+
recognized, = unpack('b', file_handle.read(1))
|
34 |
+
timestamp, = unpack('I', file_handle.read(4))
|
35 |
+
n_strokes, = unpack('H', file_handle.read(2))
|
36 |
+
image = []
|
37 |
+
n_bytes = 17
|
38 |
+
for i in range(n_strokes):
|
39 |
+
n_points, = unpack('H', file_handle.read(2))
|
40 |
+
fmt = str(n_points) + 'B'
|
41 |
+
x = unpack(fmt, file_handle.read(n_points))
|
42 |
+
y = unpack(fmt, file_handle.read(n_points))
|
43 |
+
image.append((x, y))
|
44 |
+
n_bytes += 2 + 2*n_points
|
45 |
+
result = {
|
46 |
+
'key_id': key_id,
|
47 |
+
'country_code': country_code,
|
48 |
+
'recognized': recognized,
|
49 |
+
'timestamp': timestamp,
|
50 |
+
'image': image,
|
51 |
+
}
|
52 |
+
return result
|
53 |
+
|
54 |
+
# packed bin -> RGB PIL
|
55 |
+
def binToPIL(packed_drawing):
|
56 |
+
padding = 8
|
57 |
+
radius = 7
|
58 |
+
scale = (224.0-(2*padding)) / 256
|
59 |
+
|
60 |
+
unpacked = unpack_drawing(io.BytesIO(packed_drawing))
|
61 |
+
unpacked_image = unpacked['image']
|
62 |
+
image = np.full((224,224), 255, np.uint8)
|
63 |
+
for stroke in unpacked['image']:
|
64 |
+
prevX = round(stroke[0][0]*scale)
|
65 |
+
prevY = round(stroke[1][0]*scale)
|
66 |
+
for i in range(1, len(stroke[0])):
|
67 |
+
x = round(stroke[0][i]*scale)
|
68 |
+
y = round(stroke[1][i]*scale)
|
69 |
+
cv2.line(image, (padding+prevX, padding+prevY), (padding+x, padding+y), 0, radius, -1)
|
70 |
+
prevX = x
|
71 |
+
prevY = y
|
72 |
+
pilImage = Image.fromarray(image).convert("RGB")
|
73 |
+
return pilImage
|
74 |
+
```
|
75 |
|
|
|
76 |
|
77 |
## Training procedure
|
78 |
|