ajaynagotha commited on
Commit
3c4e014
·
verified ·
1 Parent(s): ee87c6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -20
app.py CHANGED
@@ -1,23 +1,30 @@
 
1
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer
2
  from datasets import load_dataset
 
 
 
 
3
 
4
  # Load the dataset
5
  dataset = load_dataset("knowrohit07/gita_dataset")
 
6
 
7
  # Preprocess the dataset
8
  def preprocess_function(examples):
9
  inputs = [f"Question: {q} Answer:" for q in examples["question"]]
10
  targets = examples["answer"]
11
- return {"input_ids": tokenizer(inputs, padding="max_length", truncation=True)["input_ids"],
12
- "labels": tokenizer(targets, padding="max_length", truncation=True)["input_ids"]}
13
 
14
  # Load the model and tokenizer
15
  model_name = "t5-base" # Or any other suitable model
16
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
18
 
19
  # Tokenize the dataset
20
  tokenized_dataset = dataset.map(preprocess_function, batched=True)
 
21
 
22
  # Fine-tune the model on the dataset
23
  training_args = TrainingArguments(
@@ -29,6 +36,7 @@ training_args = TrainingArguments(
29
  per_device_eval_batch_size=16,
30
  num_train_epochs=3,
31
  weight_decay=0.01,
 
32
  )
33
 
34
  trainer = Trainer(
@@ -39,30 +47,51 @@ trainer = Trainer(
39
  eval_dataset=tokenized_dataset["validation"],
40
  )
41
 
 
42
  trainer.train()
 
43
 
44
- # Define the Gradio interface
 
 
 
 
45
  def answer_question(question):
46
- """
47
- This function takes a question about the Bhagavad Gita and uses the fine-tuned model to generate an answer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- Args:
50
- question: A string representing the user's question.
51
 
52
- Returns:
53
- A string representing the model's generated answer.
54
- """
55
- input_ids = tokenizer(question, return_tensors="pt").input_ids
56
- output = model.generate(input_ids, max_length=500, no_repeat_ngram_size=2)
57
- answer = tokenizer.decode(output[0], skip_special_tokens=True)
58
- return answer.strip()
59
 
 
60
  interface = gr.Interface(
61
- fn=answer_question,
62
- inputs="text",
63
- outputs="text",
64
- title="Bhagavad Gita Q&A",
65
- description="Ask your questions about the Bhagavad Gita and receive insights from the model."
66
  )
67
 
68
  interface.launch()
 
1
+ import logging
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments, Trainer
3
  from datasets import load_dataset
4
+ import gradio as gr
5
+
6
+ # Configure logging
7
+ logging.basicConfig(level=logging.INFO)
8
 
9
  # Load the dataset
10
  dataset = load_dataset("knowrohit07/gita_dataset")
11
+ logging.info("Dataset loaded successfully.")
12
 
13
  # Preprocess the dataset
14
  def preprocess_function(examples):
15
  inputs = [f"Question: {q} Answer:" for q in examples["question"]]
16
  targets = examples["answer"]
17
+ return tokenizer(inputs, targets, padding="max_length", truncation=True)
 
18
 
19
  # Load the model and tokenizer
20
  model_name = "t5-base" # Or any other suitable model
 
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
23
+ logging.info("Model and tokenizer loaded successfully.")
24
 
25
  # Tokenize the dataset
26
  tokenized_dataset = dataset.map(preprocess_function, batched=True)
27
+ logging.info("Dataset tokenized successfully.")
28
 
29
  # Fine-tune the model on the dataset
30
  training_args = TrainingArguments(
 
36
  per_device_eval_batch_size=16,
37
  num_train_epochs=3,
38
  weight_decay=0.01,
39
+ logging_dir="./logs", # Specify the logging directory
40
  )
41
 
42
  trainer = Trainer(
 
47
  eval_dataset=tokenized_dataset["validation"],
48
  )
49
 
50
+ logging.info("Starting training...")
51
  trainer.train()
52
+ logging.info("Training completed.")
53
 
54
+ # Save the fine-tuned model
55
+ model.save_pretrained("gita_model")
56
+ tokenizer.save_pretrained("gita_tokenizer")
57
+
58
+ # Define the question-answering function
59
  def answer_question(question):
60
+ """
61
+ Answers a question about the Bhagavad Gita using a fine-tuned model.
62
+
63
+ Args:
64
+ question: The question to be answered.
65
+
66
+ Returns:
67
+ The answer generated by the model.
68
+ """
69
+
70
+ try:
71
+ # Load the fine-tuned model and tokenizer
72
+ model = AutoModelForSeq2SeqLM.from_pretrained("gita_model")
73
+ tokenizer = AutoTokenizer.from_pretrained("gita_tokenizer")
74
+
75
+ # Preprocess the input
76
+ input_ids = tokenizer(question, return_tensors="pt").input_ids
77
+
78
+ # Generate the answer
79
+ output = model.generate(input_ids, max_length=500, no_repeat_ngram_size=2)
80
+ answer = tokenizer.decode(output[0], skip_special_tokens=True)
81
 
82
+ return answer.strip()
 
83
 
84
+ except Exception as e:
85
+ logging.error(f"An error occurred: {e}")
86
+ return "I couldn't find an answer to your question. Please try rephrasing it or asking something different."
 
 
 
 
87
 
88
+ # Create the Gradio interface
89
  interface = gr.Interface(
90
+ fn=answer_question,
91
+ inputs="text",
92
+ outputs="text",
93
+ title="Bhagavad Gita Q&A",
94
+ description="Ask your questions about the Bhagavad Gita and receive insights from the model."
95
  )
96
 
97
  interface.launch()