Shabbir-Anjum commited on
Commit
7d671a1
·
verified ·
1 Parent(s): 58490f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -1,28 +1,31 @@
1
  import streamlit as st
2
- from transformers import DiffusionPipeline
3
 
4
- # Load the Diffusion pipeline
5
- pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium")
6
 
7
  def generate_prompt(prompt_text):
8
  # Generate response using the Diffusion model
9
- response = pipeline(prompt=prompt_text, top_p=0.9, num_return_sequences=1)[0]['generated_text']
10
  return response
11
 
12
- # Streamlit app UI
13
- st.title('Diffusion Model Prompt Generator')
14
-
15
- # Input prompt from user
16
- prompt_input = st.text_area('Enter your prompt here:', height=100)
17
-
18
- # Generate button
19
- if st.button('Generate'):
20
- if prompt_input:
21
- with st.spinner('Generating...'):
22
- generated_text = generate_prompt(prompt_input)
23
- st.success('Generation complete!')
24
- st.text_area('Generated Text:', value=generated_text, height=200)
25
- else:
26
- st.warning('Please enter a prompt to generate.')
 
 
 
27
 
28
 
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
 
4
+ # Load the Diffusion pipeline for text generation
5
+ generator = pipeline("text-generation", model="stabilityai/stable-diffusion-3-medium")
6
 
7
  def generate_prompt(prompt_text):
8
  # Generate response using the Diffusion model
9
+ response = generator(prompt_text, top_p=0.9, max_length=100)[0]['generated_text']
10
  return response
11
 
12
+ def main():
13
+ st.title('Diffusion Model Prompt Generator')
14
+
15
+ # Text input for the prompt
16
+ prompt_text = st.text_area("Enter your prompt here:", height=200)
17
+
18
+ # Button to generate prompt
19
+ if st.button("Generate"):
20
+ if prompt_text:
21
+ with st.spinner('Generating...'):
22
+ generated_text = generate_prompt(prompt_text)
23
+ st.success('Generation complete!')
24
+ st.text_area('Generated Text:', value=generated_text, height=400)
25
+ else:
26
+ st.warning('Please enter a prompt.')
27
+
28
+ if __name__ == '__main__':
29
+ main()
30
 
31