htigenai commited on
Commit
5654017
1 Parent(s): dcff809

improved error handling and logging

Browse files
Files changed (1) hide show
  1. app.py +148 -45
app.py CHANGED
@@ -1,54 +1,157 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
 
 
 
4
 
5
- print("Loading model and tokenizer...")
6
-
7
- # Initialize model and tokenizer
8
- model_id = "htigenai/finetune_test" # your model ID
9
- tokenizer = AutoTokenizer.from_pretrained(model_id)
10
- model = AutoModelForCausalLM.from_pretrained(
11
- model_id,
12
- torch_dtype=torch.float16,
13
- device_map="auto",
14
- load_in_8bit=True # Use 8-bit quantization to reduce memory usage
15
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def generate_text(prompt):
18
- """Generate text based on the input prompt."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  try:
20
- # Tokenize input
21
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
22
-
23
- # Generate
24
- outputs = model.generate(
25
- **inputs,
26
- max_new_tokens=200,
27
- temperature=0.7,
28
- top_p=0.95,
29
- do_sample=True,
30
- pad_token_id=tokenizer.pad_token_id,
31
- eos_token_id=tokenizer.eos_token_id
32
  )
33
-
34
- # Decode and return
35
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
36
  except Exception as e:
37
- return f"Error during generation: {str(e)}"
38
-
39
- # Create Gradio interface
40
- iface = gr.Interface(
41
- fn=generate_text,
42
- inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here..."),
43
- outputs=gr.Textbox(),
44
- title="Text Generation Model",
45
- description="Enter a prompt and get AI-generated text",
46
- examples=[
47
- ["What are your thoughts about cats?"],
48
- ["Write a short story about a magical forest"],
49
- ["Explain quantum computing to a 5-year-old"],
50
- ]
51
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Launch the interface
54
- iface.launch()
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  import torch
4
+ import logging
5
+ import sys
6
+ import os
7
+ import psutil
8
+ import gc
9
 
10
+ # Set up logging
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='%(asctime)s - %(levelname)s - %(message)s',
14
+ handlers=[
15
+ logging.StreamHandler(sys.stdout)
16
+ ]
 
 
 
17
  )
18
+ logger = logging.getLogger(__name__)
19
+
20
+ def log_system_info():
21
+ """Log system information for debugging"""
22
+ logger.info(f"Python version: {sys.version}")
23
+ logger.info(f"PyTorch version: {torch.__version__}")
24
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
25
+ if torch.cuda.is_available():
26
+ logger.info(f"CUDA version: {torch.version.cuda}")
27
+ logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
28
+ logger.info(f"CPU count: {psutil.cpu_count()}")
29
+ logger.info(f"Memory available: {psutil.virtual_memory().available / (1024 * 1024 * 1024):.2f} GB")
30
+
31
+ def cleanup_memory():
32
+ """Clean up memory"""
33
+ gc.collect()
34
+ if torch.cuda.is_available():
35
+ torch.cuda.empty_cache()
36
+ torch.cuda.synchronize()
37
 
38
+ print("Starting application...")
39
+ log_system_info()
40
+
41
+ try:
42
+ print("Loading model and tokenizer...")
43
+
44
+ # Initialize model and tokenizer with error handling
45
+ model_id = "htigenai/finetune_test" # your model ID
46
+
47
+ # Configure quantization
48
+ quantization_config = BitsAndBytesConfig(
49
+ load_in_4bit=True,
50
+ bnb_4bit_compute_dtype=torch.float16,
51
+ bnb_4bit_use_double_quant=True,
52
+ bnb_4bit_quant_type="nf4"
53
+ )
54
+
55
+ # Load tokenizer with error handling
56
  try:
57
+ tokenizer = AutoTokenizer.from_pretrained(
58
+ model_id,
59
+ trust_remote_code=True
 
 
 
 
 
 
 
 
 
60
  )
61
+ logger.info("Tokenizer loaded successfully")
 
 
62
  except Exception as e:
63
+ logger.error(f"Error loading tokenizer: {str(e)}")
64
+ raise
65
+
66
+ # Load model with error handling
67
+ try:
68
+ model = AutoModelForCausalLM.from_pretrained(
69
+ model_id,
70
+ device_map="auto",
71
+ torch_dtype=torch.float16,
72
+ quantization_config=quantization_config,
73
+ trust_remote_code=True,
74
+ low_cpu_mem_usage=True
75
+ )
76
+ logger.info("Model loaded successfully")
77
+ except Exception as e:
78
+ logger.error(f"Error loading model: {str(e)}")
79
+ raise
80
+
81
+ def generate_text(prompt):
82
+ """Generate text based on the input prompt."""
83
+ try:
84
+ logger.info(f"Generating text for prompt: {prompt[:50]}...")
85
+
86
+ # Clean up memory before generation
87
+ cleanup_memory()
88
+
89
+ # Tokenize input
90
+ inputs = tokenizer(
91
+ prompt,
92
+ return_tensors="pt",
93
+ padding=True,
94
+ truncation=True,
95
+ max_length=512
96
+ ).to(model.device)
97
+
98
+ # Generate
99
+ with torch.inference_mode():
100
+ outputs = model.generate(
101
+ **inputs,
102
+ max_new_tokens=200,
103
+ temperature=0.7,
104
+ top_p=0.95,
105
+ do_sample=True,
106
+ pad_token_id=tokenizer.pad_token_id,
107
+ eos_token_id=tokenizer.eos_token_id,
108
+ repetition_penalty=1.1
109
+ )
110
+
111
+ # Decode and return
112
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
113
+ logger.info("Text generation completed successfully")
114
+
115
+ # Clean up memory after generation
116
+ cleanup_memory()
117
+
118
+ return generated_text
119
+
120
+ except Exception as e:
121
+ logger.error(f"Error during generation: {str(e)}")
122
+ return f"Error during generation: {str(e)}"
123
+
124
+ # Create Gradio interface
125
+ iface = gr.Interface(
126
+ fn=generate_text,
127
+ inputs=gr.Textbox(
128
+ lines=3,
129
+ placeholder="Enter your prompt here...",
130
+ label="Input Prompt"
131
+ ),
132
+ outputs=gr.Textbox(
133
+ label="Generated Response",
134
+ lines=5
135
+ ),
136
+ title="Text Generation Model",
137
+ description="Enter a prompt and get AI-generated text. Please be patient as generation may take a few moments.",
138
+ examples=[
139
+ ["What are your thoughts about cats?"],
140
+ ["Write a short story about a magical forest"],
141
+ ["Explain quantum computing to a 5-year-old"],
142
+ ],
143
+ allow_flagging="never",
144
+ cache_examples=False,
145
+ )
146
+
147
+ # Launch the interface
148
+ iface.launch(
149
+ share=False,
150
+ debug=True,
151
+ show_error=True,
152
+ server_name="0.0.0.0"
153
+ )
154
 
155
+ except Exception as e:
156
+ logger.error(f"Application startup failed: {str(e)}")
157
+ raise