Rahul-8853 commited on
Commit
b98bf6d
·
verified ·
1 Parent(s): 4bdfa1f
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -1,24 +1,30 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
 
 
 
 
 
 
5
  # Load the model and tokenizer
6
- model_name = "KevSun/Personality_LM"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
 
10
  # Function to predict personality traits
11
- def predict_personality(text):
12
- inputs = tokenizer(text, return_tensors="pt")
13
- outputs = model(**inputs)
14
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
15
- labels = ["Introverted", "Extroverted", "Open", "Agreeable", "Conscientious", "Neurotic"]
16
- predictions = {label: prob.item() for label, prob in zip(labels, probs[0])}
17
- return predictions
 
 
18
 
19
  # Create the Gradio interface
20
  interface = gr.Interface(
21
- fn=predict_personality,
22
  inputs=gr.Textbox(lines=2, placeholder="Enter a sentence here..."),
23
  outputs=gr.Label(),
24
  title="Personality Analyzer",
@@ -26,4 +32,4 @@ interface = gr.Interface(
26
  )
27
 
28
  # Launch the Gradio app on a specific port
29
- interface.launch(server_port=7862) # You can change 7861 to another port if necessary
 
1
  import gradio as gr
2
+ from transformers import BertTokenizer, BertForSequenceClassification
3
  import torch
4
 
5
+ # Function to load model and tokenizer
6
+ def load_model():
7
+ tokenizer = BertTokenizer.from_pretrained("Minej/bert-base-personality")
8
+ model = BertForSequenceClassification.from_pretrained("Minej/bert-base-personality")
9
+ return tokenizer, model
10
+
11
  # Load the model and tokenizer
12
+ tokenizer, model = load_model()
 
 
13
 
14
  # Function to predict personality traits
15
+ def personality_detection(text):
16
+ inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
17
+ with torch.no_grad():
18
+ outputs = model(**inputs)
19
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1).squeeze().numpy()
20
+
21
+ label_names = ['Extroversion', 'Neuroticism', 'Agreeableness', 'Conscientiousness', 'Openness']
22
+ result = {label_names[i]: predictions[i] for i in range(len(label_names))}
23
+ return result
24
 
25
  # Create the Gradio interface
26
  interface = gr.Interface(
27
+ fn=personality_detection,
28
  inputs=gr.Textbox(lines=2, placeholder="Enter a sentence here..."),
29
  outputs=gr.Label(),
30
  title="Personality Analyzer",
 
32
  )
33
 
34
  # Launch the Gradio app on a specific port
35
+ interface.launch(server_port=7861) # You can change 7861 to another port if necessary