wiklif commited on
Commit
7938810
·
1 Parent(s): f9490b8

pierwszy commit

Browse files
Files changed (3) hide show
  1. a.py +8 -0
  2. app.py +72 -4
  3. requirements.txt +6 -0
a.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client
2
+
3
+ client = Client("wiklif/my-api")
4
+ result = client.predict(
5
+ prompt="Jakie są 3 największe kraje? Pisz po polsku.",
6
+ api_name="/chat"
7
+ )
8
+ print(result)
app.py CHANGED
@@ -1,7 +1,75 @@
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
  import gradio as gr
3
+ import transformers
4
+ import torch
5
 
6
+ model_id = "meta-llama/Meta-Llama-3.1-8B"
 
7
 
8
+ @spaces.GPU(duration=60)
9
+ def load_pipeline():
10
+ return transformers.pipeline(
11
+ "text-generation",
12
+ model=model_id,
13
+ model_kwargs={"torch_dtype": torch.bfloat16},
14
+ device_map="auto"
15
+ )
16
+
17
+ pipeline = load_pipeline()
18
+
19
+ def generate_response(chat, kwargs):
20
+ output = pipeline(chat, **kwargs)[0]['generated_text']
21
+ if output.endswith("</s>"):
22
+ output = output[:-4]
23
+ return output
24
+
25
+ def function(prompt, history=[]):
26
+ chat = "<s>"
27
+ for user_prompt, bot_response in history:
28
+ chat += f"[INST] {user_prompt} [/INST] {bot_response}</s> <s>"
29
+ chat += f"[INST] {prompt} [/INST]"
30
+
31
+ kwargs = dict(
32
+ max_new_tokens=4096,
33
+ do_sample=True,
34
+ temperature=0.5,
35
+ top_p=0.95,
36
+ repetition_penalty=1.0,
37
+ seed=1337
38
+ )
39
+
40
+ try:
41
+ output = generate_response(chat, kwargs)
42
+ return output
43
+ except:
44
+ return ''
45
+
46
+ # Interfejs Gradio
47
+ interface = gr.ChatInterface(
48
+ fn=function,
49
+ chatbot=gr.Chatbot(
50
+ avatar_images=None,
51
+ container=False,
52
+ show_copy_button=True,
53
+ layout='bubble',
54
+ render_markdown=True,
55
+ line_breaks=True
56
+ ),
57
+ css='h1 {font-size:22px;} h2 {font-size:20px;} h3 {font-size:18px;} h4 {font-size:16px;}',
58
+ autofocus=True,
59
+ fill_height=True,
60
+ analytics_enabled=False,
61
+ submit_btn='Chat',
62
+ stop_btn=None,
63
+ retry_btn=None,
64
+ undo_btn=None,
65
+ clear_btn=None
66
+ )
67
+
68
+ # API endpoint
69
+ def api_predict(prompt):
70
+ return function(prompt)
71
+
72
+ interface.launch(show_api=True, share=True)
73
+
74
+ # Dodanie endpointu API
75
+ gr.Interface(fn=api_predict, inputs="text", outputs="text").launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ gradio
3
+ numpy<2
4
+ torch
5
+ transformers
6
+ bitsandbytes