DoctorSlimm commited on
Commit
16b4096
1 Parent(s): a8ec507

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -13,7 +13,7 @@ print(zero.device) # <-- 'cpu' 🤔
13
  # gpu
14
 
15
  @spaces.GPU
16
- def greet(user):
17
  # print(zero.device) # <-- 'cuda:0' 🤗
18
  from vllm import SamplingParams, LLM
19
  from transformers.utils import move_cache
@@ -30,18 +30,37 @@ def greet(user):
30
  max_tokens = int(512 * 2)
31
  )
32
  sampling_params = SamplingParams(**sampling_params)
33
-
34
- prompts = [user]
 
 
 
35
  model_outputs = model.generate(prompts, sampling_params)
36
  generations = []
37
  for output in model_outputs:
38
  for outputs in output.outputs:
39
  generations.append(outputs.text)
 
 
40
  return generations[0]
41
 
42
 
43
  ## make predictions via api ##
44
  # https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app
45
 
46
- demo = gr.Interface(fn=greet, inputs=gr.Text(), outputs=gr.Text())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  demo.launch(share=True)
 
13
  # gpu
14
 
15
  @spaces.GPU
16
+ def greet(prompts, separator):
17
  # print(zero.device) # <-- 'cuda:0' 🤗
18
  from vllm import SamplingParams, LLM
19
  from transformers.utils import move_cache
 
30
  max_tokens = int(512 * 2)
31
  )
32
  sampling_params = SamplingParams(**sampling_params)
33
+
34
+ multi_prompt = False
35
+ if separator in prompts:
36
+ multi_prompt = True
37
+ prompts = prompts.split('separator')
38
  model_outputs = model.generate(prompts, sampling_params)
39
  generations = []
40
  for output in model_outputs:
41
  for outputs in output.outputs:
42
  generations.append(outputs.text)
43
+ if multi_prompt:
44
+ return generations
45
  return generations[0]
46
 
47
 
48
  ## make predictions via api ##
49
  # https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app
50
 
51
+ demo = gr.Interface(
52
+ fn=greet,
53
+ inputs=[
54
+ gr.Text(
55
+ value='hello sir!<SEP>bonjour madame...',
56
+ placeholder='hello sir!<SEP>bonjour madame...',
57
+ label='list of prompts separated by separator'
58
+ ),
59
+ gr.Text(
60
+ value='<SEP>',
61
+ placeholder='<SEP>',
62
+ label='separator for your prompts'
63
+ )],
64
+ outputs=gr.Text()
65
+ )
66
  demo.launch(share=True)