gosha6037 commited on
Commit
776e43c
·
1 Parent(s): aae4195

Added bigscience/bloom-petals

Browse files
Files changed (1) hide show
  1. app.py +21 -12
app.py CHANGED
@@ -1,21 +1,26 @@
1
  import sys
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
  import torch
4
  import transformers
5
- import gradio as gr
6
 
7
  sys.path.insert(0, './petals/')
8
 
9
- from src.client.remote_model import DistributedBloomForCausalLM
10
-
11
- MODEL_NAME = "bigscience/test-bloomd-6b3" # select model you like
12
- INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"]
13
 
14
- tokenizer_bloomd_6b3 = transformers.BloomTokenizerFast.from_pretrained("bigscience/test-bloomd-6b3")
15
- model_bloomd_6b3 = DistributedBloomForCausalLM.from_pretrained("bigscience/test-bloomd-6b3",
16
- initial_peers=INITIAL_PEERS,
 
 
17
  low_cpu_mem_usage=True, torch_dtype=torch.float32)
18
 
 
 
 
 
 
19
 
20
  tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
21
  model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
@@ -47,8 +52,11 @@ def predict(
47
  model = model_DialoGPT_large
48
  tokenizer = tokenizer_DialoGPT_large
49
  elif model_name == 'test-bloomd-6b3':
50
- model = tokenizer_bloomd_6b3
51
- tokenizer = model_bloomd_6b3
 
 
 
52
  else:
53
  model = model_DialoGPT_medium
54
  tokenizer = tokenizer_DialoGPT_medium
@@ -81,7 +89,8 @@ gr.Interface(
81
  'DialoGPT-small',
82
  'DialoGPT-medium',
83
  'DialoGPT-large',
84
- 'test-bloomd-6b3'
 
85
  ]
86
  ),
87
  gr.Radio(
 
1
  import sys
2
+
3
+ import gradio as gr
4
  import torch
5
  import transformers
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
  sys.path.insert(0, './petals/')
9
 
10
+ from petals.client.remote_model import DistributedBloomForCausalLM
 
 
 
11
 
12
+ MODEL_NAME = "bigscience/test-bloomd-6b3"
13
+ # INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"]
14
+ tokenizer_bloomd_6b3 = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
15
+ model_bloomd_6b3 = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME,
16
+ # initial_peers=INITIAL_PEERS,
17
  low_cpu_mem_usage=True, torch_dtype=torch.float32)
18
 
19
+ MODEL_NAME = "bigscience/bloom-petals"
20
+ tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
21
+ model_bloomd = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME,
22
+ low_cpu_mem_usage=True, torch_dtype=torch.float32)
23
+
24
 
25
  tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
26
  model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
 
52
  model = model_DialoGPT_large
53
  tokenizer = tokenizer_DialoGPT_large
54
  elif model_name == 'test-bloomd-6b3':
55
+ model = model_bloomd_6b3
56
+ tokenizer = tokenizer_bloomd_6b3
57
+ elif model_name == 'bloom-petals':
58
+ model = model_bloomd
59
+ tokenizer = tokenizer_bloomd
60
  else:
61
  model = model_DialoGPT_medium
62
  tokenizer = tokenizer_DialoGPT_medium
 
89
  'DialoGPT-small',
90
  'DialoGPT-medium',
91
  'DialoGPT-large',
92
+ 'test-bloomd-6b3',
93
+ 'bloom-petals',
94
  ]
95
  ),
96
  gr.Radio(