EmTpro01 commited on
Commit
2a8e7a8
β€’
1 Parent(s): aacca72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -29
app.py CHANGED
@@ -1,41 +1,56 @@
1
  import streamlit as st
2
- from transformers import T5Tokenizer, T5ForConditionalGeneration
3
  import torch
 
4
 
5
- # Load the model and tokenizer with CPU optimization
 
 
 
 
 
 
 
6
  @st.cache_resource
7
  def load_model():
8
  model_name = "flax-community/t5-recipe-generation"
9
- tokenizer = T5Tokenizer.from_pretrained(model_name)
10
- model = T5ForConditionalGeneration.from_pretrained(model_name)
11
-
12
- # Explicitly set to CPU and use float32 to reduce memory usage
13
- model = model.to('cpu').float()
14
-
15
- return tokenizer, model
 
 
 
 
16
 
17
- # Generate recipe function with CPU-friendly generation
18
  def generate_recipe(ingredients, tokenizer, model, max_length=512):
19
  # Prepare input
20
  input_text = f"Generate recipe with: {ingredients}"
21
 
22
- # Use torch no_grad to reduce memory consumption
23
- with torch.no_grad():
24
- input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_length, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Adjust generation parameters for faster CPU inference
27
- output_ids = model.generate(
28
- input_ids,
29
- max_length=max_length,
30
- num_return_sequences=1,
31
- no_repeat_ngram_size=2,
32
- num_beams=4, # Reduced beam search for faster CPU processing
33
- early_stopping=True
34
- )
35
-
36
- # Decode and clean the output
37
- recipe = tokenizer.decode(output_ids[0], skip_special_tokens=True)
38
- return recipe
39
 
40
  # Streamlit app
41
  def main():
@@ -57,9 +72,12 @@ def main():
57
  with st.spinner("Generating recipe..."):
58
  recipe = generate_recipe(ingredients_input, tokenizer, model)
59
 
60
- # Display recipe sections
61
- st.subheader("πŸ₯˜ Generated Recipe")
62
- st.write(recipe)
 
 
 
63
  else:
64
  st.warning("Please enter some ingredients!")
65
 
 
1
  import streamlit as st
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ # Ensure SentencePiece is installed
6
+ try:
7
+ import sentencepiece
8
+ except ImportError:
9
+ st.error("SentencePiece is not installed. Please install it using: pip install sentencepiece")
10
+ st.stop()
11
+
12
+ # Load the model and tokenizer with caching
13
  @st.cache_resource
14
  def load_model():
15
  model_name = "flax-community/t5-recipe-generation"
16
+ try:
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
19
+
20
+ # Explicitly set to CPU and use float32 to reduce memory usage
21
+ model = model.to('cpu').float()
22
+
23
+ return tokenizer, model
24
+ except Exception as e:
25
+ st.error(f"Error loading model: {e}")
26
+ st.stop()
27
 
28
+ # Generate recipe function with error handling
29
  def generate_recipe(ingredients, tokenizer, model, max_length=512):
30
  # Prepare input
31
  input_text = f"Generate recipe with: {ingredients}"
32
 
33
+ try:
34
+ # Use torch no_grad to reduce memory consumption
35
+ with torch.no_grad():
36
+ input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=max_length, truncation=True)
37
+
38
+ # Adjust generation parameters for faster CPU inference
39
+ output_ids = model.generate(
40
+ input_ids,
41
+ max_length=max_length,
42
+ num_return_sequences=1,
43
+ no_repeat_ngram_size=2,
44
+ num_beams=4, # Reduced beam search for faster CPU processing
45
+ early_stopping=True
46
+ )
47
 
48
+ # Decode and clean the output
49
+ recipe = tokenizer.decode(output_ids[0], skip_special_tokens=True)
50
+ return recipe
51
+ except Exception as e:
52
+ st.error(f"Error generating recipe: {e}")
53
+ return None
 
 
 
 
 
 
 
54
 
55
  # Streamlit app
56
  def main():
 
72
  with st.spinner("Generating recipe..."):
73
  recipe = generate_recipe(ingredients_input, tokenizer, model)
74
 
75
+ if recipe:
76
+ # Display recipe sections
77
+ st.subheader("πŸ₯˜ Generated Recipe")
78
+ st.write(recipe)
79
+ else:
80
+ st.error("Failed to generate recipe. Please try again.")
81
  else:
82
  st.warning("Please enter some ingredients!")
83