from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from gradio import Interface # Define the model name (change if desired) model_name = "facebook/bart-base" # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) def generate_questions(email): """Generates questions based on the input email.""" # ... (existing code for encoding the email) inputs = tokenizer(email, return_tensors="pt") # Check length instead of shape if len(inputs["input_ids"]) > 512: # Adjust maximum sequence length as needed print("WARNING: Input sequence exceeds maximum length. Truncating.") inputs["input_ids"] = inputs["input_ids"][:512] # Generate questions using model generation = model.generate( **inputs, # Unpack the entire inputs dictionary max_length=256, # Adjust max length as needed # ... (other generation parameters) ) # ... (existing code for decoding the generation) # Decode the generated text return tokenizer.decode(generation[0], skip_special_tokens=True) def generate_answers(questions): """Generates possible answers to the input questions.""" # Encode each question with tokenizer, separated by newline inputs = tokenizer("\n".join(questions), return_tensors="pt") # Generate answers using model with specific prompt generation = model.generate( input_ids=inputs["input_ids"], max_length=512, # Adjust max length as needed num_beams=3, # Adjust beam search for better quality (slower) early_stopping=True, prompt="Here are some possible answers to the questions:\n", ) # Decode the generated text answers = tokenizer.decode(generation[0], skip_special_tokens=True).split("\n") return zip(questions, answers[1:]) # Skip the first answer (prompt repetition) def gradio_app(email): """Gradio interface function""" questions = generate_questions(email) answers = generate_answers(questions.split("\n")) return questions, [answer for _, answer in answers] # Gradio interface definition # Gradio interface definition (without label) interface = Interface( fn=gradio_app, inputs="textbox", outputs=["text", "text"], title="AI Email Assistant", description="Enter a long email and get questions and possible answers generated by an AI model.", elem_id="email-input" ) # Launch the Gradio interface interface.launch()