danilommarano commited on
Commit
25a44a2
·
1 Parent(s): 10d26eb

Script for digit recognition model

Browse files
Files changed (1) hide show
  1. model.py +38 -0
model.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import tensorflow as tf
3
+ from tensorflow.keras.datasets import mnist
4
+ from tensorflow.keras.models import Sequential
5
+ from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
6
+
7
+ (x_train, y_train), (x_test, y_test) = mnist.load_data()
8
+
9
+ # Normalize the pixel values to range [0, 1]
10
+ x_train = x_train / 255.0
11
+ x_test = x_test / 255.0
12
+
13
+ # Reshape the data to 4D (number of samples, height, width, channels)
14
+ x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
15
+ x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
16
+
17
+ # Create the model
18
+ model = Sequential([
19
+ Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
20
+ MaxPooling2D((2, 2)),
21
+ Conv2D(64, (3, 3), activation='relu'),
22
+ MaxPooling2D((2, 2)),
23
+ Flatten(),
24
+ Dense(128, activation='relu'),
25
+ Dense(10, activation='softmax')
26
+ ])
27
+
28
+ # Compile the model
29
+ model.compile(optimizer='adam',
30
+ loss='sparse_categorical_crossentropy',
31
+ metrics=['accuracy'])
32
+
33
+ # Train the model
34
+ model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.1)
35
+
36
+ # Save the model
37
+ path = Path(Path(__file__).parent, 'model.h5')
38
+ model.save(path)