SAE-GPT2-PROD / gradio.py
FFatih's picture
Synchronisation Production
7474d35
raw
history blame
2.58 kB
# -*- coding: utf-8 -*-
"""Untitled3.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1zwLQmMKCQKLMkJ_5Un4C6V4ajs4LYUOR
"""
"""
!pip install --upgrade typing-extensions -q
!pip install -q gradio --upgrade -q
!pip install keras_nlp -q
"""
from google.colab import drive
drive.mount('/content/drive')
import os
from tensorflow import keras
import keras_nlp
import gradio as gr
import random
import time
os.environ["KERAS_BACKEND"] = "tensorflow" # or "tensorflow" or "torch"
keras.mixed_precision.set_global_policy("mixed_float16")
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
"gpt2_large_en",
sequence_length=256,
)
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
"gpt2_large_en", preprocessor=preprocessor
)
gpt2_lm.load_weights('./drive/MyDrive/checkpoints/my_checkpoint')
css = """
.gradio-container {
background-color: transparent;
color: #f5f5dc;
border-color: #d5aa5e;
}
/* Styling for the chatbot */
.chat{
border-color: #d5aa5e;
background-color:#22201f;
background-image: url('https://github.com/BastienHot/SAE-GPT2/blob/70fb88500a2cc168d71e8ed635fc54492beb6241/image/logo.png');
background-size: cover;
background-position: center;
}
/* Styling for the user */
.user{
background-color: #957d52;
}
/* Styling for the text inside the chatbot */
.gradio-chatbox .message-container .message-right {
color: #f5f5dc; /* Antique white text color */
border-color: #d5aa5e;
background-color: red;
}
.md svelte-1syupzx chatbot{
border-color: #d5aa5e;
background-color: #3e3836;
}
.message user svelte-1lcyrx4 message-bubble-border {
border-color: #3e3836;
}
"""
def predict(text):
# Simulating model prediction
return gpt2_lm.generate(text)
with gr.Blocks(css=css) as demo:
chatbot = gr.Chatbot(elem_classes="chat")
msg = gr.Textbox(elem_classes="user")
clear = gr.ClearButton([msg, chatbot])
def respond(message, chat_history):
bot_message = predict(message)
# Ajouter une classe pour la partie bot_message
bot_message_html = f'<div class="bot-message">{bot_message}</div>'
# Ajouter une classe pour la partie message
user_message_html = f'<div class="user-message">{message}</div>'
chat_history.append((user_message_html, bot_message_html))
time.sleep(2)
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
if __name__ == "__main__":
demo.launch(debug=True, share=True)