chat, stream code

#1
by JUNGU - opened
import os
from threading import Thread
from typing import Iterator
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import re

# CUDA 사용 가능 여부 확인
if not torch.cuda.is_available():
    raise RuntimeError("CUDA GPU를 찾을 수 없습니다. GPU가 필요합니다.")

print(f"사용 가능한 GPU: {torch.cuda.get_device_name(0)}")
print(f"CUDA 버전: {torch.version.cuda}")

# GPU 메모리 캐시 초기화
torch.cuda.empty_cache()

MAX_INPUT_TOKEN_LENGTH = 8192
DEFAULT_MAX_NEW_TOKENS = 4096

# 모델과 토크나이저 로딩
print("모델을 로딩중입니다...")
model = AutoModelForCausalLM.from_pretrained(
    "UNIVA-Bllossom/DeepSeek-llama3.1-Bllossom-8B",
    torch_dtype=torch.bfloat16, 
    device_map="auto",
    trust_remote_code=True
)

print("토크나이저를 로딩중입니다...")
tokenizer = AutoTokenizer.from_pretrained(
    "UNIVA-Bllossom/DeepSeek-llama3.1-Bllossom-8B",  # 예제랑은 다르게 토크나이저 맞춤 UNIVA-Bllossom/DeepSeek-llama3.3-Bllossom-70B
    trust_remote_code=True
)
tokenizer.use_default_system_prompt = False

system_prompt = '''You are a highly capable assistant. For every user question, follow these instructions exactly:
    1. First, think through the problem step-by-step in English. Enclose all of your internal reasoning between <think> and </think> tags. This chain-of-thought should detail your reasoning process.
    2. After the closing </think> tag, provide your final answer in Korean.
    3. Do not include any additional text or commentary outside of this format.
    4. Your output should strictly follow this structure:

<think>
[Your detailed step-by-step reasoning in English]
</think>
<answer>
[Your final answer in Korean]
</answer>'''

def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 4096,
    temperature: float = 0.7,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    try:
        print("\n사용자:", message)
        conversation = []
        conversation.append({"role": "system", "content": system_prompt})
        for user, assistant in chat_history:
            conversation.extend([
                {"role": "user", "content": user},
                {"role": "assistant", "content": assistant}
            ])
        conversation.append({"role": "user", "content": message})

        inputs = tokenizer.apply_chat_template(
            conversation,
            return_tensors="pt",
            add_generation_prompt=True
        )

        print("\nAI 응답:")
        streamer = TextIteratorStreamer(
            tokenizer,
            timeout=120.0,
            skip_prompt=True,
            skip_special_tokens=False
        )
        
        generate_kwargs = dict(
            input_ids=inputs.to(model.device),
            attention_mask=torch.ones_like(inputs).to(model.device),
            streamer=streamer,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

        t = Thread(target=model.generate, kwargs=generate_kwargs)
        t.start()

        outputs = []
        for text in streamer:
            if text is None:
                continue
            if "</s>" in text:
                text = text.replace("</s>", "")
            outputs.append(text)
            current_response = "".join(outputs)
            print(text, end="", flush=True)  # CLI에 실시간으로 한 글자씩 출력
            yield current_response
        
        print("\n" + "-" * 50)

    except Exception as e:
        error_msg = f"Error in generate: {str(e)}"
        print(error_msg)
        yield "죄송합니다. 응답 생성 중 오류가 발생했습니다. 다시 시도해 주세요."

# Gradio 인터페이스 수정
with gr.Blocks(css="""
    .message-wrap {margin-bottom: 10px;}
    details {margin: 10px 0;}
    summary {cursor: pointer; padding: 5px;}
    summary:hover {background-color: #f5f5f5;}
""") as demo:
    gr.Markdown("## DeepSeek Bllossom 챗봇")
    
    chatbot = gr.Chatbot(
        label="DeepSeek Bllossom 챗봇",
        height=600,
        bubble_full_width=False,
        render_markdown=True,
        show_label=False
    )
    
    with gr.Row():
        msg = gr.Textbox(
            label="메시지 입력",
            placeholder="메시지를 입력하세요...",
            lines=2,
            scale=9
        )
        submit = gr.Button("전송", variant="primary", scale=1)
    
    with gr.Accordion("고급 설정", open=False):
        max_new_tokens = gr.Slider(
            label="최대 토큰 수",
            minimum=1,
            maximum=8192,
            step=1,
            value=4096
        )
        temperature = gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=2.0,
            step=0.1,
            value=0.7
        )
        top_p = gr.Slider(
            label="Top-p",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9
        )
    
    clear = gr.Button("대화 내용 지우기")

    # 이벤트 핸들러
    def user(message, history):
        return "", history + [[message, None]]

    def bot(history, max_new_tokens, temperature, top_p):
        try:
            message = history[-1][0]
            history[-1][1] = ""
            for content in generate(
                message,
                history[:-1],
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p
            ):
                history[-1][1] = content
                yield history
        except Exception as e:
            print(f"Error in bot: {str(e)}")
            history[-1][1] = "죄송합니다. 응답 생성 중 오류가 발생했습니다. 다시 시도해 주세요."
            yield history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, [chatbot, max_new_tokens, temperature, top_p], chatbot
    )
    submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, [chatbot, max_new_tokens, temperature, top_p], chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

    # 초기 인사말
    demo.load(lambda: [[None, "안녕하세요! 무엇을 도와드릴까요?"]], None, chatbot)

demo.queue(max_size=20).launch(share=True)
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment