JRQi commited on
Commit
906bdab
1 Parent(s): 71c9215

Update game1.py

Browse files
Files changed (1) hide show
  1. game1.py +8 -4
game1.py CHANGED
@@ -114,7 +114,8 @@ def func1(lang_selected, num_selected, human_predict, num1, num2, user_important
114
  # Use a pipeline as a high-level helper
115
  from transformers import pipeline
116
 
117
- classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
 
118
  output = classifier([text['text']])
119
 
120
  star2num = {
@@ -322,7 +323,8 @@ def func1_written(text_written, human_predict, lang_written):
322
  # tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
323
  # model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
324
 
325
- classifier = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
 
326
 
327
  output = classifier([text_written])
328
 
@@ -353,10 +355,12 @@ def func1_written(text_written, human_predict, lang_written):
353
  import shap
354
 
355
  # sentiment_classifier = pipeline("text-classification", return_all_scores=True)
 
 
356
  if lang_written == "Dutch":
357
- sentiment_classifier = pipeline("text-classification", model='DTAI-KULeuven/robbert-v2-dutch-sentiment', return_all_scores=True)
358
  else:
359
- sentiment_classifier = pipeline("text-classification", model='distilbert-base-uncased-finetuned-sst-2-english', return_all_scores=True)
360
 
361
  explainer = shap.Explainer(sentiment_classifier)
362
 
 
114
  # Use a pipeline as a high-level helper
115
  from transformers import pipeline
116
 
117
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
118
+ classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=device)
119
  output = classifier([text['text']])
120
 
121
  star2num = {
 
323
  # tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
324
  # model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
325
 
326
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
327
+ classifier = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment", device=device)
328
 
329
  output = classifier([text_written])
330
 
 
355
  import shap
356
 
357
  # sentiment_classifier = pipeline("text-classification", return_all_scores=True)
358
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
359
+
360
  if lang_written == "Dutch":
361
+ sentiment_classifier = pipeline("text-classification", model='DTAI-KULeuven/robbert-v2-dutch-sentiment', return_all_scores=True, device=device)
362
  else:
363
+ sentiment_classifier = pipeline("text-classification", model='distilbert-base-uncased-finetuned-sst-2-english', return_all_scores=True, device=device)
364
 
365
  explainer = shap.Explainer(sentiment_classifier)
366