import gradio as gr import pandas as pd from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration # Load the tokenizer and retriever tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True) # Load the model model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) # Tokenize the contexts and responses inputs = tokenizer(contexts, return_tensors='pt', padding=True, truncation=True) labels = tokenizer(responses, return_tensors='pt', padding=True, truncation=True) # Load your dataset df = pd.read_csv('your_dataset.csv') # Ensure the dataset has the required columns for RAG # For example, it should have 'context' and 'response' columns contexts = df['Abstract'].tolist() #responses = df['response'].tolist() def generate_response(input_text): input_ids = tokenizer([input_text], return_tensors='pt')['input_ids'] outputs = model.generate(input_ids) response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] return response # Create the Gradio interface iface = gr.Interface( fn=generate_response, inputs="text", outputs="text", title="RAG Chatbot", description="A chatbot powered by Retrieval-Augmented Generation (RAG) model." ) # Launch the interface iface.launch()