chainyo commited on
Commit
56db83f
·
1 Parent(s): 5d5f276

add config

Browse files
Files changed (3) hide show
  1. README.md +42 -0
  2. finetuning.ipynb +60 -0
  3. training_loop.py +16 -3
README.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - vision
5
+ - image-segmentation
6
+ datasets:
7
+ - segments/sidewalk-semantic
8
+ ---
9
+
10
+ # SegFormer (b0-sized) model fine-tuned on sidewalk-semantic dataset
11
+
12
+ SegFormer model fine-tuned on segments/sidewalk-semantic at resolution 512x512. It was introduced in the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Xie et al. and first released in [this repository](https://github.com/NVlabs/SegFormer).
13
+
14
+ ## Model description
15
+
16
+ SegFormer consists of a hierarchical Transformer encoder and a lightweight all-MLP decode head to achieve great results on semantic segmentation benchmarks such as ADE20K and Cityscapes. The hierarchical Transformer is first pre-trained on ImageNet-1k, after which a decode head is added and fine-tuned altogether on a downstream dataset.
17
+
18
+ ## Intended uses & limitations
19
+
20
+ You can use the raw model for semantic segmentation. See the [model hub](https://huggingface.co/models?other=segformer) to look for fine-tuned versions on a task that interests you.
21
+
22
+ ### How to use
23
+
24
+ Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes:
25
+
26
+ ```python
27
+ from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
28
+ from PIL import Image
29
+ import requests
30
+
31
+ feature_extractor = SegformerFeatureExtractor(reduce_labels=True)
32
+ model = SegformerForSemanticSegmentation.from_pretrained("ChainYo/segformer-sidewalk")
33
+
34
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
35
+ image = Image.open(requests.get(url, stream=True).raw)
36
+
37
+ inputs = feature_extractor(images=image, return_tensors="pt")
38
+ outputs = model(**inputs)
39
+ logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4)
40
+ ```
41
+
42
+ For more code examples, we refer to the [documentation](https://huggingface.co/transformers/model_doc/segformer.html#).
finetuning.ipynb CHANGED
@@ -938,6 +938,66 @@
938
  "batch = next(iter(train_dataloader))"
939
  ]
940
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
941
  {
942
  "cell_type": "code",
943
  "execution_count": null,
 
938
  "batch = next(iter(train_dataloader))"
939
  ]
940
  },
941
+ {
942
+ "cell_type": "code",
943
+ "execution_count": 31,
944
+ "metadata": {},
945
+ "outputs": [
946
+ {
947
+ "name": "stderr",
948
+ "output_type": "stream",
949
+ "text": [
950
+ "Cloning https://huggingface.co/ChainYo/segformer-sidewalk into local empty directory.\n",
951
+ "remote: Enforcing permissions... \n",
952
+ "remote: Allowed refs: all \n",
953
+ "To https://huggingface.co/ChainYo/segformer-sidewalk\n",
954
+ " c75c928..5d5f276 main -> main\n",
955
+ "\n"
956
+ ]
957
+ },
958
+ {
959
+ "ename": "OSError",
960
+ "evalue": "It looks like the config file at '/home/chainyo/code/segformer-sidewalk/checkpoints/epoch=44-step=1125.ckpt' is not a valid JSON file.",
961
+ "output_type": "error",
962
+ "traceback": [
963
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
964
+ "\u001b[0;31mUnicodeDecodeError\u001b[0m Traceback (most recent call last)",
965
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py:650\u001b[0m, in \u001b[0;36mPretrainedConfig._get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=647'>648</a>\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=648'>649</a>\u001b[0m \u001b[39m# Load config dict\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=649'>650</a>\u001b[0m config_dict \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49m_dict_from_json_file(resolved_config_file)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=650'>651</a>\u001b[0m \u001b[39mexcept\u001b[39;00m (json\u001b[39m.\u001b[39mJSONDecodeError, \u001b[39mUnicodeDecodeError\u001b[39;00m):\n",
966
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py:733\u001b[0m, in \u001b[0;36mPretrainedConfig._dict_from_json_file\u001b[0;34m(cls, json_file)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=731'>732</a>\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39m(json_file, \u001b[39m\"\u001b[39m\u001b[39mr\u001b[39m\u001b[39m\"\u001b[39m, encoding\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mutf-8\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mas\u001b[39;00m reader:\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=732'>733</a>\u001b[0m text \u001b[39m=\u001b[39m reader\u001b[39m.\u001b[39;49mread()\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=733'>734</a>\u001b[0m \u001b[39mreturn\u001b[39;00m json\u001b[39m.\u001b[39mloads(text)\n",
967
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/codecs.py:322\u001b[0m, in \u001b[0;36mBufferedIncrementalDecoder.decode\u001b[0;34m(self, input, final)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/codecs.py?line=320'>321</a>\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuffer \u001b[39m+\u001b[39m \u001b[39minput\u001b[39m\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/codecs.py?line=321'>322</a>\u001b[0m (result, consumed) \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_buffer_decode(data, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49merrors, final)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/codecs.py?line=322'>323</a>\u001b[0m \u001b[39m# keep undecoded input until the next call\u001b[39;00m\n",
968
+ "\u001b[0;31mUnicodeDecodeError\u001b[0m: 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte",
969
+ "\nDuring handling of the above exception, another exception occurred:\n",
970
+ "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
971
+ "\u001b[1;32m/home/chainyo/code/segformer-sidewalk/finetuning.ipynb Cell 23'\u001b[0m in \u001b[0;36m<cell line: 11>\u001b[0;34m()\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=7'>8</a>\u001b[0m config \u001b[39m=\u001b[39m AutoConfig\u001b[39m.\u001b[39mfrom_pretrained(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mnvidia/mit-b0\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=8'>9</a>\u001b[0m config\u001b[39m.\u001b[39mpush_to_hub(\u001b[39m\"\u001b[39m\u001b[39msegformer-sidewalk\u001b[39m\u001b[39m\"\u001b[39m, repo_url\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mhttps://huggingface.co/ChainYo/segformer-sidewalk\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=10'>11</a>\u001b[0m model \u001b[39m=\u001b[39m SegformerForSemanticSegmentation\u001b[39m.\u001b[39;49mfrom_pretrained(\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=11'>12</a>\u001b[0m \u001b[39m\"\u001b[39;49m\u001b[39m/home/chainyo/code/segformer-sidewalk/checkpoints/epoch=44-step=1125.ckpt\u001b[39;49m\u001b[39m\"\u001b[39;49m, \n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=12'>13</a>\u001b[0m num_labels\u001b[39m=\u001b[39;49mnum_labels, \n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=13'>14</a>\u001b[0m id2label\u001b[39m=\u001b[39;49mid2label, \n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=14'>15</a>\u001b[0m label2id\u001b[39m=\u001b[39;49mid2label,\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=15'>16</a>\u001b[0m )\n\u001b[1;32m <a href='vscode-notebook-cell:/home/chainyo/code/segformer-sidewalk/finetuning.ipynb#ch0000016?line=16'>17</a>\u001b[0m model\u001b[39m.\u001b[39mpush_to_hub(\u001b[39m\"\u001b[39m\u001b[39msegformer-sidewalk\u001b[39m\u001b[39m\"\u001b[39m, repo_url\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mhttps://huggingface.co/ChainYo/segformer-sidewalk\u001b[39m\u001b[39m\"\u001b[39m)\n",
972
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py:1764\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1761'>1762</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1762'>1763</a>\u001b[0m config_path \u001b[39m=\u001b[39m config \u001b[39mif\u001b[39;00m config \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1763'>1764</a>\u001b[0m config, model_kwargs \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49mconfig_class\u001b[39m.\u001b[39;49mfrom_pretrained(\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1764'>1765</a>\u001b[0m config_path,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1765'>1766</a>\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1766'>1767</a>\u001b[0m return_unused_kwargs\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1767'>1768</a>\u001b[0m force_download\u001b[39m=\u001b[39;49mforce_download,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1768'>1769</a>\u001b[0m resume_download\u001b[39m=\u001b[39;49mresume_download,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1769'>1770</a>\u001b[0m proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1770'>1771</a>\u001b[0m local_files_only\u001b[39m=\u001b[39;49mlocal_files_only,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1771'>1772</a>\u001b[0m use_auth_token\u001b[39m=\u001b[39;49muse_auth_token,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1772'>1773</a>\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1773'>1774</a>\u001b[0m _from_auto\u001b[39m=\u001b[39;49mfrom_auto_class,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1774'>1775</a>\u001b[0m _from_pipeline\u001b[39m=\u001b[39;49mfrom_pipeline,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1775'>1776</a>\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1776'>1777</a>\u001b[0m )\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1777'>1778</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/modeling_utils.py?line=1778'>1779</a>\u001b[0m model_kwargs \u001b[39m=\u001b[39m kwargs\n",
973
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py:526\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=451'>452</a>\u001b[0m \u001b[39m@classmethod\u001b[39m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=452'>453</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mfrom_pretrained\u001b[39m(\u001b[39mcls\u001b[39m, pretrained_model_name_or_path: Union[\u001b[39mstr\u001b[39m, os\u001b[39m.\u001b[39mPathLike], \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mPretrainedConfig\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=453'>454</a>\u001b[0m \u001b[39mr\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=454'>455</a>\u001b[0m \u001b[39m Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.\u001b[39;00m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=455'>456</a>\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=523'>524</a>\u001b[0m \u001b[39m assert unused_kwargs == {\"foo\": False}\u001b[39;00m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=524'>525</a>\u001b[0m \u001b[39m ```\"\"\"\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=525'>526</a>\u001b[0m config_dict, kwargs \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49mget_config_dict(pretrained_model_name_or_path, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=526'>527</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m config_dict \u001b[39mand\u001b[39;00m \u001b[39mhasattr\u001b[39m(\u001b[39mcls\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mand\u001b[39;00m config_dict[\u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m!=\u001b[39m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mmodel_type:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=527'>528</a>\u001b[0m logger\u001b[39m.\u001b[39mwarning(\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=528'>529</a>\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mYou are using a model of type \u001b[39m\u001b[39m{\u001b[39;00mconfig_dict[\u001b[39m'\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m'\u001b[39m]\u001b[39m}\u001b[39;00m\u001b[39m to instantiate a model of type \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=529'>530</a>\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mmodel_type\u001b[39m}\u001b[39;00m\u001b[39m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=530'>531</a>\u001b[0m )\n",
974
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py:553\u001b[0m, in \u001b[0;36mPretrainedConfig.get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=550'>551</a>\u001b[0m original_kwargs \u001b[39m=\u001b[39m copy\u001b[39m.\u001b[39mdeepcopy(kwargs)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=551'>552</a>\u001b[0m \u001b[39m# Get config dict associated with the base config file\u001b[39;00m\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=552'>553</a>\u001b[0m config_dict, kwargs \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49m_get_config_dict(pretrained_model_name_or_path, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=554'>555</a>\u001b[0m \u001b[39m# That config file may point us toward another config file to use.\u001b[39;00m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=555'>556</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39mconfiguration_files\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m config_dict:\n",
975
+ "File \u001b[0;32m~/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py:652\u001b[0m, in \u001b[0;36mPretrainedConfig._get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=649'>650</a>\u001b[0m config_dict \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m_dict_from_json_file(resolved_config_file)\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=650'>651</a>\u001b[0m \u001b[39mexcept\u001b[39;00m (json\u001b[39m.\u001b[39mJSONDecodeError, \u001b[39mUnicodeDecodeError\u001b[39;00m):\n\u001b[0;32m--> <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=651'>652</a>\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=652'>653</a>\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mIt looks like the config file at \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mresolved_config_file\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m is not a valid JSON file.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=653'>654</a>\u001b[0m )\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=655'>656</a>\u001b[0m \u001b[39mif\u001b[39;00m resolved_config_file \u001b[39m==\u001b[39m config_file:\n\u001b[1;32m <a href='file:///home/chainyo/miniconda3/envs/segformer/lib/python3.8/site-packages/transformers/configuration_utils.py?line=656'>657</a>\u001b[0m logger\u001b[39m.\u001b[39minfo(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mloading configuration file \u001b[39m\u001b[39m{\u001b[39;00mconfig_file\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
976
+ "\u001b[0;31mOSError\u001b[0m: It looks like the config file at '/home/chainyo/code/segformer-sidewalk/checkpoints/epoch=44-step=1125.ckpt' is not a valid JSON file."
977
+ ]
978
+ }
979
+ ],
980
+ "source": [
981
+ "import json\n",
982
+ "from transformers import AutoConfig\n",
983
+ "\n",
984
+ "id2label_file = json.load(open(\"id2label.json\", \"r\"))\n",
985
+ "id2label = {int(k): v for k, v in id2label_file.items()}\n",
986
+ "num_labels = len(id2label)\n",
987
+ "\n",
988
+ "config = AutoConfig.from_pretrained(f\"nvidia/mit-b0\")\n",
989
+ "config.push_to_hub(\".\", repo_url=\"https://huggingface.co/ChainYo/segformer-sidewalk\")\n",
990
+ "\n",
991
+ "model = SegformerForSemanticSegmentation.from_pretrained(\n",
992
+ " \"/home/chainyo/code/segformer-sidewalk/checkpoints/epoch=44-step=1125.ckpt\", \n",
993
+ " num_labels=num_labels, \n",
994
+ " id2label=id2label, \n",
995
+ " label2id=id2label,\n",
996
+ " config=config,\n",
997
+ ")\n",
998
+ "model.push_to_hub(\".\", repo_url=\"https://huggingface.co/ChainYo/segformer-sidewalk\")"
999
+ ]
1000
+ },
1001
  {
1002
  "cell_type": "code",
1003
  "execution_count": null,
training_loop.py CHANGED
@@ -1,9 +1,9 @@
1
  """
2
  Minimal command:
3
- python training_loop.py --hub_dir "segments/sidewalk-semantic"
4
 
5
  Maximal command:
6
- python training_loop.py --hub_dir "segments/sidewalk-semantic" --batch_size 32 --learning_rate 6e-5 --model_flavor 0 --seed 42 --split train
7
  """
8
 
9
  import json
@@ -12,6 +12,8 @@ import torch
12
  from pytorch_lightning import Trainer, callbacks, seed_everything
13
  from pytorch_lightning.loggers import WandbLogger
14
 
 
 
15
  from dataloader import SidewalkSegmentationDataLoader
16
  from model import SidewalkSegmentationModel
17
 
@@ -23,6 +25,7 @@ def main(
23
  model_flavor: int = 0,
24
  seed: int = 42,
25
  split: str = "train",
 
26
  ):
27
  seed_everything(seed)
28
  logger = WandbLogger(project="sidewalk-segmentation")
@@ -69,6 +72,15 @@ def main(
69
  )
70
  trainer.fit(model, data_module)
71
 
 
 
 
 
 
 
 
 
 
72
 
73
  if __name__ == "__main__":
74
  import argparse
@@ -80,6 +92,7 @@ if __name__ == "__main__":
80
  parser.add_argument("--model_flavor", type=int, default=0)
81
  parser.add_argument("--seed", type=int, default=42)
82
  parser.add_argument("--split", type=str, default="train")
 
83
  args = parser.parse_args()
84
 
85
  main(
@@ -89,5 +102,5 @@ if __name__ == "__main__":
89
  model_flavor=args.model_flavor,
90
  seed=args.seed,
91
  split=args.split,
 
92
  )
93
-
 
1
  """
2
  Minimal command:
3
+ python training_loop.py --hub_dir "segments/sidewalk-semantic" --push_to_hub
4
 
5
  Maximal command:
6
+ python training_loop.py --hub_dir "segments/sidewalk-semantic" --batch_size 32 --learning_rate 6e-5 --model_flavor 0 --seed 42 --split train --push_to_hub
7
  """
8
 
9
  import json
 
12
  from pytorch_lightning import Trainer, callbacks, seed_everything
13
  from pytorch_lightning.loggers import WandbLogger
14
 
15
+ from transformers import AutoConfig, SegformerForSemanticSegmentation, SegformerFeatureExtractor
16
+
17
  from dataloader import SidewalkSegmentationDataLoader
18
  from model import SidewalkSegmentationModel
19
 
 
25
  model_flavor: int = 0,
26
  seed: int = 42,
27
  split: str = "train",
28
+ push_to_hub: bool = False,
29
  ):
30
  seed_everything(seed)
31
  logger = WandbLogger(project="sidewalk-segmentation")
 
72
  )
73
  trainer.fit(model, data_module)
74
 
75
+ if push_to_hub:
76
+ config = AutoConfig.from_pretrained(f"nvidia/mit-b{model_flavor}")
77
+ config.push_to_hub("segformer-sidewalk", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")
78
+ checkpoint_path = checkpoint_callback.best_model_filepath
79
+ model = SegformerForSemanticSegmentation.from_pretrained(
80
+ checkpoint_path, num_labels=num_labels, id2label=id2label, label2id=id2label, config=config,
81
+ )
82
+ model.push_to_hub("segformer-sidewalk", repo_url="https://huggingface.co/ChainYo/segformer-sidewalk")
83
+
84
 
85
  if __name__ == "__main__":
86
  import argparse
 
92
  parser.add_argument("--model_flavor", type=int, default=0)
93
  parser.add_argument("--seed", type=int, default=42)
94
  parser.add_argument("--split", type=str, default="train")
95
+ parser.add_argument("--push_to_hub", action="store_true")
96
  args = parser.parse_args()
97
 
98
  main(
 
102
  model_flavor=args.model_flavor,
103
  seed=args.seed,
104
  split=args.split,
105
+ push_to_hub=args.push_to_hub,
106
  )