tezuesh commited on
Commit
85dd38b
·
verified ·
1 Parent(s): f2cde80

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. inference.py +19 -71
  2. inference_util.py +87 -0
inference.py CHANGED
@@ -1,87 +1,35 @@
1
- # inference.py
2
  import torch
3
- from torchvision import transforms, datasets
4
- from PIL import Image
5
- import json
6
  from pathlib import Path
7
- from model import MNISTModel
8
  import os
9
  import sys
 
 
10
 
11
- class Inferencer:
12
- def __init__(self, input_dir: str = 'input_data', output_dir: str = 'output_data'):
13
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
- self.model, _ = self._load_model()
15
- self.input_dir = Path(input_dir)
16
- self.output_dir = Path(output_dir)
17
- self.transform = transforms.Compose([
18
- transforms.ToTensor(),
19
- transforms.Normalize((0.1307,), (0.3081,))
20
- ])
21
-
22
- def _load_model(self, model_path='best_model.pth'):
23
- """Load the trained model."""
24
- model = MNISTModel().to(self.device)
25
- model.load_state_dict(
26
- torch.load(model_path, map_location=self.device, weights_only=True)
27
- )
28
- model.eval()
29
- return model, self.device
30
-
31
- def predict(self, input_tensor: torch.Tensor):
32
- """Make prediction on the input tensor."""
33
- with torch.no_grad():
34
- if input_tensor.dim() == 3:
35
- input_tensor = input_tensor.unsqueeze(0)
36
-
37
- input_tensor = input_tensor.to(self.device)
38
- output = self.model(input_tensor)
39
- probs = torch.softmax(output, dim=1)
40
- prediction = output.argmax(1).item()
41
- confidence = probs[0][prediction].item()
42
- return prediction, confidence
43
 
44
- def process_input(self):
45
- """Process all images in input directory."""
46
- # Create output directory if it doesn't exist
47
- os.makedirs(self.output_dir, exist_ok=True)
48
-
49
- results = []
50
- # Process each file in input directory
51
- for file_path in sorted(self.input_dir.glob('*.pt')): # For tensor files
52
- try:
53
- # Load tensor
54
- input_tensor = torch.load(file_path)
55
-
56
- # Get prediction
57
- prediction, confidence = self.predict(input_tensor)
58
-
59
- results.append({
60
- "filename": file_path.name,
61
- "prediction": prediction,
62
- "confidence": confidence
63
- })
64
-
65
- except Exception as e:
66
- print(f"Error processing {file_path}: {str(e)}", file=sys.stderr)
67
-
68
- # Save results
69
- with open(self.output_dir / 'results.json', 'w') as f:
70
- json.dump(results, f, indent=2)
71
-
72
- return results
73
 
74
  def main():
75
- # Accept input/output directories as arguments
76
  import argparse
77
  parser = argparse.ArgumentParser()
 
78
  parser.add_argument('--input-dir', default='input_data')
79
  parser.add_argument('--output-dir', default='output_data')
80
  args = parser.parse_args()
81
 
82
- inferencer = Inferencer(args.input_dir, args.output_dir)
83
- results = inferencer.process_input()
84
- print(f"Processed {len(results)} inputs")
85
 
86
  if __name__ == "__main__":
87
- main()
 
 
1
  import torch
2
+ from torchvision import transforms
 
 
3
  from pathlib import Path
4
+ import json
5
  import os
6
  import sys
7
+ from model import MNISTModel
8
+ from inference_util import Inferencer
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ class InferenceWrapper:
12
+ def __init__(self, model_path: str, input_dir: str = 'input_data', output_dir: str = 'output_data'):
13
+ self.model_path = model_path
14
+ self.inferencer = Inferencer(input_dir, output_dir)
15
+ # Override the model with our specified model path
16
+ self.inferencer.model, _ = self.inferencer._load_model(model_path)
17
+
18
+ def run_inference(self):
19
+ """Run inference using the specified model"""
20
+ return self.inferencer.process_input()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def main():
 
