import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch class CodeGenerator: def __init__(self, model_name="Salesforce/codet5-base", device=None): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) if device: self.model = self.model.to(device) def generate_code(self, prompt, max_length=100): try: input_ids = self.tokenizer.encode(prompt, return_tensors="pt") output = self.model.generate(input_ids, max_length=max_length, num_return_sequences=1) return self.tokenizer.decode(output[0], skip_special_tokens=True) except Exception as e: return f"Error generating code: {str(e)}" class ChatHandler: def __init__(self, code_generator): self.history = [] self.code_generator = code_generator # Store the generator reference def handle_message(self, message): if not message.strip(): return "", self.history response = self.code_generator.generate_code(message) self.history.append((message, response)) return "", self.history def clear_history(self): self.history = [] return [] def create_gradio_interface(): device = "cuda" if torch.cuda.is_available() else "cpu" code_generator = CodeGenerator(device=device) chat_handler = ChatHandler(code_generator) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# S-Dreamer Salesforce/codet5-base Chat Interface") with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(height=400) message_input = gr.Textbox(label="Enter your code-related query", placeholder="Type your message here...") submit_button = gr.Button("Submit") with gr.Column(scale=1): gr.Markdown("## Features") features = ["Code generation", "Code completion", "Code explanation", "Error correction"] for feature in features: gr.Markdown(f"- {feature}") clear_button = gr.Button("Clear Chat") submit_button.click(chat_handler.handle_message, inputs=message_input, outputs=[message_input, chatbot]) clear_button.click(lambda: (None, chat_handler.clear_history()), inputs=[], outputs=[message_input, chatbot]) demo.launch() if __name__ == "__main__": create_gradio_interface()