SAMBOOM commited on
Commit
ff9d849
1 Parent(s): 921e88c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -44
app.py CHANGED
@@ -1,46 +1,59 @@
1
- import transformers
2
- import streamlit as st
3
- from transformers import AutoTokenizer, AutoModelWithLMHead
4
-
5
- tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
6
- @st.cache_data
7
- def load_model(model_name):
8
- model = AutoModelWithLMHead.from_pretrained(model_name)
9
- return model
10
-
11
- model = load_model("gpt2-large")
12
-
13
- def infer(sent, max_length, temperature, top_k, top_p):
14
- input_ids = tokenizer.encode(sent, return_tensors="pt")
15
- output_sequences = model.generate(
16
- input_ids=input_ids,
17
- max_length=max_length,
18
- temperature=temperature,
19
- top_k=top_k,
20
- top_p=top_p,
21
- do_sample=True,
22
- num_return_sequences=1
23
- )
24
-
25
- return output_sequences
26
-
27
- default_value = "You: Ask me anything!"
 
 
28
 
29
- #prompts
30
- st.title("Chat with GPT-2 💬")
31
- st.write("GPT-2 is a large transformer-based language model with 1.5 billion parameters. It is trained to predict the next word in a sentence, given all of the previous words. This makes it great for text generation and for answering questions about the text it's given.")
32
-
33
- messages = [{"role": "system", "content": "You are a helpful assistant."}]
34
-
35
- user_input = st.text_input("You:", default_value)
36
- if user_input:
37
- messages.append({"role": "user", "content": user_input})
38
-
39
- output_sequences = infer(user_input, max_length=100, temperature=0.7, top_k=40, top_p=0.9)
40
- generated_sequence = output_sequences[0].tolist()
41
- generated_text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
42
-
43
- messages.append({"role": "assistant", "content": generated_text})
44
 
45
- for message in messages:
46
- st.write(f"{message['role']}: {message['content']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ from PIL import Image
6
+
7
+ sys.path.append(".")
8
+ os.environ["HUGGINGFACE_HTTPS_PROXY"] = "" # Disable HTTPS proxy when not required
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("./model_dir")
11
+ model = AutoModelForSeq2SeqLM.from_pretrained("./model_dir", device_map="auto").half()
12
+ device = next(iter(model.parameters())).device
13
+
14
+ def generate_response(input_text):
15
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
16
+ outputs = model.generate(
17
+ input_ids,
18
+ max_length=512,
19
+ num_beams=5,
20
+ early_stopping=True,
21
+ pad_token_id=tokenizer.pad_token_id,
22
+ eos_token_id=tokenizer.eos_token_id,
23
+ length_penalty=1.0,
24
+ no_repeat_ngram_size=2,
25
+ min_length=10,
26
+ temperature=0.9,
27
+ )
28
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
+ return generated_text
30
 
31
+ import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ st.set_page_config(layout="wide")
34
+ col1, col2 = st.beta_columns((3, 1))
35
+
36
+ with open('style.css') as f:
37
+ st.markdown(f'<style>{f.read()}<\style>', unsafe_allow_html=True)
38
+
39
+ with col1:
40
+ user_prompt = st.text_area("You:", "", height=50)
41
+
42
+ if 'generated' not in st.session_state:
43
+ st.session_state['generated'] = ''
44
+
45
+ if len(user_prompt) > 0 and st.button('Send'):
46
+ response = generate_response(user_prompt)
47
+ st.write('<span style="font-weight:bold;">Assistant:</span>
48
+ ' + response, unsafe_allow_html=True)
49
+ st.session_state['generated'] += '\n\n<span style="font-weight:bold;">User:</span>'+'\n'+ user_prompt + '\n'
50
+ st.session_state['generated'] += '<span style="font-weight:bold;">Assistant:</span>\n' + response
51
+
52
+ if 'generated' in st.session_state:
53
+ message = st.session_state['generated'].replace('\n', '
54
+ ').replace('<span style="font-weight: bold;">User:</span>', '&uarr;').replace('<span style="font-weight: bold;">Assistant:</span>', '')
55
+ st.markdown(message, unsafe_allow_html=True)
56
+
57
+ # Upload logo
58
+ logo = Image.open("your_logo.png")
59
+ st.sidebar.image(logo, width=160)