|
|
|
import streamlit as st |
|
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration |
|
|
|
|
|
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") |
|
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") |
|
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True) |
|
|
|
|
|
st.title("AI Health Assistant (RAG-based)") |
|
|
|
def get_answer_rag(question): |
|
inputs = tokenizer(question, return_tensors="pt") |
|
retrieved_docs = retriever.retrieve(inputs['input_ids'], top_k=3) |
|
outputs = model.generate(input_ids=inputs['input_ids'], context_input_ids=retrieved_docs['input_ids']) |
|
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return answer |
|
|
|
|
|
question = st.text_input("Ask a health-related question:") |
|
if question: |
|
answer = get_answer_rag(question) |
|
st.write(f"Answer: {answer}") |
|
|