#from transformers import AutoModelForCausalLM, AutoTokenizer import torch """ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") """ import random import json import torch from model import NeuralNet from nltk_utils import bag_of_words, tokenize device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') with open('./intents.json', 'r') as json_data: intents = json.load(json_data) FILE = "./data.pth" data = torch.load(FILE) input_size = data["input_size"] hidden_size = data["hidden_size"] output_size = data["output_size"] all_words = data['all_words'] tags = data['tags'] model_state = data["model_state"] model = NeuralNet(input_size, hidden_size, output_size).to(device) model.load_state_dict(model_state) model.eval() #test def predict(sentence, history=[]): """ # tokenize the new input sentence new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt') # append the new user input tokens to the chat history bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1) # generate a response history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist() # convert the tokens to text, and then split the responses into the right format response = tokenizer.decode(history[0]).split("<|endoftext|>") response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list """ sentence = tokenize(sentence) X = bag_of_words(sentence, all_words) X = X.reshape(1, X.shape[0]) X = torch.from_numpy(X).to(device) output = model(X) _, predicted = torch.max(output, dim=1) tag = tags[predicted.item()] probs = torch.softmax(output, dim=1) prob = probs[0][predicted.item()] if prob.item() > 0.75: for intent in intents['intents']: if tag == intent["tag"]: reply = [random.choice(intent['responses'])] return reply, history import gradio as gr gr.Interface(fn=predict, theme="default", css=".footer {display:none !important}", inputs=["text", "state"], outputs=["chatbot", "state"]).launch()