jmdu's picture
Create app.py
56ba8e8 verified
raw
history blame
No virus
1.06 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModel
import torch
# Load the model and tokenizer
@st.cache(allow_output_mutation=True)
def load_model():
tokenizer = AutoTokenizer.from_pretrained("Salesforce/SFR-Embedding-Mistral")
model = AutoModel.from_pretrained("Salesforce/SFR-Embedding-Mistral")
return tokenizer, model
tokenizer, model = load_model()
def embed_text(text):
inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=32768)
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).detach().numpy()
def main():
st.title("Text Embedding using Salesforce/SFR-Embedding-Mistral")
# Text input
text = st.text_area("Enter text here:", height=150)
if st.button("Get Embeddings"):
if text:
with st.spinner('Fetching embeddings...'):
embeddings = embed_text(text)
st.write(embeddings)
else:
st.warning("Please enter some text to process.")
if __name__ == "__main__":
main()