keerthi-balaji commited on
Commit
71790db
·
verified ·
1 Parent(s): 144d69d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -61
app.py CHANGED
@@ -1,71 +1,28 @@
1
- import numpy as np
2
- from datasets import load_dataset
3
  import gradio as gr
4
  from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
5
- import torch
6
 
7
- # Load the horoscope-chat dataset
8
- dataset = load_dataset("chloeliu/horoscope-chat", split="train")
 
9
 
10
- # Assuming 'input' and 'response' are the correct keys:
11
- def prepare_docs(dataset):
12
- docs = []
13
- for data in dataset:
14
- question = data.get('input', '') # Safely access the 'input' field
15
- answer = data.get('response', '') # Safely access the 'response' field
16
- docs.append({
17
- "question": question,
18
- "answer": answer
19
- })
20
- return docs
21
 
22
- # Prepare the documents
23
- docs = prepare_docs(dataset)
24
-
25
- # Custom Retriever that searches in the dataset
26
- class HoroscopeRetriever(RagRetriever):
27
- def __init__(self, docs, tokenizer):
28
- self.docs = docs
29
- self.tokenizer = tokenizer
30
-
31
- def retrieve(self, question_hidden_states, n_docs=1):
32
- # Convert the question_hidden_states to a text string
33
- question = question_hidden_states[0]
34
-
35
- if isinstance(question, np.ndarray):
36
- if question.size == 1:
37
- question = question.item() # Convert single-element array to scalar
38
- else:
39
- question = str(question[0]) # Take the first element of the array
40
- else:
41
- question = str(question)
42
-
43
- question = question.lower()
44
-
45
- # Simple retrieval logic: find the most relevant document based on the question
46
- best_match = None
47
- for doc in self.docs:
48
- if question in doc["question"].lower():
49
- best_match = doc
50
- break
51
-
52
- if best_match:
53
- # Fake embedding as RAG expects this (In a real case, compute embeddings)
54
- retrieved_doc_embeds = torch.zeros((1, 1, 768)) # Example tensor
55
- doc_ids = ["0"] # Example document ID
56
- docs = [best_match["answer"]]
57
  else:
58
- retrieved_doc_embeds = torch.zeros((1, 1, 768)) # Example tensor
59
- doc_ids = ["0"] # Example document ID
60
- docs = ["Sorry, I couldn't find a relevant horoscope."]
61
-
62
- return retrieved_doc_embeds, doc_ids, docs
63
 
64
- # Initialize the custom retriever with the dataset
65
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
66
- retriever = HoroscopeRetriever(docs, tokenizer)
67
 
68
  # Initialize RAG components
 
69
  model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
70
 
71
  # Define the chatbot function
@@ -78,5 +35,5 @@ def horoscope_chatbot(input_text):
78
  # Set up Gradio interface
79
  iface = gr.Interface(fn=horoscope_chatbot, inputs="text", outputs="text", title="Horoscope RAG Chatbot")
80
 
81
- # Launch the interface with sharing enabled
82
- iface.launch(share=True)
 
 
 
1
  import gradio as gr
2
  from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
3
+ import json
4
 
5
+ # Load horoscope data
6
+ with open("horoscope_data.json", "r") as file:
7
+ horoscope_data = json.load(file)
8
 
9
+ # Custom Retriever that looks up horoscopes
10
+ class CustomHoroscopeRetriever(RagRetriever):
11
+ def __init__(self, horoscope_data):
12
+ self.horoscope_data = horoscope_data
 
 
 
 
 
 
 
13
 
14
+ def retrieve(self, question_texts, n_docs=1):
15
+ zodiac_sign = question_texts[0].capitalize()
16
+ if zodiac_sign in self.horoscope_data:
17
+ return [self.horoscope_data[zodiac_sign]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  else:
19
+ return ["I couldn't find your zodiac sign. Please try again with a valid one."]
 
 
 
 
20
 
21
+ # Initialize the custom retriever with the horoscope data
22
+ retriever = CustomHoroscopeRetriever(horoscope_data)
 
23
 
24
  # Initialize RAG components
25
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
26
  model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever)
27
 
28
  # Define the chatbot function
 
35
  # Set up Gradio interface
36
  iface = gr.Interface(fn=horoscope_chatbot, inputs="text", outputs="text", title="Horoscope RAG Chatbot")
37
 
38
+ # Launch the interface
39
+ iface.launch()