SAMBOOM commited on
Commit
581b122
1 Parent(s): 7e751ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -56
app.py CHANGED
@@ -1,59 +1,62 @@
1
- import os
2
- import sys
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
- from PIL import Image
6
-
7
- sys.path.append(".")
8
- os.environ["HUGGINGFACE_HTTPS_PROXY"] = "" # Disable HTTPS proxy when not required
9
-
10
- tokenizer = AutoTokenizer.from_pretrained("./model_dir")
11
- model = AutoModelForSeq2SeqLM.from_pretrained("./model_dir", device_map="auto").half()
12
- device = next(iter(model.parameters())).device
13
-
14
- def generate_response(input_text):
15
- input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
16
- outputs = model.generate(
17
- input_ids,
18
- max_length=512,
19
- num_beams=5,
20
- early_stopping=True,
21
- pad_token_id=tokenizer.pad_token_id,
22
- eos_token_id=tokenizer.eos_token_id,
23
- length_penalty=1.0,
24
- no_repeat_ngram_size=2,
25
- min_length=10,
26
- temperature=0.9,
27
  )
28
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
- return generated_text
30
 
31
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- st.set_page_config(layout="wide")
34
- col1, col2 = st.beta_columns((3, 1))
35
-
36
- with open('style.css') as f:
37
- st.markdown(f'<style>{f.read()}<\style>', unsafe_allow_html=True)
38
-
39
- with col1:
40
- user_prompt = st.text_area("You:", "", height=50)
41
-
42
- if 'generated' not in st.session_state:
43
- st.session_state['generated'] = ''
44
-
45
- if len(user_prompt) > 0 and st.button('Send'):
46
- response = generate_response(user_prompt)
47
- st.write('<span style="font-weight:bold;">Assistant:</span>
48
- ' + response, unsafe_allow_html=True)
49
- st.session_state['generated'] += '\n\n<span style="font-weight:bold;">User:</span>'+'\n'+ user_prompt + '\n'
50
- st.session_state['generated'] += '<span style="font-weight:bold;">Assistant:</span>\n' + response
51
-
52
- if 'generated' in st.session_state:
53
- message = st.session_state['generated'].replace('\n', '
54
- ').replace('<span style="font-weight: bold;">User:</span>', '&uarr;').replace('<span style="font-weight: bold;">Assistant:</span>', '')
55
- st.markdown(message, unsafe_allow_html=True)
56
-
57
- # Upload logo
58
- logo = Image.open("your_logo.png")
59
- st.sidebar.image(logo, width=160)
 
1
+ import streamlit as st
2
+ from huggingface_hub import InferenceClient
3
+
4
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
5
+
6
+ def format_prompt(message, history):
7
+ prompt = "<s>"
8
+ for user_prompt, bot_response in history:
9
+ prompt += f"[INST] {user_prompt} [/INST]"
10
+ prompt += f" {bot_response}</s> "
11
+ prompt += f"[INST] {message} [/INST]"
12
+ return prompt
13
+
14
+ def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
15
+ temperature = float(temperature)
16
+ if temperature < 1e-2:
17
+ temperature = 1e-2
18
+ top_p = float(top_p)
19
+
20
+ generate_kwargs = dict(
21
+ temperature=temperature,
22
+ max_new_tokens=max_new_tokens,
23
+ top_p=top_p,
24
+ repetition_penalty=repetition_penalty,
25
+ do_sample=True,
26
+ seed=42,
27
  )
 
 
28
 
29
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
30
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
31
+ output = ""
32
+
33
+ for response in stream:
34
+ output += response.token.text
35
+ yield output
36
+ return output
37
+
38
+ # Create text input for user message
39
+ message_input = st.text_input("You:", "")
40
+
41
+ # Create text input for system prompt
42
+ system_prompt_input = st.text_input("System Prompt:", "You are a helpful assistant.")
43
+
44
+ # Create sliders for temperature, max new tokens, top-p, and repetition penalty
45
+ temperature_slider = st.slider("Temperature", 0.0, 1.0, 0.9)
46
+ max_new_tokens_slider = st.slider("Max new tokens", 0, 1048, 256)
47
+ top_p_slider = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95)
48
+ repetition_penalty_slider = st.slider("Repetition penalty", 1.0, 2.0, 1.0)
49
+
50
+ # Create button to generate response
51
+ if st.button("Generate"):
52
+ # Create empty list to store conversation history
53
+ history = []
54
+
55
+ # Call generate function with user message, system prompt, and slider values
56
+ output = generate(message_input, history, system_prompt_input, temperature=temperature_slider, max_new_tokens=max_new_tokens_slider, top_p=top_p_slider, repetition_penalty=repetition_penalty_slider)
57
+
58
+ # Display generated response
59
+ st.write("Assistant:", output)
60
 
61
+ # Add user message and generated response to conversation history
62
+ history.append((message_input, output))