Do0rMaMu commited on
Commit
f00a373
1 Parent(s): bba0a04

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +3 -1
main.py CHANGED
@@ -3,6 +3,7 @@ from pydantic import BaseModel
3
  from typing import List, Optional, Dict, Any
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer
5
  import torch
 
6
 
7
  app = FastAPI()
8
 
@@ -15,10 +16,11 @@ class PromptRequest(BaseModel):
15
  @app.on_event("startup")
16
  def load_model():
17
  global model, tokenizer, pipe
 
18
  model_path = "model/models--meta-llama--Llama-3.2-3B-Instruct/snapshots/0cb88a4f764b7a12671c53f0838cd831a0843b95"
19
  tokenizer = AutoTokenizer.from_pretrained(model_path)
20
  streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True)
21
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
22
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, streamer=streamer)
23
 
24
  @app.post("/generate/")
 
3
  from typing import List, Optional, Dict, Any
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer
5
  import torch
6
+ import os
7
 
8
  app = FastAPI()
9
 
 
16
  @app.on_event("startup")
17
  def load_model():
18
  global model, tokenizer, pipe
19
+ os.environ["TRANSFORMERS_CACHE"] = "./cache"
20
  model_path = "model/models--meta-llama--Llama-3.2-3B-Instruct/snapshots/0cb88a4f764b7a12671c53f0838cd831a0843b95"
21
  tokenizer = AutoTokenizer.from_pretrained(model_path)
22
  streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True)
23
+ model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, cache_dir="./cache")
24
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, streamer=streamer)
25
 
26
  @app.post("/generate/")