from fastapi import FastAPI, Query | |
from transformers import pipeline | |
app = FastAPI() | |
def initialize_pipeline(): | |
return pipeline("text-classification", model="FacebookAI/roberta-large-mnli") | |
pipe = initialize_pipeline() | |
# @app.get("/docs") | |
# def home(): | |
# return {"message": "Hello Siddhant"} | |
def generate_text( | |
text: str = Query(None, description="Input text to generate from"), | |
prompt: str = Query(None, description="Optional prompt for fine-tuning the generated text"), | |
): | |
if not text and not prompt: | |
return {"error": "Please provide either 'text' or 'prompt' parameter."} | |
if prompt: | |
input_text = f"{text} {prompt}" if text else prompt | |
else: | |
input_text = text | |
output = pipe(input_text, max_length=100, do_sample=True, top_k=50) | |
return {"input_text": input_text, "output": output[0]["generated_text"]} | |