Avimanyu commited on
Commit
ec7b21b
1 Parent(s): 8814dc8

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -59
app.py DELETED
@@ -1,59 +0,0 @@
1
- import gradio as gr
2
- import torch
3
-
4
- from transformers import T5Tokenizer, T5ForConditionalGeneration
5
-
6
- # Define hyperparameters
7
- max_seq_length = 512
8
- max_output_length = 1024
9
- num_beams = 16
10
- length_penalty = 1.4
11
- no_repeat_ngram_size = 2
12
- temperature = 0.7
13
- top_k = 150
14
- top_p = 0.92
15
- repetition_penalty = 2.1
16
- early_stopping = True
17
-
18
- # Load the pre-trained model and tokenizer
19
- model_name = "google/flan-t5-large"
20
- tokenizer = T5Tokenizer.from_pretrained(model_name, model_max_length=512)
21
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
-
23
- if torch.cuda.device_count() > 1:
24
- device_ids = [i for i in range(torch.cuda.device_count())]
25
- model = torch.nn.DataParallel(T5ForConditionalGeneration.from_pretrained(model_name, return_dict=True), device_ids=device_ids)
26
- else:
27
- model = T5ForConditionalGeneration.from_pretrained(model_name, return_dict=True)
28
-
29
- model.to(device)
30
-
31
- # Define a function to generate a response to user input
32
- def chatbot(text):
33
- with torch.no_grad():
34
- # Tokenize the input text and convert to a PyTorch tensor
35
- input_ids = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=max_seq_length).input_ids.to(device)
36
-
37
- # Generate a response using the model
38
- if torch.cuda.device_count() > 1:
39
- outputs = model.module.generate(input_ids, min_length=max_seq_length, max_new_tokens=max_output_length, num_beams=num_beams, length_penalty=length_penalty, no_repeat_ngram_size=no_repeat_ngram_size, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, early_stopping=early_stopping)
40
- else:
41
- outputs = model.generate(input_ids, min_length=max_seq_length, max_new_tokens=max_output_length, num_beams=num_beams, length_penalty=length_penalty, no_repeat_ngram_size=no_repeat_ngram_size, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, early_stopping=early_stopping)
42
-
43
- # Decode the response and return it
44
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
45
- return response
46
-
47
- chat_help_text = "Welcome! This ChatBot is designed to answer questions about a wide range of topics. " \
48
- "Please note that the ChatBot may not always provide accurate or complete answers, and may not " \
49
- "understand certain questions. To use the ChatBot, simply type in your question in the text box " \
50
- "below and hit Enter or click the button. Please keep in mind that the ChatBot is not perfect " \
51
- "and may provide inaccurate or incomplete answers. It is best suited for simple factual " \
52
- "questions rather than complex or nuanced inquiries."
53
-
54
- # Create a Gradio interface
55
- iface = gr.Interface(fn=chatbot, inputs="text", outputs="text", title="NuNet Inferencing Demo",
56
- description=chat_help_text)
57
-
58
- iface.launch(share=True)
59
-