Ilvir commited on
Commit
d0e1ddb
·
1 Parent(s): 194bdaf

Upload gpt (1).py

Browse files
Files changed (1) hide show
  1. pages/gpt (1).py +73 -0
pages/gpt (1).py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
+ import streamlit as st
3
+ import torch
4
+ import textwrap
5
+ import plotly.express as px
6
+
7
+ from streamlit_extras.let_it_rain import rain
8
+
9
+ rain(
10
+ emoji="⭐",
11
+ font_size=54,
12
+ falling_speed=5,
13
+ animation_length="infinite",
14
+ )
15
+
16
+ st.header(':green[Text generation by GPT2 model]')
17
+
18
+ tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')
19
+ model = GPT2LMHeadModel.from_pretrained(
20
+ 'sberbank-ai/rugpt3small_based_on_gpt2',
21
+ output_attentions = False,
22
+ output_hidden_states = False,
23
+ )
24
+
25
+ model.load_state_dict(torch.load('models/model.pt', map_location=torch.device('cpu')))
26
+
27
+
28
+ length = st.sidebar.slider('**Generated sequence length:**', 8, 256, 15)
29
+ if length > 100:
30
+ st.warning("This is very hard for me, please have pity on me. Could you lower the value?", icon="🤖")
31
+ num_samples = st.sidebar.slider('**Number of generations:**', 1, 10, 1)
32
+ if num_samples > 4:
33
+ st.warning("OH MY ..., I have to work late again!!! Could you lower the value?", icon="🤖")
34
+ temperature = st.sidebar.slider('**Temperature:**', 0.1, 10.0, 3.0)
35
+ if temperature > 6.0:
36
+ st.info('What? You want to get some kind of bullshit as a result? Turn down the temperature', icon="🤖")
37
+ top_k = st.sidebar.slider('**Number of most likely generation words:**', 10, 200, 50)
38
+ top_p = st.sidebar.slider('**Minimum total probability of top words:**', 0.4, 1.0, 0.9)
39
+
40
+
41
+ prompt = st.text_input('**Enter text 👇:**')
42
+ if st.button('**Generate text**'):
43
+ image_container = st.empty()
44
+ image_container.image("pict/wait.jpeg", caption="that's so long!!!", use_column_width=True)
45
+ with torch.inference_mode():
46
+ prompt = tokenizer.encode(prompt, return_tensors='pt')
47
+ out = model.generate(
48
+ input_ids=prompt,
49
+ max_length=length,
50
+ num_beams=8,
51
+ do_sample=True,
52
+ temperature=temperature,
53
+ top_k=top_k,
54
+ top_p=top_p,
55
+ no_repeat_ngram_size=3,
56
+ num_return_sequences=num_samples,
57
+ ).cpu().numpy()
58
+ image_container.empty()
59
+ st.write('**_Результат_** 👇')
60
+ for i, out_ in enumerate(out):
61
+ # audio_file = open('pict/pole-chudes-priz.mp3', 'rb')
62
+ # audio_bytes = audio_file.read()
63
+ # st.audio(audio_bytes, format='audio/mp3')
64
+
65
+ with st.expander(f'Текст {i+1}:'):
66
+ st.write(textwrap.fill(tokenizer.decode(out_), 100))
67
+ st.image("pict/wow.png")
68
+
69
+
70
+
71
+
72
+
73
+