cutechicken commited on
Commit
9a66aa0
โ€ข
1 Parent(s): 67209ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -7
app.py CHANGED
@@ -6,7 +6,6 @@ import os
6
  from threading import Thread
7
  import random
8
  from datasets import load_dataset
9
- from sentence_transformers import SentenceTransformer
10
  from sklearn.metrics.pairwise import cosine_similarity
11
  import numpy as np
12
 
@@ -18,24 +17,45 @@ MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
18
  MODELS = os.environ.get("MODELS")
19
  MODEL_NAME = MODEL_ID.split("/")[-1]
20
 
21
- # ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
22
- embedding_model = SentenceTransformer('sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens')
 
 
 
 
 
23
 
24
  # ์œ„ํ‚คํ”ผ๋””์•„ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
25
  wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
26
  print("Wikipedia dataset loaded:", wiki_dataset)
27
 
 
 
 
 
 
 
 
 
28
  # ๋ฐ์ดํ„ฐ์…‹์˜ ์งˆ๋ฌธ๋“ค์„ ์ž„๋ฒ ๋”ฉ
29
  questions = wiki_dataset['train']['question'][:10000] # ์ฒ˜์Œ 10000๊ฐœ๋งŒ ์‚ฌ์šฉ
30
- question_embeddings = embedding_model.encode(questions, convert_to_tensor=True)
 
 
 
 
 
 
 
 
31
 
32
  def find_relevant_context(query, top_k=3):
33
  # ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ
34
- query_embedding = embedding_model.encode(query, convert_to_tensor=True)
35
 
36
  # ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
37
  similarities = cosine_similarity(
38
- query_embedding.cpu().numpy().reshape(1, -1),
39
  question_embeddings.cpu().numpy()
40
  )[0]
41
 
@@ -47,7 +67,8 @@ def find_relevant_context(query, top_k=3):
47
  for idx in top_indices:
48
  relevant_contexts.append({
49
  'question': questions[idx],
50
- 'answer': wiki_dataset['train']['answer'][idx]
 
51
  })
52
 
53
  return relevant_contexts
 
6
  from threading import Thread
7
  import random
8
  from datasets import load_dataset
 
9
  from sklearn.metrics.pairwise import cosine_similarity
10
  import numpy as np
11
 
 
17
  MODELS = os.environ.get("MODELS")
18
  MODEL_NAME = MODEL_ID.split("/")[-1]
19
 
20
+ # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_ID,
23
+ torch_dtype=torch.bfloat16,
24
+ device_map="auto",
25
+ )
26
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
27
 
28
  # ์œ„ํ‚คํ”ผ๋””์•„ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
29
  wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
30
  print("Wikipedia dataset loaded:", wiki_dataset)
31
 
32
+ def get_embeddings(text, model, tokenizer):
33
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+ # ๋งˆ์ง€๋ง‰ ํžˆ๋“  ์Šคํ…Œ์ดํŠธ์˜ ํ‰๊ท ์„ ์ž„๋ฒ ๋”ฉ์œผ๋กœ ์‚ฌ์šฉ
37
+ embeddings = outputs.last_hidden_state.mean(dim=1)
38
+ return embeddings
39
+
40
  # ๋ฐ์ดํ„ฐ์…‹์˜ ์งˆ๋ฌธ๋“ค์„ ์ž„๋ฒ ๋”ฉ
41
  questions = wiki_dataset['train']['question'][:10000] # ์ฒ˜์Œ 10000๊ฐœ๋งŒ ์‚ฌ์šฉ
42
+ question_embeddings = []
43
+ batch_size = 32
44
+
45
+ for i in range(0, len(questions), batch_size):
46
+ batch = questions[i:i+batch_size]
47
+ batch_embeddings = get_embeddings(batch, model, tokenizer)
48
+ question_embeddings.append(batch_embeddings)
49
+
50
+ question_embeddings = torch.cat(question_embeddings, dim=0)
51
 
52
  def find_relevant_context(query, top_k=3):
53
  # ์ฟผ๋ฆฌ ์ž„๋ฒ ๋”ฉ
54
+ query_embedding = get_embeddings(query, model, tokenizer)
55
 
56
  # ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
57
  similarities = cosine_similarity(
58
+ query_embedding.cpu().numpy(),
59
  question_embeddings.cpu().numpy()
60
  )[0]
61
 
 
67
  for idx in top_indices:
68
  relevant_contexts.append({
69
  'question': questions[idx],
70
+ 'answer': wiki_dataset['train']['answer'][idx],
71
+ 'similarity': similarities[idx]
72
  })
73
 
74
  return relevant_contexts