File size: 3,160 Bytes
74b889b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import argparse
import markdown2
import sys
import uvicorn
from pathlib import Path
from typing import Union, Optional
from fastapi import FastAPI
from pydantic import BaseModel, Field
from fastapi.responses import HTMLResponse
from tclogger import logger, OSEnver
from transforms.embed import JinaAIEmbedder
from configs.constants import AVAILABLE_MODELS
info_path = Path(__file__).parent / "configs" / "info.json"
ENVER = OSEnver(info_path)
class EmbeddingApp:
def __init__(self):
self.app = FastAPI(
docs_url="/",
title=ENVER["app_name"],
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
version=ENVER["version"],
)
self.embedder = JinaAIEmbedder()
self.setup_routes()
def get_available_models(self):
return AVAILABLE_MODELS
def get_readme(self):
readme_path = Path(__file__).parents[1] / "README.md"
with open(readme_path, "r", encoding="utf-8") as rf:
readme_str = rf.read()
readme_html = markdown2.markdown(
readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"]
)
return readme_html
class EncodePostItem(BaseModel):
text: Union[str, list[str]] = Field(
default=None,
summary="Input text(s) to embed",
)
model: Optional[str] = Field(
default=AVAILABLE_MODELS[0],
summary="Embedding model name",
)
def encode(self, item: EncodePostItem):
logger.note(f"> Encoding text: [{item.text}]", end=" ")
if item.model != self.embedder.model:
self.embedder.switch_model(item.model)
embeddings = self.embedder.encode(item.text).tolist()
logger.success(f"[{len(embeddings[0])}]")
if len(embeddings) == 1:
return embeddings[0]
else:
return embeddings
def setup_routes(self):
self.app.get(
"/models",
summary="Get available models",
)(self.get_available_models)
self.app.post(
"/encode",
summary="Encode embedding for input text",
)(self.encode)
self.app.get(
"/readme",
summary="README of HF LLM API",
response_class=HTMLResponse,
include_in_schema=False,
)(self.get_readme)
class ArgParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super(ArgParser, self).__init__(*args, **kwargs)
self.add_argument(
"-s",
"--server",
type=str,
default=ENVER["server"],
help=f"Server IP ({ENVER['server']}) for Embedding API",
)
self.add_argument(
"-p",
"--port",
type=int,
default=ENVER["port"],
help=f"Server Port ({ENVER['port']}) for Embedding API",
)
self.args = self.parse_args(sys.argv[1:])
app = EmbeddingApp().app
if __name__ == "__main__":
args = ArgParser().args
uvicorn.run("__main__:app", host=args.server, port=args.port)
# python -m app
|