Kr08 commited on
Commit
6a968bc
1 Parent(s): b74dff6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -21
app.py CHANGED
@@ -3,32 +3,34 @@ import gradio as gr
3
 
4
  # from airllm import HuggingFaceModelLoader, AutoModelForCausalLM
5
 
6
- from airllm import AutoModel
7
- import mlx.core as mx
8
-
9
- model = AutoModel.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
10
- # model = AutoModel.from_pretrained(model_loader)
11
- MAX_LENGTH = 128
12
 
 
 
 
 
 
 
13
 
14
  @spaces.GPU
15
  def generate_text(input_text):
16
-
17
- input_tokens = model.tokenizer(input_text,
18
- return_tensors="np",
19
- return_attention_mask=False,
20
- truncation=True,
21
- max_length=MAX_LENGTH,
22
- padding=False)
 
 
 
23
 
24
 
25
- output = model.generate(mx.array(input_tokens['input_ids']),
26
- max_new_tokens=20,
27
- use_cache=True,
28
- return_dict_in_generate=True)
29
- # input_ids = model.tokenizer.encode(input_text, return_tensors="np")
30
- # output = model.generate(input_ids, max_length=100)
31
- # return model.tokenizer.decode(output[0])
32
  return output
33
 
34
 
@@ -36,7 +38,7 @@ iface = gr.Interface(
36
  fn=generate_text,
37
  inputs=gr.Textbox(placeholder="Enter prompt..."),
38
  outputs="text",
39
- title="LLaMA 3 70B Text Generation"
40
  )
41
 
42
  iface.launch(server_name="0.0.0.0", server_port=7860)
 
3
 
4
  # from airllm import HuggingFaceModelLoader, AutoModelForCausalLM
5
 
6
+ model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
 
 
 
 
 
7
 
8
+ pipeline = transformers.pipeline(
9
+ "text-generation",
10
+ model=model_id,
11
+ model_kwargs={"torch_dtype": torch.bfloat16},
12
+ device_map="auto",
13
+ )
14
 
15
  @spaces.GPU
16
  def generate_text(input_text):
17
+
18
+ output = pipeline(messages,
19
+ max_new_tokens=256,
20
+ )
21
+ # input_tokens = model.tokenizer(input_text,
22
+ # return_tensors="np",
23
+ # return_attention_mask=False,
24
+ # truncation=True,
25
+ # max_length=MAX_LENGTH,
26
+ # padding=False)
27
 
28
 
29
+ # output = model.generate(mx.array(input_tokens['input_ids']),
30
+ # max_new_tokens=20,
31
+ # use_cache=True,
32
+ # return_dict_in_generate=True)
33
+
 
 
34
  return output
35
 
36
 
 
38
  fn=generate_text,
39
  inputs=gr.Textbox(placeholder="Enter prompt..."),
40
  outputs="text",
41
+ title="LLaMA 3 8B Text Generation"
42
  )
43
 
44
  iface.launch(server_name="0.0.0.0", server_port=7860)