emeses commited on
Commit
52e2a53
·
1 Parent(s): 62e9f3c

Update space

Browse files
Files changed (2) hide show
  1. .gitignore +4 -0
  2. app.py +48 -32
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .env
4
+ venv/
app.py CHANGED
@@ -1,56 +1,72 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("emeses/lab2_model")
8
-
9
-
10
 
11
  def respond(
12
  message,
13
  history: list[tuple[str, str]],
14
  system_message,
15
- max_tokens,
16
- temperature,
17
- top_p,
18
  ):
19
- # Simpler prompt format
20
- prompt = message
21
-
22
- response = ""
23
  try:
24
- # Basic text generation without streaming first
25
- response = client.text_generation(
26
- prompt,
 
 
 
 
 
 
27
  max_new_tokens=max_tokens,
28
  temperature=temperature,
29
  top_p=top_p,
 
 
30
  )
 
 
 
 
 
 
 
31
  return response
32
  except Exception as e:
33
  return f"Error: {str(e)}"
34
 
35
- """
36
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
37
- """
38
- demo = gr.ChatInterface(
39
  respond,
40
  additional_inputs=[
41
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
42
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
43
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
44
- gr.Slider(
45
- minimum=0.1,
46
- maximum=1.0,
47
- value=0.95,
48
- step=0.05,
49
- label="Top-p (nucleus sampling)",
50
  ),
 
 
 
51
  ],
 
 
 
 
 
52
  )
53
 
54
-
55
- if __name__ == "__main__":
56
- demo.launch()
 
 
 
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from peft import PeftModel
5
 
6
+ # Load model and tokenizer
7
+ base_model = AutoModelForCausalLM.from_pretrained("unsloth/Llama-3.2-3B-Instruct-bnb-4bit")
8
+ model = PeftModel.from_pretrained(base_model, "emeses/lab2_model")
9
+ tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-3B-Instruct-bnb-4bit")
 
 
10
 
11
  def respond(
12
  message,
13
  history: list[tuple[str, str]],
14
  system_message,
15
+ max_tokens=512,
16
+ temperature=0.7,
17
+ top_p=0.9,
18
  ):
 
 
 
 
19
  try:
20
+ # Format the prompt
21
+ prompt = f"{system_message}\n\nUser: {message}\nAssistant:"
22
+
23
+ # Tokenize input
24
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
25
+
26
+ # Generate response
27
+ outputs = model.generate(
28
+ inputs.input_ids,
29
  max_new_tokens=max_tokens,
30
  temperature=temperature,
31
  top_p=top_p,
32
+ do_sample=True,
33
+ pad_token_id=tokenizer.pad_token_id,
34
  )
35
+
36
+ # Decode response
37
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+
39
+ # Extract assistant's response
40
+ response = response.split("Assistant:")[-1].strip()
41
+
42
  return response
43
  except Exception as e:
44
  return f"Error: {str(e)}"
45
 
46
+ # Create Gradio interface
47
+ iface = gr.ChatInterface(
 
 
48
  respond,
49
  additional_inputs=[
50
+ gr.Textbox(
51
+ label="System Message",
52
+ value="You are a helpful AI assistant.",
53
+ lines=2 # Better for system prompts
 
 
 
 
 
54
  ),
55
+ gr.Slider(minimum=1, maximum=1024, value=512, label="Max Tokens"),
56
+ gr.Slider(minimum=0, maximum=1, value=0.7, label="Temperature", step=0.1),
57
+ gr.Slider(minimum=0, maximum=1, value=0.9, label="Top P", step=0.1),
58
  ],
59
+ title="Chat with Fine-tuned LLaMA Model",
60
+ description="A conversational AI powered by fine-tuned LLaMA 3.2B model",
61
+ retry_btn="Regenerate", # Add retry button
62
+ undo_btn="Delete Last", # Add undo button
63
+ clear_btn="Clear Chat" # Add clear button
64
  )
65
 
66
+ # Add examples to help users (optional)
67
+ iface.queue().launch(
68
+ share=True,
69
+ server_name="0.0.0.0",
70
+ server_port=7860,
71
+ show_error=True # Better error visibility
72
+ )