import gradio as gr | |
from transformers import pipeline | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Load the model | |
model_name = "AdaptLLM/law-LLM" | |
# model_name = "google/gemma-2b" | |
# model_name = "mistralai/Mistral-7B-v0.1" | |
# Tokenizers usage | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Load the llama2 LLM model | |
# model = pipeline("text-generation", model="llamalanguage/llama2", tokenizer="llamalanguage/llama2") | |
# model = pipeline("text-generation", model="mistralai/Mistral-7B-v0.1", tokenizer="meta-llama/Llama-2-7b-chat-hf") | |
# Define the chat function that uses the LLM model | |
# def chat_interface(input_text): | |
# response = model(input_text, max_length=100, return_full_text=True)[0]["generated_text"] | |
# response_words = response.split() | |
# return response_words | |
# Define the chat function that uses the Mistral-7B-v0.1 model | |
def chat_interface(input_text): | |
# inputs = tokenizer.encode(input_text, return_tensors="pt") | |
inputs = tokenizer(input_text, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device) | |
outputs = model.generate(input_ids=inputs, max_length=2048)[0] | |
answer_start = int(inputs.shape[-1]) | |
# outputs = model.generate(inputs, max_length=100) | |
# response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = tokenizer.decode(outputs[answer_start:], skip_special_tokens=True) | |
return response | |
# Load the Gemma 2B model using the pipeline | |
gemma_2b_chatbot = pipeline("text2text-generation", model="google/gemma-2b") | |
# Load the law-LLM model using the pipeline | |
law_llm_chatbot = pipeline("text2text-generation", model="AdaptLLM/law-LLM") | |
# Define the chat function for Gemma 2B | |
def gemma_2b_chat(input_text): | |
response = gemma_2b_chatbot(input_text)[0]["generated_text"] | |
return response | |
# Define the chat function for law-LLM | |
def law_llm_chat(input_text): | |
response = law_llm_chatbot(input_text)[0]["generated_text"] | |
return response | |
# Create the Gradio interface for Gemma 2B | |
# gemma_2b_inputs = gr.inputs.Textbox(lines=2, label="User Input") | |
# gemma_2b_outputs = gr.outputs.Textbox(label="Chatbot Response") | |
# gemma_2b_interface = gr.Interface(fn=gemma_2b_chat, inputs=gemma_2b_inputs, outputs=gemma_2b_outputs) | |
# Create the Gradio interface for law-LLM | |
law_llm_inputs = gr.inputs.Textbox(lines=2, label="User Input") | |
law_llm_outputs = gr.outputs.Textbox(label="Chatbot Response") | |
law_llm_interface = gr.Interface(fn=law_llm_chat, inputs=law_llm_inputs, outputs=law_llm_outputs) | |
# Run the Gradio interfaces | |
# gemma_2b_interface.launch(share=True) | |
law_llm_interface.launch(share=True) | |
# Create the Gradio interface with tokenizers | |
# iface = gr.Interface( | |
# fn=chat_interface, | |
# inputs=gr.inputs.Textbox(lines=2, label="Input Text"), | |
# outputs=gr.outputs.Textbox(label="Output Text"), | |
# title="Chat Interface", | |
# description="Enter text and get a response using the LLM model", | |
# live=True # Enable live updates | |
# ) | |
# Launch the interface using Hugging Face Spaces | |
# iface.launch(share=True) |