Spaces:
Paused
Paused
import gradio | |
import torch | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
# Load model | |
hub_path = 'guptavishal79/aimlops' | |
loaded_model = GPT2LMHeadModel.from_pretrained(hub_path) | |
loaded_tokenizer = GPT2Tokenizer.from_pretrained(hub_path) | |
def generate_response(model, tokenizer, prompt, max_length=200): | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
# Create the attention mask and pad token id | |
attention_mask = torch.ones_like(input_ids) | |
pad_token_id = tokenizer.eos_token_id | |
output = model.generate( | |
input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
attention_mask=attention_mask, | |
pad_token_id=pad_token_id | |
) | |
return tokenizer.decode(output[0], skip_special_tokens=True) | |
# Function for response generation | |
def generate_query_response(prompt, max_length=200): | |
model = loaded_model | |
tokenizer = loaded_tokenizer | |
prompt = f"<question>{prompt}<answer>" | |
response = generate_response(model, tokenizer, prompt, max_length) | |
return response | |
# Gradio elements | |
# Input from user | |
in_prompt = gradio.Textbox(lines=2, placeholder=None, value="", label='Enter Medical Question') | |
in_max_length = gradio.Number(value=200, label='Answer Length') | |
# Output response | |
out_response = gradio.Textbox(type="text", label='Answer') | |
# Gradio interface to generate UI link | |
iface = gradio.Interface(fn = generate_query_response, | |
inputs = [in_prompt, in_max_length], | |
outputs = [out_response]) | |
iface.launch(share = True) |