bofenghuang commited on
Commit
b107fa5
1 Parent(s): 62deaa4
Files changed (2) hide show
  1. README.md +3 -3
  2. app.py +14 -20
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Vigogne-Chat
3
  emoji: 🦙
4
- colorFrom: purple
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.27.0
8
  app_file: app.py
 
1
  ---
2
+ title: Vigogne Chat
3
  emoji: 🦙
4
+ colorFrom: yellow
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.27.0
8
  app_file: app.py
app.py CHANGED
@@ -13,6 +13,8 @@ python vigogne/demo/demo_chat.py \
13
  --lora_model_name_or_path bofenghuang/vigogne-chat-7b
14
  """
15
 
 
 
16
  # import datetime
17
  import logging
18
  import os
@@ -20,10 +22,6 @@ import re
20
  from threading import Event, Thread
21
  from typing import List, Optional
22
 
23
-
24
- # from uuid import uuid4
25
-
26
- import json
27
  import gradio as gr
28
 
29
  # import requests
@@ -33,13 +31,16 @@ from transformers import (
33
  AutoModelForCausalLM,
34
  AutoTokenizer,
35
  GenerationConfig,
 
36
  StoppingCriteriaList,
37
  TextIteratorStreamer,
38
  )
39
-
40
  from vigogne.constants import ASSISTANT, USER
41
- from vigogne.preprocess import generate_inference_chat_prompt
42
  from vigogne.inference.inference_utils import StopWordsCriteria
 
 
 
 
43
 
44
  logging.basicConfig(
45
  format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
@@ -97,7 +98,9 @@ def main(
97
  server_port: Optional[str] = None,
98
  share: bool = False,
99
  ):
100
- tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, padding_side="right", use_fast=False)
 
 
101
 
102
  if device == "cuda":
103
  model = AutoModelForCausalLM.from_pretrained(
@@ -302,12 +305,7 @@ def main(
302
  elem_classes=["disclaimer"],
303
  )
304
 
305
- submit_event = msg.submit(
306
- fn=user,
307
- inputs=[msg, chatbot],
308
- outputs=[msg, chatbot],
309
- queue=False,
310
- ).then(
311
  fn=bot,
312
  inputs=[
313
  chatbot,
@@ -321,12 +319,7 @@ def main(
321
  outputs=chatbot,
322
  queue=True,
323
  )
324
- submit_click_event = submit.click(
325
- fn=user,
326
- inputs=[msg, chatbot],
327
- outputs=[msg, chatbot],
328
- queue=False,
329
- ).then(
330
  fn=bot,
331
  inputs=[
332
  chatbot,
@@ -352,4 +345,5 @@ def main(
352
  demo.queue(max_size=128, concurrency_count=2)
353
  demo.launch(enable_queue=True, share=share, server_name=server_name, server_port=server_port)
354
 
355
- main()
 
 
13
  --lora_model_name_or_path bofenghuang/vigogne-chat-7b
14
  """
15
 
16
+ import json
17
+
18
  # import datetime
19
  import logging
20
  import os
 
22
  from threading import Event, Thread
23
  from typing import List, Optional
24
 
 
 
 
 
25
  import gradio as gr
26
 
27
  # import requests
 
31
  AutoModelForCausalLM,
32
  AutoTokenizer,
33
  GenerationConfig,
34
+ LlamaTokenizer,
35
  StoppingCriteriaList,
36
  TextIteratorStreamer,
37
  )
 
38
  from vigogne.constants import ASSISTANT, USER
 
39
  from vigogne.inference.inference_utils import StopWordsCriteria
40
+ from vigogne.preprocess import generate_inference_chat_prompt
41
+
42
+ # from uuid import uuid4
43
+
44
 
45
  logging.basicConfig(
46
  format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
 
98
  server_port: Optional[str] = None,
99
  share: bool = False,
100
  ):
101
+ # tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, padding_side="right", use_fast=False)
102
+ tokenizer_class = LlamaTokenizer if "llama" in base_model_name_or_path else AutoTokenizer
103
+ tokenizer = tokenizer_class.from_pretrained(base_model_name_or_path, padding_side="right", use_fast=False)
104
 
105
  if device == "cuda":
106
  model = AutoModelForCausalLM.from_pretrained(
 
305
  elem_classes=["disclaimer"],
306
  )
307
 
308
+ submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False,).then(
 
 
 
 
 
309
  fn=bot,
310
  inputs=[
311
  chatbot,
 
319
  outputs=chatbot,
320
  queue=True,
321
  )
322
+ submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False,).then(
 
 
 
 
 
323
  fn=bot,
324
  inputs=[
325
  chatbot,
 
345
  demo.queue(max_size=128, concurrency_count=2)
346
  demo.launch(enable_queue=True, share=share, server_name=server_name, server_port=server_port)
347
 
348
+
349
+ main(base_model_name_or_path="decapoda-research/llama-7b-hf", lora_model_name_or_path="bofenghuang/vigogne-chat-7b")