BeveledCube commited on
Commit
ac96c11
·
verified ·
1 Parent(s): 01960d6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +12 -8
main.py CHANGED
@@ -1,36 +1,40 @@
1
  import os
2
- from flask import Flask, request, jsonify
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
  import torch
5
 
6
  app = Flask("Response API")
7
- size = "small"
8
  # microsoft/DialoGPT-small
9
  # microsoft/DialoGPT-medium
10
  # microsoft/DialoGPT-large
11
 
12
  # Load the Hugging Face GPT-2 model and tokenizer
13
- model = GPT2LMHeadModel.from_pretrained("microsoft/DialoGPT-medium")
14
- tokenizer = GPT2Tokenizer.from_pretrained("microsoft/DialoGPT-medium")
15
 
16
  @app.route("/", methods=["POST"])
17
  def receive_data():
18
  data = request.get_json()
19
 
20
- print("Prompt:", data['prompt'])
21
- print("Length:", data['length'])
22
 
23
- input_text = data['prompt']
24
 
25
  # Tokenize the input text
26
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
27
 
28
  # Generate output using the model
29
- output_ids = model.generate(input_ids, max_length=data['length'], num_beams=5, no_repeat_ngram_size=2)
30
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
31
 
32
  answer_data = { "answer": generated_text }
33
  print("Answered with:", answer_data)
34
  return jsonify(answer_data)
35
 
 
 
 
 
36
  app.run(debug=False, port=7860)
 
1
  import os
2
+ from flask import Flask, request, jsonify, render_template
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
  import torch
5
 
6
  app = Flask("Response API")
7
+ name = "microsoft/DialoGPT-medium"
8
  # microsoft/DialoGPT-small
9
  # microsoft/DialoGPT-medium
10
  # microsoft/DialoGPT-large
11
 
12
  # Load the Hugging Face GPT-2 model and tokenizer
13
+ model = GPT2LMHeadModel.from_pretrained(name)
14
+ tokenizer = GPT2Tokenizer.from_pretrained(name)
15
 
16
  @app.route("/", methods=["POST"])
17
  def receive_data():
18
  data = request.get_json()
19
 
20
+ print("Prompt:", data["prompt"])
21
+ print("Length:", data["length"])
22
 
23
+ input_text = data["prompt"]
24
 
25
  # Tokenize the input text
26
  input_ids = tokenizer.encode(input_text, return_tensors="pt")
27
 
28
  # Generate output using the model
29
+ output_ids = model.generate(input_ids, max_length=data["length"], num_beams=5, no_repeat_ngram_size=2)
30
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
31
 
32
  answer_data = { "answer": generated_text }
33
  print("Answered with:", answer_data)
34
  return jsonify(answer_data)
35
 
36
+ @app.route("/", methods=["GET"])
37
+ def receive_data():
38
+ return render_template("index.html")
39
+
40
  app.run(debug=False, port=7860)