wiklif commited on
Commit
8ad8716
·
1 Parent(s): f1cb75e

dodano accelerate i lepsze logowanie błędów

Browse files
Files changed (2) hide show
  1. app.py +46 -19
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,31 +1,62 @@
 
1
  import spaces
2
  import gradio as gr
3
  import transformers
4
  import torch
5
- import os
6
  from huggingface_hub import login
 
 
 
 
7
 
8
  model_id = "meta-llama/Meta-Llama-3.1-8B"
9
 
10
  @spaces.GPU(duration=60)
11
  def load_pipeline():
12
- # Zaloguj się używając tokena
13
- login(token=os.environ.get("MY_API_LLAMA_3_1"))
 
 
14
 
15
- return transformers.pipeline(
16
- "text-generation",
17
- model=model_id,
18
- model_kwargs={"torch_dtype": torch.bfloat16},
19
- device_map="auto"
20
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- pipeline = load_pipeline()
 
 
 
 
23
 
24
  def generate_response(chat, kwargs):
25
- output = pipeline(chat, **kwargs)[0]['generated_text']
26
- if output.endswith("</s>"):
27
- output = output[:-4]
28
- return output
 
 
 
 
 
 
 
29
 
30
  def function(prompt, history=[]):
31
  chat = "<s>"
@@ -42,11 +73,7 @@ def function(prompt, history=[]):
42
  seed=1337
43
  )
44
 
45
- try:
46
- output = generate_response(chat, kwargs)
47
- return output
48
- except:
49
- return ''
50
 
51
  # Interfejs Gradio
52
  interface = gr.ChatInterface(
 
1
+ import os
2
  import spaces
3
  import gradio as gr
4
  import transformers
5
  import torch
 
6
  from huggingface_hub import login
7
+ import logging
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
  model_id = "meta-llama/Meta-Llama-3.1-8B"
13
 
14
  @spaces.GPU(duration=60)
15
  def load_pipeline():
16
+ try:
17
+ # Zaloguj się używając tokena
18
+ login(token=os.environ.get("MY_API_LLAMA_3_1"))
19
+ logger.info("Login successful")
20
 
21
+ if torch.cuda.is_available():
22
+ logger.info(f"GPU available: {torch.cuda.get_device_name(0)}")
23
+ device_map = "auto"
24
+ torch_dtype = torch.bfloat16
25
+ else:
26
+ logger.warning("No GPU available, using CPU")
27
+ device_map = "cpu"
28
+ torch_dtype = torch.float32
29
+
30
+ pipeline = transformers.pipeline(
31
+ "text-generation",
32
+ model=model_id,
33
+ model_kwargs={"torch_dtype": torch_dtype},
34
+ device_map=device_map
35
+ )
36
+ logger.info("Model loaded successfully")
37
+ return pipeline
38
+ except Exception as e:
39
+ logger.error(f"Error loading model: {str(e)}")
40
+ raise
41
 
42
+ try:
43
+ pipeline = load_pipeline()
44
+ except Exception as e:
45
+ logger.error(f"Failed to load pipeline: {str(e)}")
46
+ pipeline = None
47
 
48
  def generate_response(chat, kwargs):
49
+ if pipeline is None:
50
+ return "Model nie został załadowany poprawnie. Proszę spróbować później."
51
+
52
+ try:
53
+ output = pipeline(chat, **kwargs)[0]['generated_text']
54
+ if output.endswith("</s>"):
55
+ output = output[:-4]
56
+ return output
57
+ except Exception as e:
58
+ logger.error(f"Error generating response: {str(e)}")
59
+ return f"Wystąpił błąd podczas generowania odpowiedzi: {str(e)}"
60
 
61
  def function(prompt, history=[]):
62
  chat = "<s>"
 
73
  seed=1337
74
  )
75
 
76
+ return generate_response(chat, kwargs)
 
 
 
 
77
 
78
  # Interfejs Gradio
79
  interface = gr.ChatInterface(
requirements.txt CHANGED
@@ -4,3 +4,4 @@ numpy<2
4
  torch
5
  transformers
6
  bitsandbytes
 
 
4
  torch
5
  transformers
6
  bitsandbytes
7
+ accelerate