MNGames commited on
Commit
e06d0f8
1 Parent(s): 41efb19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -10,20 +10,22 @@ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
 
11
  def generate_questions(email):
12
  """Generates questions based on the input email."""
13
- # Encode the email and prompt together with tokenizer
14
- inputs = tokenizer(email, return_tensors="pt", add_special_tokens=True)
15
 
16
- # Convert tensor to list before concatenation
17
- inputs["input_ids"] = [tokenizer.cls_token_id] + inputs["input_ids"].tolist()
 
 
18
 
19
  # Generate questions using model
20
  generation = model.generate(
21
  **inputs, # Unpack the entire inputs dictionary
22
  max_length=256, # Adjust max length as needed
23
- num_beams=5, # Adjust beam search for better quality (slower)
24
- early_stopping=True,
25
  )
26
 
 
 
27
  # Decode the generated text
28
  return tokenizer.decode(generation[0], skip_special_tokens=True)
29
 
 
10
 
11
  def generate_questions(email):
12
  """Generates questions based on the input email."""
13
+ # ... (existing code for encoding the email)
 
14
 
15
+ # Check length instead of shape
16
+ if len(inputs["input_ids"]) > 512: # Adjust maximum sequence length as needed
17
+ print("WARNING: Input sequence exceeds maximum length. Truncating.")
18
+ inputs["input_ids"] = inputs["input_ids"][:512]
19
 
20
  # Generate questions using model
21
  generation = model.generate(
22
  **inputs, # Unpack the entire inputs dictionary
23
  max_length=256, # Adjust max length as needed
24
+ # ... (other generation parameters)
 
25
  )
26
 
27
+ # ... (existing code for decoding the generation)
28
+
29
  # Decode the generated text
30
  return tokenizer.decode(generation[0], skip_special_tokens=True)
31