import spaces import gradio as gr import transformers import torch model_id = "meta-llama/Meta-Llama-3.1-8B" @spaces.GPU(duration=60) def load_pipeline(): return transformers.pipeline( "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto" ) pipeline = load_pipeline() def generate_response(chat, kwargs): output = pipeline(chat, **kwargs)[0]['generated_text'] if output.endswith(""): output = output[:-4] return output def function(prompt, history=[]): chat = "" for user_prompt, bot_response in history: chat += f"[INST] {user_prompt} [/INST] {bot_response} " chat += f"[INST] {prompt} [/INST]" kwargs = dict( max_new_tokens=4096, do_sample=True, temperature=0.5, top_p=0.95, repetition_penalty=1.0, seed=1337 ) try: output = generate_response(chat, kwargs) return output except: return '' # Interfejs Gradio interface = gr.ChatInterface( fn=function, chatbot=gr.Chatbot( avatar_images=None, container=False, show_copy_button=True, layout='bubble', render_markdown=True, line_breaks=True ), css='h1 {font-size:22px;} h2 {font-size:20px;} h3 {font-size:18px;} h4 {font-size:16px;}', autofocus=True, fill_height=True, analytics_enabled=False, submit_btn='Chat', stop_btn=None, retry_btn=None, undo_btn=None, clear_btn=None ) # API endpoint def api_predict(prompt): return function(prompt) interface.launch(show_api=True, share=True) # Dodanie endpointu API gr.Interface(fn=api_predict, inputs="text", outputs="text").launch(share=True)