cutechicken commited on
Commit
50ef49c
1 Parent(s): a908cb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -38
app.py CHANGED
@@ -6,8 +6,8 @@ import os
6
  from threading import Thread
7
  import random
8
  from datasets import load_dataset
9
- from sklearn.metrics.pairwise import cosine_similarity
10
  import numpy as np
 
11
 
12
  # GPU 메모리 관리
13
  torch.cuda.empty_cache()
@@ -29,40 +29,19 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
29
  wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
30
  print("Wikipedia dataset loaded:", wiki_dataset)
31
 
32
- def get_embeddings(text, model, tokenizer):
33
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
34
- with torch.no_grad():
35
- outputs = model(**inputs)
36
- # hidden states의 평균을 임베딩으로 사용
37
- hidden_states = outputs[0] # 모델의 마지막 레이어 출력
38
- embeddings = hidden_states.mean(dim=1)
39
- return embeddings
40
-
41
- # 데이터셋의 질문들을 임베딩
42
- print("임베딩 생성 시작...")
43
- questions = wiki_dataset['train']['question'][:1000] # 처음 1000개만 사용 (테스트용)
44
- question_embeddings = []
45
- batch_size = 8 # 배치 사이즈 줄임
46
-
47
- for i in range(0, len(questions), batch_size):
48
- batch = questions[i:i+batch_size]
49
- batch_embeddings = get_embeddings(batch, model, tokenizer)
50
- question_embeddings.append(batch_embeddings.cpu())
51
- if i % 100 == 0:
52
- print(f"Processed {i}/{len(questions)} questions")
53
-
54
- question_embeddings = torch.cat(question_embeddings, dim=0)
55
- print("임베딩 생성 완료")
56
 
57
  def find_relevant_context(query, top_k=3):
58
- # 쿼리 임베딩
59
- query_embedding = get_embeddings(query, model, tokenizer)
60
 
61
  # 코사인 유사도 계산
62
- similarities = cosine_similarity(
63
- query_embedding.cpu().numpy(),
64
- question_embeddings.numpy()
65
- )[0]
66
 
67
  # 가장 유사한 질문들의 인덱스
68
  top_indices = np.argsort(similarities)[-top_k:][::-1]
@@ -70,11 +49,12 @@ def find_relevant_context(query, top_k=3):
70
  # 관련 컨텍스트 추출
71
  relevant_contexts = []
72
  for idx in top_indices:
73
- relevant_contexts.append({
74
- 'question': questions[idx],
75
- 'answer': wiki_dataset['train']['answer'][idx],
76
- 'similarity': similarities[idx]
77
- })
 
78
 
79
  return relevant_contexts
80
 
@@ -83,11 +63,11 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
83
  print(f'message is - {message}')
84
  print(f'history is - {history}')
85
 
86
- # RAG: 관련 컨텍스트 찾기
87
  relevant_contexts = find_relevant_context(message)
88
  context_prompt = "\n\n관련 참고 정보:\n"
89
  for ctx in relevant_contexts:
90
- context_prompt += f"Q: {ctx['question']}\nA: {ctx['answer']}\n\n"
91
 
92
  # 대화 히스토리 구성
93
  conversation = []
@@ -97,6 +77,7 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
97
  {"role": "assistant", "content": answer}
98
  ])
99
 
 
100
  # 컨텍스트를 포함한 최종 프롬프트 구성
101
  final_message = context_prompt + "\n현재 질문: " + message
102
  conversation.append({"role": "user", "content": final_message})
 
6
  from threading import Thread
7
  import random
8
  from datasets import load_dataset
 
9
  import numpy as np
10
+ from sklearn.feature_extraction.text import TfidfVectorizer
11
 
12
  # GPU 메모리 관리
13
  torch.cuda.empty_cache()
 
29
  wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
30
  print("Wikipedia dataset loaded:", wiki_dataset)
31
 
32
+ # TF-IDF 벡터라이저 초기화 및 학습
33
+ print("TF-IDF 벡터화 시작...")
34
+ questions = wiki_dataset['train']['question'][:10000] # 처음 10000개만 사용
35
+ vectorizer = TfidfVectorizer(max_features=1000)
36
+ question_vectors = vectorizer.fit_transform(questions)
37
+ print("TF-IDF 벡터화 완료")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def find_relevant_context(query, top_k=3):
40
+ # 쿼리 벡터화
41
+ query_vector = vectorizer.transform([query])
42
 
43
  # 코사인 유사도 계산
44
+ similarities = (query_vector * question_vectors.T).toarray()[0]
 
 
 
45
 
46
  # 가장 유사한 질문들의 인덱스
47
  top_indices = np.argsort(similarities)[-top_k:][::-1]
 
49
  # 관련 컨텍스트 추출
50
  relevant_contexts = []
51
  for idx in top_indices:
52
+ if similarities[idx] > 0: # 유사도가 0보다 큰 경우만 포함
53
+ relevant_contexts.append({
54
+ 'question': questions[idx],
55
+ 'answer': wiki_dataset['train']['answer'][idx],
56
+ 'similarity': similarities[idx]
57
+ })
58
 
59
  return relevant_contexts
60
 
 
63
  print(f'message is - {message}')
64
  print(f'history is - {history}')
65
 
66
+ # 관련 컨텍스트 찾기
67
  relevant_contexts = find_relevant_context(message)
68
  context_prompt = "\n\n관련 참고 정보:\n"
69
  for ctx in relevant_contexts:
70
+ context_prompt += f"Q: {ctx['question']}\nA: {ctx['answer']}\n유사도: {ctx['similarity']:.3f}\n\n"
71
 
72
  # 대화 히스토리 구성
73
  conversation = []
 
77
  {"role": "assistant", "content": answer}
78
  ])
79
 
80
+
81
  # 컨텍스트를 포함한 최종 프롬프트 구성
82
  final_message = context_prompt + "\n현재 질문: " + message
83
  conversation.append({"role": "user", "content": final_message})