Spaces:
Running
on
Zero
Running
on
Zero
cutechicken
commited on
Commit
•
50ef49c
1
Parent(s):
a908cb3
Update app.py
Browse files
app.py
CHANGED
@@ -6,8 +6,8 @@ import os
|
|
6 |
from threading import Thread
|
7 |
import random
|
8 |
from datasets import load_dataset
|
9 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
10 |
import numpy as np
|
|
|
11 |
|
12 |
# GPU 메모리 관리
|
13 |
torch.cuda.empty_cache()
|
@@ -29,40 +29,19 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
29 |
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
|
30 |
print("Wikipedia dataset loaded:", wiki_dataset)
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
embeddings = hidden_states.mean(dim=1)
|
39 |
-
return embeddings
|
40 |
-
|
41 |
-
# 데이터셋의 질문들을 임베딩
|
42 |
-
print("임베딩 생성 시작...")
|
43 |
-
questions = wiki_dataset['train']['question'][:1000] # 처음 1000개만 사용 (테스트용)
|
44 |
-
question_embeddings = []
|
45 |
-
batch_size = 8 # 배치 사이즈 줄임
|
46 |
-
|
47 |
-
for i in range(0, len(questions), batch_size):
|
48 |
-
batch = questions[i:i+batch_size]
|
49 |
-
batch_embeddings = get_embeddings(batch, model, tokenizer)
|
50 |
-
question_embeddings.append(batch_embeddings.cpu())
|
51 |
-
if i % 100 == 0:
|
52 |
-
print(f"Processed {i}/{len(questions)} questions")
|
53 |
-
|
54 |
-
question_embeddings = torch.cat(question_embeddings, dim=0)
|
55 |
-
print("임베딩 생성 완료")
|
56 |
|
57 |
def find_relevant_context(query, top_k=3):
|
58 |
-
# 쿼리
|
59 |
-
|
60 |
|
61 |
# 코사인 유사도 계산
|
62 |
-
similarities =
|
63 |
-
query_embedding.cpu().numpy(),
|
64 |
-
question_embeddings.numpy()
|
65 |
-
)[0]
|
66 |
|
67 |
# 가장 유사한 질문들의 인덱스
|
68 |
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
@@ -70,11 +49,12 @@ def find_relevant_context(query, top_k=3):
|
|
70 |
# 관련 컨텍스트 추출
|
71 |
relevant_contexts = []
|
72 |
for idx in top_indices:
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
78 |
|
79 |
return relevant_contexts
|
80 |
|
@@ -83,11 +63,11 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
|
|
83 |
print(f'message is - {message}')
|
84 |
print(f'history is - {history}')
|
85 |
|
86 |
-
#
|
87 |
relevant_contexts = find_relevant_context(message)
|
88 |
context_prompt = "\n\n관련 참고 정보:\n"
|
89 |
for ctx in relevant_contexts:
|
90 |
-
context_prompt += f"Q: {ctx['question']}\nA: {ctx['answer']}\n\n"
|
91 |
|
92 |
# 대화 히스토리 구성
|
93 |
conversation = []
|
@@ -97,6 +77,7 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
|
|
97 |
{"role": "assistant", "content": answer}
|
98 |
])
|
99 |
|
|
|
100 |
# 컨텍스트를 포함한 최종 프롬프트 구성
|
101 |
final_message = context_prompt + "\n현재 질문: " + message
|
102 |
conversation.append({"role": "user", "content": final_message})
|
|
|
6 |
from threading import Thread
|
7 |
import random
|
8 |
from datasets import load_dataset
|
|
|
9 |
import numpy as np
|
10 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
11 |
|
12 |
# GPU 메모리 관리
|
13 |
torch.cuda.empty_cache()
|
|
|
29 |
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
|
30 |
print("Wikipedia dataset loaded:", wiki_dataset)
|
31 |
|
32 |
+
# TF-IDF 벡터라이저 초기화 및 학습
|
33 |
+
print("TF-IDF 벡터화 시작...")
|
34 |
+
questions = wiki_dataset['train']['question'][:10000] # 처음 10000개만 사용
|
35 |
+
vectorizer = TfidfVectorizer(max_features=1000)
|
36 |
+
question_vectors = vectorizer.fit_transform(questions)
|
37 |
+
print("TF-IDF 벡터화 완료")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
def find_relevant_context(query, top_k=3):
|
40 |
+
# 쿼리 벡터화
|
41 |
+
query_vector = vectorizer.transform([query])
|
42 |
|
43 |
# 코사인 유사도 계산
|
44 |
+
similarities = (query_vector * question_vectors.T).toarray()[0]
|
|
|
|
|
|
|
45 |
|
46 |
# 가장 유사한 질문들의 인덱스
|
47 |
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
|
|
49 |
# 관련 컨텍스트 추출
|
50 |
relevant_contexts = []
|
51 |
for idx in top_indices:
|
52 |
+
if similarities[idx] > 0: # 유사도가 0보다 큰 경우만 포함
|
53 |
+
relevant_contexts.append({
|
54 |
+
'question': questions[idx],
|
55 |
+
'answer': wiki_dataset['train']['answer'][idx],
|
56 |
+
'similarity': similarities[idx]
|
57 |
+
})
|
58 |
|
59 |
return relevant_contexts
|
60 |
|
|
|
63 |
print(f'message is - {message}')
|
64 |
print(f'history is - {history}')
|
65 |
|
66 |
+
# 관련 컨텍스트 찾기
|
67 |
relevant_contexts = find_relevant_context(message)
|
68 |
context_prompt = "\n\n관련 참고 정보:\n"
|
69 |
for ctx in relevant_contexts:
|
70 |
+
context_prompt += f"Q: {ctx['question']}\nA: {ctx['answer']}\n유사도: {ctx['similarity']:.3f}\n\n"
|
71 |
|
72 |
# 대화 히스토리 구성
|
73 |
conversation = []
|
|
|
77 |
{"role": "assistant", "content": answer}
|
78 |
])
|
79 |
|
80 |
+
|
81 |
# 컨텍스트를 포함한 최종 프롬프트 구성
|
82 |
final_message = context_prompt + "\n현재 질문: " + message
|
83 |
conversation.append({"role": "user", "content": final_message})
|