BeveledCube commited on
Commit
bf5c1c9
·
verified ·
1 Parent(s): 6e0a07a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +32 -27
main.py CHANGED
@@ -1,46 +1,51 @@
 
 
 
 
 
1
  import os
2
- from flask import Flask, request, jsonify, render_template
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
  import torch
5
 
6
- app = Flask("Response API")
7
- name = "microsoft/DialoGPT-medium"
8
- # microsoft/DialoGPT-small
9
- # microsoft/DialoGPT-medium
10
- # microsoft/DialoGPT-large
11
 
12
- # Load the Hugging Face GPT-2 model and tokenizer
13
- model = GPT2LMHeadModel.from_pretrained(name)
14
- tokenizer = GPT2Tokenizer.from_pretrained(name)
15
 
16
- # Using CUDA for an optimal experience
17
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
- model = model.to(device)
19
 
20
- # Open a thing for the API
21
  @app.post("/api")
22
- def receive_data():
23
- data = request.get_json()
 
 
 
24
 
25
- print("Prompt:", data["prompt"])
26
- print("Length:", data["length"])
 
 
 
 
 
 
 
 
27
 
28
- input_text = data["prompt"]
29
 
30
  # Tokenize the input text
31
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
32
 
33
  # Generate output using the model
34
- output_ids = model.generate(input_ids, max_length=data["length"], num_beams=5, no_repeat_ngram_size=2)
35
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
36
 
37
  answer_data = { "answer": generated_text }
38
  print("Answered with:", answer_data)
39
- return jsonify(answer_data)
40
-
41
- # Incase a normal browser opens the page
42
- @app.get("/")
43
- def not_api():
44
- return render_template("index.html")
45
-
46
- app.run(debug=False, port=7860, load_dotenv=True)
 
1
+ from fastapi.staticfiles import StaticFiles
2
+ from fastapi.responses import FileResponse
3
+ from pydantic import BaseModel
4
+ from fastapi import FastAPI
5
+
6
  import os
7
+
8
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
9
  import torch
10
 
11
+ app = FastAPI()
 
 
 
 
12
 
13
+ class req(BaseModel):
14
+ prompt: str
15
+ length: int
16
 
17
+ @app.get("/")
18
+ def read_root():
19
+ return FileResponse(path="templates/index.html", media_type="text/html")
20
 
 
21
  @app.post("/api")
22
+ def read_root(data: req):
23
+ name = "microsoft/DialoGPT-medium"
24
+ # microsoft/DialoGPT-small
25
+ # microsoft/DialoGPT-medium
26
+ # microsoft/DialoGPT-large
27
 
28
+ # Load the Hugging Face GPT-2 model and tokenizer
29
+ model = GPT2LMHeadModel.from_pretrained(name)
30
+ tokenizer = GPT2Tokenizer.from_pretrained(name)
31
+
32
+ print("Prompt:", data.prompt)
33
+ print("Length:", data.length)
34
+
35
+ # Using CUDA for an optimal experience
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ model = model.to(device)
38
 
39
+ input_text = data.prompt
40
 
41
  # Tokenize the input text
42
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
43
 
44
  # Generate output using the model
45
+ output_ids = model.generate(input_ids, max_length=length, num_beams=5, no_repeat_ngram_size=2)
46
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
47
 
48
  answer_data = { "answer": generated_text }
49
  print("Answered with:", answer_data)
50
+
51
+ return answer_data