import gradio as gr | |
from transformers import pipeline | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Load the model and tokenizer | |
model_name = "AdaptLLM/law-LLM" | |
# model_name = "google/gemma-2b" | |
# model_name = "mistralai/Mistral-7B-v0.1" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Load the llama2 LLM model | |
# model = pipeline("text-generation", model="llamalanguage/llama2", tokenizer="llamalanguage/llama2") | |
# model = pipeline("text-generation", model="mistralai/Mistral-7B-v0.1", tokenizer="meta-llama/Llama-2-7b-chat-hf") | |
# Define the chat function that uses the LLM model | |
# def chat_interface(input_text): | |
# response = model(input_text, max_length=100, return_full_text=True)[0]["generated_text"] | |
# response_words = response.split() | |
# return response_words | |
# Define the chat function that uses the Mistral-7B-v0.1 model | |
def chat_interface(input_text): | |
# inputs = tokenizer.encode(input_text, return_tensors="pt") | |
inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device) | |
outputs = model.generate(input_ids=inputs, max_length=2048)[0] | |
answer_start = int(inputs.shape[-1]) | |
# outputs = model.generate(inputs, max_length=100) | |
# response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = tokenizer.decode(outputs[answer_start:], skip_special_tokens=True) | |
return response | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=chat_interface, | |
inputs=gr.inputs.Textbox(lines=2, label="Input Text"), | |
outputs=gr.outputs.Textbox(label="Output Text"), | |
title="Chat Interface", | |
description="Enter text and get a response using the LLM model", | |
# live=True # Enable live updates | |
) | |
# Launch the interface using Hugging Face Spaces | |
iface.launch(share=True) |