zhangxunhui commited on
Commit
5d1a8a6
1 Parent(s): b07a621

share true

Browse files
Files changed (2) hide show
  1. 01-train_a_model.ipynb +83 -0
  2. app.py +1 -1
01-train_a_model.ipynb ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "# 读取本地的图片和标注信息\n",
10
+ "import pandas as pd\n",
11
+ "train_csv = pd.read_csv('../dataset/train.csv')\n",
12
+ "n_inp = len(set(train_csv['label']))\n",
13
+ "train_csv.head()"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "def label_func(item):\n",
23
+ " rel_path = str(item.relative_to('dataset/train'))\n",
24
+ " return train_csv[train_csv['image_ID']==rel_path][\"label\"].values[0]"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "from fastai.data.all import *\n",
34
+ "\n",
35
+ "dataloader = DataBlock(\n",
36
+ " blocks=(ImageBlock, CategoryBlock),\n",
37
+ " get_items=get_image_files,\n",
38
+ " get_y=label_func,\n",
39
+ " splitter=RandomSplitter(valid_pct=0.2, seed=42),\n",
40
+ " item_tfms=Resize(224)\n",
41
+ ").dataloaders('dataset/train')"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "dataloader.show_batch(max_n=6)"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "learn = vision_learner(dataloader, resnet18, metrics=error_rate)\n",
60
+ "learn.fine_tune(3)"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "test_csv = pd.read_csv('dataset/test.csv')\n",
70
+ "test_image = PILImage.create('dataset/test/0b84e400d4.jpg')\n",
71
+ "sport,_,probs = learn.predict(test_image)\n",
72
+ "print(f\"This is {sport}\")"
73
+ ]
74
+ }
75
+ ],
76
+ "metadata": {
77
+ "language_info": {
78
+ "name": "python"
79
+ }
80
+ },
81
+ "nbformat": 4,
82
+ "nbformat_minor": 2
83
+ }
app.py CHANGED
@@ -4,4 +4,4 @@ def greet(name):
4
  return "Hello " + name + "!!"
5
 
6
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
4
  return "Hello " + name + "!!"
5
 
6
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ iface.launch(share=True)