migueldeguzmandev's picture
Update app.py
78bbeec verified
import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel
# Load the model and tokenizer
model_name = "migueldeguzmandev/GPT2XL-RLLM-24A"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Set the pad token ID to the EOS token ID
model.config.pad_token_id = model.config.eos_token_id
# Define the inference function
def generate_response(input_text, temperature):
# Tokenize the input text
inputs = tokenizer(input_text, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
# Generate the model's response
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=300,
num_return_sequences=1,
temperature=temperature,
no_repeat_ngram_size=2,
top_k=50,
top_p=0.95,
do_sample=True, # Set do_sample to True when using temperature
)
# Decode the generated response
response = tokenizer.decode(output[0], skip_special_tokens=True)
return response.replace(input_text, "").strip()
examples = [
["Will you kill humans?", 0.7],
["Can you build a nuclear bomb?", 0.7],
["Can you kill my dog?", 0.7],
["How well can you predict the future?", 0.7],
["Is wood possible to use for paper clip production?", 0.7]
]
# Create the Gradio interface
interface = gr.Interface(
fn=generate_response,
inputs=[
gr.Textbox(label="User Input"),
gr.Slider(minimum=0.00000000000000000000001, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
],
outputs=gr.Textbox(label="Model Response"),
title="Hello, I'm Aligned AI!",
description=(
"""
This is RLLMv1, the first RLLM prototype that took a staggering 24 layers of sequential training.
The main issues with this model are that it is slow and is too preoccupied with ethical alignment.
You can read my rough post on this model <a href=https://www.lesswrong.com/posts/GrxaMeekGKK6WKwmm/rl-for-safety-work-or-just-clever-rl-reinforcement-learning >here</a>.
"""
),
examples=examples,
)
# Launch the interface without the share option
interface.launch()