acecalisto3 commited on
Commit
18caca1
1 Parent(s): ca75e75

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +88 -0
handler.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import datetime
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ # Set up logging
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Load configuration settings from a separate file (config.json)
11
+ # Example configuration file:
12
+ # {
13
+ # "model_name": "acecalisto3/InstructiPhi",
14
+ # "max_length": 50,
15
+ # "logging_level": "INFO"
16
+ # }
17
+ try:
18
+ with open('config.json') as f:
19
+ config = json.load(f)
20
+ except FileNotFoundError:
21
+ logger.error("Configuration file 'config.json' not found. Using default settings.")
22
+ config = {
23
+ "model_name": "acecalisto3/InstructiPhi", # Default model name
24
+ "max_length": 50, # Default max length
25
+ "logging_level": "INFO" # Default logging level
26
+ }
27
+
28
+ # Load model and tokenizer
29
+ model_name = config["model_name"]
30
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
31
+ model = AutoModelForCausalLM.from_pretrained(model_name)
32
+
33
+ # Set logging level from configuration
34
+ logging.basicConfig(level=config["logging_level"])
35
+
36
+ def handle_request(event, context):
37
+ """Handles incoming requests to the deployed model.
38
+
39
+ Args:
40
+ event: The event data from the deployment platform.
41
+ context: The context data from the deployment platform.
42
+
43
+ Returns:
44
+ A dictionary containing the response status code and body.
45
+ """
46
+
47
+ try:
48
+ # Extract input text from the event
49
+ input_text = event.get('body')
50
+ if not input_text:
51
+ return {
52
+ 'statusCode': 400,
53
+ 'body': json.dumps({'error': 'Missing input text'})
54
+ }
55
+
56
+ # Input validation: Check length
57
+ if len(input_text) > 1000: # Set a reasonable limit
58
+ return {
59
+ 'statusCode': 400,
60
+ 'body': json.dumps({'error': 'Input text is too long'})
61
+ }
62
+
63
+ # Tokenize the input text
64
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
65
+
66
+ # Generate the response using the model
67
+ output = model.generate(input_ids, max_length=config["max_length"])
68
+
69
+ # Decode the generated response
70
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
71
+
72
+ # Return a successful response with structured output
73
+ return {
74
+ 'statusCode': 200,
75
+ 'body': json.dumps({
76
+ 'response': generated_text,
77
+ 'model': model_name, # Include model name in the output
78
+ 'timestamp': datetime.datetime.now().isoformat()
79
+ })
80
+ }
81
+
82
+ except Exception as e:
83
+ # Log the error with more context
84
+ logger.error(f"Error processing request: {e}, input: {input_text}")
85
+ return {
86
+ 'statusCode': 500,
87
+ 'body': json.dumps({'error': 'Internal server error'})
88
+ }