File size: 2,844 Bytes
6c0ac6b
142b484
 
6c0ac6b
 
 
 
015696a
 
 
 
 
6c0ac6b
 
142b484
 
 
 
 
 
6c0ac6b
 
 
 
 
 
 
 
 
015696a
 
 
 
 
 
6c0ac6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142b484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException, Query
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware

from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_google_genai import (
    ChatGoogleGenerativeAI,
    HarmBlockThreshold,
    HarmCategory,
)
from TextGen import app

my_hf_token=os.environ["HF_TOKEN"]

tts_client = Client("https://jofthomas-xtts.hf.space/--replicas/sxv98/",hf_token=my_hf_token)



class Generate(BaseModel):
    text:str

def generate_text(prompt: str):
    if prompt == "":
        return {"detail": "Please provide a prompt."}
    else:
        prompt = PromptTemplate(template=prompt, input_variables=['Prompt'])

        # Initialize the LLM
        llm = ChatGoogleGenerativeAI(
            model="gemini-pro",
            safety_settings={
                HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
            },
        )

        llmchain = LLMChain(
            prompt=prompt,
            llm=llm
        )

        llm_response = llmchain.run({"Prompt": prompt})
        return Generate(text=llm_response)

        

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/", tags=["Home"])
def api_home():
    return {'detail': 'Welcome to FastAPI TextGen Tutorial!'}

@app.post("/api/generate", summary="Generate text from prompt", tags=["Generate"], response_model=Generate)
def inference(input_prompt: str):
    return generate_text(prompt=input_prompt)

@app.get("/generate_wav")
async def generate_wav(text: str, language: str = "en"):
    try:
        # Use the Gradio client to generate the wav file
        result = tts_client.predict(
            text,  # str in 'Text Prompt' Textbox component
            language,  # str in 'Language' Dropdown component
            "./narator_out.wav",  # str (filepath on your computer (or URL) of file) in 'Reference Audio' Audio component
            "./narator_out.wav",  # str (filepath on your computer (or URL) of file) in 'Use Microphone for Reference' Audio component
            False,  # bool in 'Use Microphone' Checkbox component
            False,  # bool in 'Cleanup Reference Voice' Checkbox component
            False,  # bool in 'Do not use language auto-detect' Checkbox component
            True,  # bool in 'Agree' Checkbox component
            fn_index=1
        )

        # Get the path of the generated wav file
        wav_file_path = result[1]

        # Return the generated wav file as a response
        return FileResponse(wav_file_path, media_type="audio/wav", filename="output.wav")

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))