barathm111 commited on
Commit
6bfb9ac
·
verified ·
1 Parent(s): 3a487a7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -11
app.py CHANGED
@@ -2,26 +2,36 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import pipeline
4
 
5
- # Create a new FastAPI app instance
6
  app = FastAPI()
7
 
8
  # Initialize the text generation pipeline
9
  pipe = pipeline("text-generation", model="defog/sqlcoder-7b-2", pad_token_id=2)
10
 
 
 
 
11
  @app.get("/")
12
  def home():
13
- return {"message": "Hello World"}
14
 
15
- # Define a function to handle the POST request at '/generate'
16
  @app.post("/generate")
17
- def generate(request: dict):
18
  try:
19
- text = request.get('text')
20
- if not text:
21
- raise HTTPException(status_code=400, detail="Text field is required")
 
 
 
22
 
23
- output = pipe(text, max_new_tokens=50)
24
- # Return the generated text in JSON response
25
- return {"output": output[0]['generated_text']}
 
 
26
  except Exception as e:
27
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
2
  from pydantic import BaseModel
3
  from transformers import pipeline
4
 
 
5
  app = FastAPI()
6
 
7
  # Initialize the text generation pipeline
8
  pipe = pipeline("text-generation", model="defog/sqlcoder-7b-2", pad_token_id=2)
9
 
10
+ class QueryRequest(BaseModel):
11
+ text: str
12
+
13
  @app.get("/")
14
  def home():
15
+ return {"message": "SQL Generation Server is running"}
16
 
 
17
  @app.post("/generate")
18
+ def generate(request: QueryRequest):
19
  try:
20
+ text = request.text
21
+ prompt = f"Generate a valid SQL query for the following request. Only return the SQL query, nothing else:\n\n{text}\n\nSQL query:"
22
+ output = pipe(prompt, max_new_tokens=100)
23
+
24
+ generated_text = output[0]['generated_text']
25
+ sql_query = generated_text.split("SQL query:")[-1].strip()
26
 
27
+ # Basic validation
28
+ if not sql_query.lower().startswith(('select', 'show', 'describe')):
29
+ raise ValueError("Generated text is not a valid SQL query")
30
+
31
+ return {"output": sql_query}
32
  except Exception as e:
33
+ raise HTTPException(status_code=500, detail=str(e))
34
+
35
+ if __name__ == "__main__":
36
+ import uvicorn
37
+ uvicorn.run(app, host="0.0.0.0", port=7860)