bstraehle commited on
Commit
caeaee0
1 Parent(s): fbf1f2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -48
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- import logging, os, sys, time
3
 
4
  from dotenv import load_dotenv, find_dotenv
5
 
@@ -7,6 +7,8 @@ from rag_langchain import LangChainRAG
7
  from rag_llamaindex import LlamaIndexRAG
8
  from trace import trace_wandb
9
 
 
 
10
  _ = load_dotenv(find_dotenv())
11
 
12
  RAG_INGESTION = False # load, split, embed, and store documents
@@ -34,55 +36,58 @@ def invoke(openai_api_key, prompt, rag_option):
34
  if (rag_option is None):
35
  raise gr.Error("Retrieval-Augmented Generation is required.")
36
 
37
- os.environ["OPENAI_API_KEY"] = openai_api_key
38
-
39
- if (RAG_INGESTION):
40
- if (rag_option == RAG_LANGCHAIN):
41
- rag = LangChainRAG()
42
- rag.ingestion(config)
43
- elif (rag_option == RAG_LLAMAINDEX):
44
- rag = LlamaIndexRAG()
45
- rag.ingestion(config)
46
-
47
- completion = ""
48
- result = ""
49
- callback = ""
50
- err_msg = ""
51
 
52
- try:
53
- start_time_ms = round(time.time() * 1000)
54
-
55
- if (rag_option == RAG_LANGCHAIN):
56
- rag = LangChainRAG()
57
- completion, callback = rag.rag_chain(config, prompt)
58
- result = completion["result"]
59
- elif (rag_option == RAG_LLAMAINDEX):
60
- rag = LlamaIndexRAG()
61
- result, callback = rag.retrieval(config, prompt)
62
- else:
63
- rag = LangChainRAG()
64
- completion, callback = rag.llm_chain(config, prompt)
65
- result = completion.generations[0][0].text
66
- except Exception as e:
67
- err_msg = e
68
-
69
- raise gr.Error(e)
70
- finally:
71
- end_time_ms = round(time.time() * 1000)
72
 
73
- trace_wandb(
74
- config,
75
- rag_option,
76
- prompt,
77
- completion,
78
- result,
79
- callback,
80
- err_msg,
81
- start_time_ms,
82
- end_time_ms
83
- )
84
-
85
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  gr.close_all()
88
 
 
1
  import gradio as gr
2
+ import logging, os, sys, threading, time
3
 
4
  from dotenv import load_dotenv, find_dotenv
5
 
 
7
  from rag_llamaindex import LlamaIndexRAG
8
  from trace import trace_wandb
9
 
10
+ lock = threading.Lock()
11
+
12
  _ = load_dotenv(find_dotenv())
13
 
14
  RAG_INGESTION = False # load, split, embed, and store documents
 
36
  if (rag_option is None):
37
  raise gr.Error("Retrieval-Augmented Generation is required.")
38
 
39
+ with lock:
40
+ os.environ["OPENAI_API_KEY"] = openai_api_key
41
+
42
+ if (RAG_INGESTION):
43
+ if (rag_option == RAG_LANGCHAIN):
44
+ rag = LangChainRAG()
45
+ rag.ingestion(config)
46
+ elif (rag_option == RAG_LLAMAINDEX):
47
+ rag = LlamaIndexRAG()
48
+ rag.ingestion(config)
 
 
 
 
49
 
50
+ completion = ""
51
+ result = ""
52
+ callback = ""
53
+ err_msg = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ try:
56
+ start_time_ms = round(time.time() * 1000)
57
+
58
+ if (rag_option == RAG_LANGCHAIN):
59
+ rag = LangChainRAG()
60
+ completion, callback = rag.rag_chain(config, prompt)
61
+ result = completion["result"]
62
+ elif (rag_option == RAG_LLAMAINDEX):
63
+ rag = LlamaIndexRAG()
64
+ result, callback = rag.retrieval(config, prompt)
65
+ else:
66
+ rag = LangChainRAG()
67
+ completion, callback = rag.llm_chain(config, prompt)
68
+ result = completion.generations[0][0].text
69
+ except Exception as e:
70
+ err_msg = e
71
+
72
+ raise gr.Error(e)
73
+ finally:
74
+ end_time_ms = round(time.time() * 1000)
75
+
76
+ trace_wandb(
77
+ config,
78
+ rag_option,
79
+ prompt,
80
+ completion,
81
+ result,
82
+ callback,
83
+ err_msg,
84
+ start_time_ms,
85
+ end_time_ms
86
+ )
87
+
88
+ del os.environ["OPENAI_API_KEY"]
89
+
90
+ return result
91
 
92
  gr.close_all()
93