datawithsuman commited on
Commit
682c36d
1 Parent(s): c3e3949

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -28
app.py CHANGED
@@ -40,43 +40,81 @@ if uploaded_files:
40
  documents = reader.load_data()
41
  st.success("File uploaded...")
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # Indexing
44
- index = PropertyGraphIndex.from_documents(
45
- documents,
46
- embed_model=OpenAIEmbedding(model_name="text-embedding-3-small"),
47
- kg_extractors=[
48
- ImplicitPathExtractor(),
49
- SimpleLLMPathExtractor(
50
- llm=OpenAI(model="gpt-3.5-turbo", temperature=0.3),
51
- num_workers=4,
52
- max_paths_per_chunk=10,
53
- ),
54
- ],
55
- show_progress=True,
56
- )
57
-
58
- # Save Knowlege Graph
59
- index.property_graph_store.save_networkx_graph(name="./data/kg.html")
60
-
61
- # Display the graph in Streamlit
62
- st.success("File Processed...")
63
- st.success("Creating Knowledge Graph...")
64
- HtmlFile = open("./data/kg.html", 'r', encoding='utf-8')
65
- source_code = HtmlFile.read()
66
- components.html(source_code, height= 500, width=700)
67
 
68
  # Retrieval
69
- kg_retriever = index.as_retriever(
70
- include_text=True, # include source text, default True
71
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Generation
74
  model = "gpt-3.5-turbo"
75
 
 
 
 
 
 
76
  def get_context(query):
77
- contexts = kg_retriever.retrieve(query)
78
- context_list = [n.text for n in contexts]
79
  return context_list
 
80
 
81
 
82
  def res(prompt):
 
40
  documents = reader.load_data()
41
  st.success("File uploaded...")
42
 
43
+ # # Indexing
44
+ # index = PropertyGraphIndex.from_documents(
45
+ # documents,
46
+ # embed_model=OpenAIEmbedding(model_name="text-embedding-3-small"),
47
+ # kg_extractors=[
48
+ # ImplicitPathExtractor(),
49
+ # SimpleLLMPathExtractor(
50
+ # llm=OpenAI(model="gpt-3.5-turbo", temperature=0.3),
51
+ # num_workers=4,
52
+ # max_paths_per_chunk=10,
53
+ # ),
54
+ # ],
55
+ # show_progress=True,
56
+ # )
57
+
58
+ # # Save Knowlege Graph
59
+ # index.property_graph_store.save_networkx_graph(name="./data/kg.html")
60
+
61
+ # # Display the graph in Streamlit
62
+ # st.success("File Processed...")
63
+ # st.success("Creating Knowledge Graph...")
64
+ # HtmlFile = open("./data/kg.html", 'r', encoding='utf-8')
65
+ # source_code = HtmlFile.read()
66
+ # components.html(source_code, height= 500, width=700)
67
+
68
+ # # Retrieval
69
+ # kg_retriever = index.as_retriever(
70
+ # include_text=True, # include source text, default True
71
+ # )
72
+
73
+
74
  # Indexing
75
+ splitter = SentenceSplitter(chunk_size=256)
76
+ nodes = splitter.get_nodes_from_documents(documents)
77
+ storage_context = StorageContext.from_defaults()
78
+ storage_context.docstore.add_documents(nodes)
79
+ index = VectorStoreIndex(nodes=nodes, storage_context=storage_context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # Retrieval
82
+ bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=10)
83
+ vector_retriever = index.as_retriever(similarity_top_k=10)
84
+
85
+ # Hybrid Retriever class
86
+ class HybridRetriever(BaseRetriever):
87
+ def __init__(self, vector_retriever, bm25_retriever):
88
+ self.vector_retriever = vector_retriever
89
+ self.bm25_retriever = bm25_retriever
90
+ super().__init__()
91
+
92
+ def _retrieve(self, query, **kwargs):
93
+ bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs)
94
+ vector_nodes = self.vector_retriever.retrieve(query, **kwargs)
95
+ all_nodes = []
96
+ node_ids = set()
97
+ for n in bm25_nodes + vector_nodes:
98
+ if n.node.node_id not in node_ids:
99
+ all_nodes.append(n)
100
+ node_ids.add(n.node.node_id)
101
+ return all_nodes
102
+
103
+ hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever)
104
 
105
  # Generation
106
  model = "gpt-3.5-turbo"
107
 
108
+ # def get_context(query):
109
+ # contexts = kg_retriever.retrieve(query)
110
+ # context_list = [n.text for n in contexts]
111
+ # return context_list
112
+
113
  def get_context(query):
114
+ contexts = hybrid_retriever.retrieve(query)
115
+ context_list = [n.get_content() for n in contexts]
116
  return context_list
117
+
118
 
119
 
120
  def res(prompt):