Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""Untitled3.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1zwLQmMKCQKLMkJ_5Un4C6V4ajs4LYUOR | |
""" | |
""" | |
!pip install --upgrade typing-extensions -q | |
!pip install -q gradio --upgrade -q | |
!pip install keras_nlp -q | |
""" | |
from google.colab import drive | |
drive.mount('/content/drive') | |
import os | |
from tensorflow import keras | |
import keras_nlp | |
import gradio as gr | |
import random | |
import time | |
os.environ["KERAS_BACKEND"] = "tensorflow" # or "tensorflow" or "torch" | |
keras.mixed_precision.set_global_policy("mixed_float16") | |
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( | |
"gpt2_large_en", | |
sequence_length=256, | |
) | |
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( | |
"gpt2_large_en", preprocessor=preprocessor | |
) | |
gpt2_lm.load_weights('./drive/MyDrive/checkpoints/my_checkpoint') | |
css = """ | |
.gradio-container { | |
background-color: transparent; | |
color: #f5f5dc; | |
border-color: #d5aa5e; | |
} | |
/* Styling for the chatbot */ | |
.chat{ | |
border-color: #d5aa5e; | |
background-color:#22201f; | |
background-image: url('https://github.com/BastienHot/SAE-GPT2/blob/70fb88500a2cc168d71e8ed635fc54492beb6241/image/logo.png'); | |
background-size: cover; | |
background-position: center; | |
} | |
/* Styling for the user */ | |
.user{ | |
background-color: #957d52; | |
} | |
/* Styling for the text inside the chatbot */ | |
.gradio-chatbox .message-container .message-right { | |
color: #f5f5dc; /* Antique white text color */ | |
border-color: #d5aa5e; | |
background-color: red; | |
} | |
.md svelte-1syupzx chatbot{ | |
border-color: #d5aa5e; | |
background-color: #3e3836; | |
} | |
.message user svelte-1lcyrx4 message-bubble-border { | |
border-color: #3e3836; | |
} | |
""" | |
def predict(text): | |
# Simulating model prediction | |
return gpt2_lm.generate(text) | |
with gr.Blocks(css=css) as demo: | |
chatbot = gr.Chatbot(elem_classes="chat") | |
msg = gr.Textbox(elem_classes="user") | |
clear = gr.ClearButton([msg, chatbot]) | |
def respond(message, chat_history): | |
bot_message = predict(message) | |
# Ajouter une classe pour la partie bot_message | |
bot_message_html = f'<div class="bot-message">{bot_message}</div>' | |
# Ajouter une classe pour la partie message | |
user_message_html = f'<div class="user-message">{message}</div>' | |
chat_history.append((user_message_html, bot_message_html)) | |
time.sleep(2) | |
return "", chat_history | |
msg.submit(respond, [msg, chatbot], [msg, chatbot]) | |
if __name__ == "__main__": | |
demo.launch(debug=True, share=True) | |