Ethanmaht's picture
Update app.py
5ea3b28 verified
raw
history blame
1.38 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModel
import torch
from flask import Flask, request, jsonify
# 1. 加载模型和分词器
model_name = "jinaai/jina-embeddings-v3" # 替换为您实际使用的模型名
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
app = Flask(__name__)
# 2. 定义生成嵌入的函数
def generate_embeddings(text):
# 使用分词器处理输入文本
inputs = tokenizer(text, return_tensors="pt")
# 禁用梯度计算,以减少资源消耗
with torch.no_grad():
# 获取最后一层隐藏状态并计算平均值作为嵌入
embeddings = model(**inputs).last_hidden_state.mean(dim=1)
# 将嵌入转换为Python列表,方便Gradio输出
return embeddings.numpy().tolist()
@app.route('/api/v1/embeddings', methods=['POST'])
def embedding():
_embedding_data = []
data = request.json # 获取 JSON 数据
headers = request.headers
input_text_list = data.get('embeddings', [])
for _ in input_text_list:
_embedding_data.append(generate_embeddings(_))
return jsonify({
"embeddings": _embedding_data,
"model": model_name // 使用的模型
})
# 4. 启动Gradio应用
if __name__ == "__main__":
app.run(debug=False, port=7860)