kmewhort's picture
Update README.md
e0f58aa
|
raw
history blame
3.61 kB
metadata
license: apache-2.0
tags:
  - generated_from_trainer
metrics:
  - accuracy
model-index:
  - name: beit-sketch-classifier
    results: []

beit-sketch-classifier

This model is a version of microsoft/beit-base-patch16-224-pt22k-ft22k fine-tuned on a dataset of Quick!Draw! sketches (~10% of QuickDraw's 50M sketches). It achieves the following results on the evaluation set:

  • Loss: 0.7372
  • Accuracy: 0.8098

Intended uses & limitations

It's intended to be used to classifier sketches with a line-segment input format (there's no data augmentation in the fine-tuning; the input raster images ideally need to be generated from line-vector format very similarly to the training images).

You can generate the requisite PIL images from Quickdraw bin format with the following:

# packed bytes -> dict (fro mhttps://github.com/googlecreativelab/quickdraw-dataset/blob/master/examples/binary_file_parser.py)
def unpack_drawing(file_handle):
    key_id, = unpack('Q', file_handle.read(8))
    country_code, = unpack('2s', file_handle.read(2))
    recognized, = unpack('b', file_handle.read(1))
    timestamp, = unpack('I', file_handle.read(4))
    n_strokes, = unpack('H', file_handle.read(2))
    image = []
    n_bytes = 17
    for i in range(n_strokes):
        n_points, = unpack('H', file_handle.read(2))
        fmt = str(n_points) + 'B'
        x = unpack(fmt, file_handle.read(n_points))
        y = unpack(fmt, file_handle.read(n_points))
        image.append((x, y))
        n_bytes += 2 + 2*n_points
    result = {
        'key_id': key_id,
        'country_code': country_code,
        'recognized': recognized,
        'timestamp': timestamp,
        'image': image,
    }
    return result

# packed bin -> RGB PIL
def binToPIL(packed_drawing):
    padding = 8
    radius = 7
    scale = (224.0-(2*padding)) / 256
    
    unpacked = unpack_drawing(io.BytesIO(packed_drawing))
    unpacked_image = unpacked['image']
    image = np.full((224,224), 255, np.uint8)
    for stroke in unpacked['image']:
        prevX = round(stroke[0][0]*scale)
        prevY = round(stroke[1][0]*scale)
        for i in range(1, len(stroke[0])):
            x = round(stroke[0][i]*scale)
            y = round(stroke[1][i]*scale)
            cv2.line(image, (padding+prevX, padding+prevY), (padding+x, padding+y), 0, radius, -1)
            prevX = x
            prevY = y
    pilImage = Image.fromarray(image).convert("RGB")     
    return pilImage

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 5e-05
  • train_batch_size: 64
  • eval_batch_size: 64
  • seed: 42
  • gradient_accumulation_steps: 4
  • total_train_batch_size: 256
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_ratio: 0.1
  • num_epochs: 3

Training results

Training Loss Epoch Step Accuracy Validation Loss
0.939 1.0 12606 0.7853 0.8275
0.7312 2.0 25212 0.7587 0.8027
0.6174 3.0 37818 0.7372 0.8098

Framework versions

  • Transformers 4.25.1
  • Pytorch 1.13.1+cu117
  • Datasets 2.7.1
  • Tokenizers 0.13.2