Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
@@ -3,6 +3,7 @@ from pydantic import BaseModel
|
|
3 |
from typing import List, Optional, Dict, Any
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer
|
5 |
import torch
|
|
|
6 |
|
7 |
app = FastAPI()
|
8 |
|
@@ -15,10 +16,11 @@ class PromptRequest(BaseModel):
|
|
15 |
@app.on_event("startup")
|
16 |
def load_model():
|
17 |
global model, tokenizer, pipe
|
|
|
18 |
model_path = "model/models--meta-llama--Llama-3.2-3B-Instruct/snapshots/0cb88a4f764b7a12671c53f0838cd831a0843b95"
|
19 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
20 |
streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True)
|
21 |
-
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.
|
22 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, streamer=streamer)
|
23 |
|
24 |
@app.post("/generate/")
|
|
|
3 |
from typing import List, Optional, Dict, Any
|
4 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer
|
5 |
import torch
|
6 |
+
import os
|
7 |
|
8 |
app = FastAPI()
|
9 |
|
|
|
16 |
@app.on_event("startup")
|
17 |
def load_model():
|
18 |
global model, tokenizer, pipe
|
19 |
+
os.environ["TRANSFORMERS_CACHE"] = "./cache"
|
20 |
model_path = "model/models--meta-llama--Llama-3.2-3B-Instruct/snapshots/0cb88a4f764b7a12671c53f0838cd831a0843b95"
|
21 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
22 |
streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True)
|
23 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, cache_dir="./cache")
|
24 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, streamer=streamer)
|
25 |
|
26 |
@app.post("/generate/")
|