23
  import argparse
24
  parser = argparse.ArgumentParser()
25
+ parser.add_argument('--model-path', required=True, help='Path to the model weights')
26
  parser.add_argument('--input-dir', default='input_data')
27
  parser.add_argument('--output-dir', default='output_data')
28
  args = parser.parse_args()
29
 
30
+ wrapper = InferenceWrapper(args.model_path, args.input_dir, args.output_dir)
31
+ results = wrapper.run_inference()
32
+ print(f"Processed {len(results)} inputs using model: {args.model_path}")
33
 
34
  if __name__ == "__main__":
35
+ main()
inference_util.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import torch
3
+ from torchvision import transforms, datasets
4
+ from PIL import Image
5
+ import json
6
+ from pathlib import Path
7
+ from model import MNISTModel
8
+ import os
9
+ import sys
10
+
11
+ class Inferencer:
12
+ def __init__(self, input_dir: str = 'input_data', output_dir: str = 'output_data'):
13
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ self.model, _ = self._load_model()
15
+ self.input_dir = Path(input_dir)
16
+ self.output_dir = Path(output_dir)
17
+ self.transform = transforms.Compose([
18
+ transforms.ToTensor(),
19
+ transforms.Normalize((0.1307,), (0.3081,))
20
+ ])
21
+
22
+ def _load_model(self, model_path='best_model.pth'):
23
+ """Load the trained model."""
24
+ model = MNISTModel().to(self.device)
25
+ model.load_state_dict(
26
+ torch.load(model_path, map_location=self.device, weights_only=True)
27
+ )
28
+ model.eval()
29
+ return model, self.device
30
+
31
+ def predict(self, input_tensor: torch.Tensor):
32
+ """Make prediction on the input tensor."""
33
+ with torch.no_grad():
34
+ if input_tensor.dim() == 3:
35
+ input_tensor = input_tensor.unsqueeze(0)
36
+
37
+ input_tensor = input_tensor.to(self.device)
38
+ output = self.model(input_tensor)
39
+ probs = torch.softmax(output, dim=1)
40
+ prediction = output.argmax(1).item()
41
+ confidence = probs[0][prediction].item()
42
+ return prediction, confidence
43
+
44
+ def process_input(self):
45
+ """Process all images in input directory."""
46
+ # Create output directory if it doesn't exist
47
+ os.makedirs(self.output_dir, exist_ok=True)
48
+
49
+ results = []
50
+ # Process each file in input directory
51
+ for file_path in sorted(self.input_dir.glob('*.pt')): # For tensor files
52
+ try:
53
+ # Load tensor
54
+ input_tensor = torch.load(file_path)
55
+
56
+ # Get prediction
57
+ prediction, confidence = self.predict(input_tensor)
58
+
59
+ results.append({
60
+ "filename": file_path.name,
61
+ "prediction": prediction,
62
+ "confidence": confidence
63
+ })
64
+
65
+ except Exception as e:
66
+ print(f"Error processing {file_path}: {str(e)}", file=sys.stderr)
67
+
68
+ # Save results
69
+ with open(self.output_dir / 'results.json', 'w') as f:
70
+ json.dump(results, f, indent=2)
71
+
72
+ return results
73
+
74
+ def main():
75
+ # Accept input/output directories as arguments
76
+ import argparse
77
+ parser = argparse.ArgumentParser()
78
+ parser.add_argument('--input-dir', default='input_data')
79
+ parser.add_argument('--output-dir', default='output_data')
80
+ args = parser.parse_args()
81
+
82
+ inferencer = Inferencer(args.input_dir, args.output_dir)
83
+ results = inferencer.process_input()
84
+ print(f"Processed {len(results)} inputs")
85
+
86
+ if __name__ == "__main__":
87
+ main()