freddyaboulton's picture
Update app.py
0b14f95 verified
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import gradio as gr
from gradio_webrtc import WebRTC, ReplyOnPause, AdditionalOutputs
import numpy as np
import os
from twilio.rest import Client
import base64
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
checkpoint = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
whisper = pipeline(
model="openai/whisper-large-v3-turbo", device=device
)
system_prompt = "You are an AI coding assistant. Your task is to write single-file HTML applications based on a user's request. You may also be asked to edit your original response. Only return the code needed to fulfill the request."
user_prompt = "Please write a single-file HTML application to fulfill the following request. Only return the necessary code. Include all necessary imports and styles.\nThe message:{user_message}\nCurrent code you have written:{code}"
def extract_html_content(text):
"""
Extract content including HTML tags.
"""
try:
start_tag = "<html>"
end_tag = "</html>"
# Find positions of start and end tags
start_pos = text.find(start_tag)
end_pos = text.find(end_tag)
# Check if both tags exist and are in correct order
if start_pos == -1 or end_pos == -1 or start_pos > end_pos:
return None
# Extract content including tags
return text[start_pos:end_pos + len(end_tag)]
except Exception as e:
print(f"Error processing string: {e}")
return None
def display_in_sandbox(code):
encoded_html = base64.b64encode(code.encode('utf-8')).decode('utf-8')
data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
return f"<iframe src=\"{data_uri}\" width=\"100%\" height=\"920px\"></iframe>"
def generate(user_message: tuple[int, np.ndarray],
history: list[dict],
code: str):
audio_float32 = user_message[1].astype(np.float32) / 32768.0
msg_text = whisper({"array": audio_float32.squeeze(), "sampling_rate": user_message[0]})["text"]
print("msg_text", msg_text)
user_msg_formatted = user_prompt.format(user_message=msg_text, code=code)
print("user_msg_formatted", user_msg_formatted)
history.append({"role": "user", "content": user_msg_formatted})
input_text = tokenizer.apply_chat_template(history, tokenize=False)
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=2048, temperature=0.2, top_p=0.9, do_sample=True)
response = tokenizer.decode(outputs[0])
print("response", response)
output = response[response.rindex("<|im_start|>assistant\n") + len("<|im_start|>assistant\n"):]
html_code = extract_html_content(output)
history.append({"role": "assistant", "content": output})
yield AdditionalOutputs(history, html_code)
with gr.Blocks() as demo:
history = gr.State([{"role": "system", "content": system_prompt}])
with gr.Row():
code = gr.Code(language="html", interactive=False)
sandbox = gr.HTML("")
with gr.Row():
webrtc = WebRTC(rtc_configuration=rtc_configuration, mode="send", modality="audio")
webrtc.stream(ReplyOnPause(generate),
inputs=[webrtc, history, code],
outputs=[webrtc], time_limit=90)
webrtc.on_additional_outputs(lambda history, code: (history, code),
outputs=[history, code])
code.change(display_in_sandbox, code, sandbox, queue=False)
if __name__ == "__main__":
demo.launch()