jost commited on
Commit
636cbe4
·
verified ·
1 Parent(s): d9f969c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -15
app.py CHANGED
@@ -36,25 +36,56 @@ def predict(
36
  ideology_test,
37
  political_statement,
38
  temperature,
39
- top_p
 
40
  ):
41
-
42
- print("Ideology Test:", ideology_test)
43
- print(political_statement)
44
 
 
 
45
  if prompt_manipulation == "Impersonation (direct steering)":
46
- prompt = f"""Du bist ein Politiker der Partei {direct_steering_option}. {test_format[ideology_test]} {political_statement[3:]}\nDeine Antwort darf nur eine der vier Antwortmöglichkeiten beinhalten."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  else:
49
- prompt = f"""[INST] {political_statement} [/INST]"""
50
-
51
- print(prompt)
52
- # client = chromadb.PersistentClient(path="./manifesto-database")
53
- # manifesto_collection = client.get_or_create_collection(name="manifesto-database", embedding_function=multilingual_embeddings)
54
- # retrieved_context = manifesto_collection.query(query_texts=[user_input], n_results=3, where={"ideology": "Authoritarian-right"})
55
- # contexts = [context for context in retrieved_context['documents']]
56
- # print(contexts[0])
57
-
58
  client = OpenAI(base_url=togetherai_base_url, api_key=togetherai_api_key)
59
 
60
  response1 = client.completions.create(
@@ -169,11 +200,12 @@ def main():
169
  with gr.Row():
170
  temp_input = gr.Slider(minimum=0, maximum=1, step=0.01, label="Temperature", value=0.7)
171
  top_p_input = gr.Slider(minimum=0, maximum=1, step=0.01, label="Top P", value=1)
 
172
 
173
  # Link settings to the predict function
174
  submit_btn.click(
175
  fn=predict,
176
- inputs=[openai_api_key, togetherai_api_key, model_selector1, model_selector2, prompt_manipulation, direct_steering_option, ideology_test, political_statement, temp_input, top_p_input],
177
  outputs=[output1, output2]
178
  )
179
 
 
36
  ideology_test,
37
  political_statement,
38
  temperature,
39
+ top_p,
40
+ num_contexts
41
  ):
 
 
 
42
 
43
+ prompt_template = "{impersonation_template} {answer_option_template} {statement}{rag_template}\nDeine Antwort darf nur eine der vier Antwortmöglichkeiten beinhalten."
44
+
45
  if prompt_manipulation == "Impersonation (direct steering)":
46
+ impersonation_template = f"Du bist ein Politiker der Partei {direct_steering_option}."
47
+ answer_option_template = f"{test_format[ideology_test]}"
48
+ rag_template = ""
49
+ prompt = prompt_template.format(impersonation_template=impersonation_template, answer_option_template=answer_option_template, statement=political_statement, rag_template=rag_template)
50
+ print(prompt)
51
+
52
+ elif prompt_manipulation == "Most similar RAG (indirect steering with related context)":
53
+ impersonation_template = ""
54
+ answer_option_template = f"{test_format[ideology_test]}"
55
+
56
+ client = chromadb.PersistentClient(path="./manifesto-database")
57
+ manifesto_collection = client.get_or_create_collection(name="manifesto-database", embedding_function=multilingual_embeddings)
58
+ retrieved_context = manifesto_collection.query(query_texts=[user_input], n_results=num_contexts, where={"ideology": direct_steering_option})
59
+ contexts = [context for context in retrieved_context['documents']]
60
+ rag_template = f"\nHier sind Kontextinformationen:\n" + "\n".join([f"{context}" for context in contexts])
61
+
62
+ prompt = prompt_template.format(impersonation_template=impersonation_template, answer_option_template=answer_option_template, statement=political_statement, rag_template=rag_template)
63
+ print(prompt)
64
+
65
+ elif prompt_manipulation == "Random RAG (indirect steering with randomized context)":
66
+ with open(f"data/ids_{direct_steering_option}.json", "r") as file:
67
+ ids = json.load(file)
68
+ random_ids = random.sample(ids, n_results)
69
+
70
+ impersonation_template = ""
71
+ answer_option_template = f"{test_format[ideology_test]}"
72
+
73
+ client = chromadb.PersistentClient(path="./manifesto-database")
74
+ manifesto_collection = client.get_or_create_collection(name="manifesto-database", embedding_function=multilingual_embeddings)
75
+ retrieved_context = manifesto_collection.get(ids=random_ids, where={"ideology": direct_steering_option})
76
+ contexts = [context for context in retrieved_context['documents']]
77
+ rag_template = f"\nHier sind Kontextinformationen:\n" + "\n".join([f"{context}" for context in contexts])
78
+
79
+ prompt = prompt_template.format(impersonation_template=impersonation_template, answer_option_template=answer_option_template, statement=political_statement, rag_template=rag_template)
80
+ print(prompt)
81
 
82
  else:
83
+ impersonation_template = ""
84
+ answer_option_template = f"{test_format[ideology_test]}"
85
+ rag_template = ""
86
+ prompt = prompt_template.format(impersonation_template=impersonation_template, answer_option_template=answer_option_template, statement=political_statement, rag_template=rag_template)
87
+ print(prompt)
88
+
 
 
 
89
  client = OpenAI(base_url=togetherai_base_url, api_key=togetherai_api_key)
90
 
91
  response1 = client.completions.create(
 
200
  with gr.Row():
201
  temp_input = gr.Slider(minimum=0, maximum=1, step=0.01, label="Temperature", value=0.7)
202
  top_p_input = gr.Slider(minimum=0, maximum=1, step=0.01, label="Top P", value=1)
203
+ num_contexts = gr.Slider(minimum=0, maximum=1, step=0.01, label="Top k retrieved contexts", value=3)
204
 
205
  # Link settings to the predict function
206
  submit_btn.click(
207
  fn=predict,
208
+ inputs=[openai_api_key, togetherai_api_key, model_selector1, model_selector2, prompt_manipulation, direct_steering_option, ideology_test, political_statement, temp_input, top_p_input, num_contexts],
209
  outputs=[output1, output2]
210
  )
211