aquibmoin commited on
Commit
a40023d
1 Parent(s): 575c5d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -17,11 +17,12 @@ client = OpenAI(api_key=api_key)
17
  def encode_text(text):
18
  inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
19
  outputs = bi_model(**inputs)
20
- return outputs.last_hidden_state.mean(dim=1).detach().numpy()
21
 
22
  def retrieve_relevant_context(user_input, context_texts):
23
- user_embedding = encode_text(user_input)
24
  context_embeddings = np.array([encode_text(text) for text in context_texts])
 
25
  similarities = cosine_similarity(user_embedding, context_embeddings).flatten()
26
  most_relevant_idx = np.argmax(similarities)
27
  return context_texts[most_relevant_idx]
@@ -40,7 +41,7 @@ def generate_response(user_input, relevant_context):
40
  frequency_penalty=0.5,
41
  presence_penalty=0.0
42
  )
43
- return response.choices[0].message['content']
44
 
45
  def chatbot(user_input, context=""):
46
  context_texts = context.split("\n")
@@ -61,7 +62,8 @@ iface = gr.Interface(
61
  )
62
 
63
  # Launch the interface
64
- iface.launch()
 
65
 
66
 
67
 
 
17
  def encode_text(text):
18
  inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
19
  outputs = bi_model(**inputs)
20
+ return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten() # Ensure the output is 2D
21
 
22
  def retrieve_relevant_context(user_input, context_texts):
23
+ user_embedding = encode_text(user_input).reshape(1, -1)
24
  context_embeddings = np.array([encode_text(text) for text in context_texts])
25
+ context_embeddings = context_embeddings.reshape(len(context_embeddings), -1) # Flatten each embedding
26
  similarities = cosine_similarity(user_embedding, context_embeddings).flatten()
27
  most_relevant_idx = np.argmax(similarities)
28
  return context_texts[most_relevant_idx]
 
41
  frequency_penalty=0.5,
42
  presence_penalty=0.0
43
  )
44
+ return response.choices[0].message.content.strip()
45
 
46
  def chatbot(user_input, context=""):
47
  context_texts = context.split("\n")
 
62
  )
63
 
64
  # Launch the interface
65
+ iface.launch(share=True)
66
+
67
 
68
 
69