v1
Browse files
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Phantom
|
3 |
-
emoji:
|
4 |
colorFrom: yellow
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: Phantom
|
3 |
+
emoji: 👻
|
4 |
colorFrom: yellow
|
5 |
colorTo: purple
|
6 |
sdk: gradio
|
app.py
CHANGED
@@ -9,9 +9,7 @@ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENT
|
|
9 |
import torch
|
10 |
from PIL import Image
|
11 |
from utils.utils import *
|
12 |
-
import torch.nn.functional as F
|
13 |
from model.load_model import load_model
|
14 |
-
from torchvision.transforms.functional import pil_to_tensor
|
15 |
|
16 |
# Gradio Package
|
17 |
import time
|
@@ -49,7 +47,7 @@ def threading_function(inputs, streamer, device, model, tokenizer, temperature,
|
|
49 |
generation_kwargs.update({'use_cache': True})
|
50 |
return model.generate(**generation_kwargs)
|
51 |
|
52 |
-
|
53 |
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
54 |
|
55 |
# model selection
|
@@ -63,7 +61,7 @@ def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
|
63 |
model = model_7
|
64 |
tokenizer = tokenizer_7
|
65 |
|
66 |
-
# X ->
|
67 |
for param in model.parameters():
|
68 |
if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
|
69 |
param.data = param.data.to(torch.bfloat16)
|
|
|
9 |
import torch
|
10 |
from PIL import Image
|
11 |
from utils.utils import *
|
|
|
12 |
from model.load_model import load_model
|
|
|
13 |
|
14 |
# Gradio Package
|
15 |
import time
|
|
|
47 |
generation_kwargs.update({'use_cache': True})
|
48 |
return model.generate(**generation_kwargs)
|
49 |
|
50 |
+
@spaces.GPU
|
51 |
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
|
52 |
|
53 |
# model selection
|
|
|
61 |
model = model_7
|
62 |
tokenizer = tokenizer_7
|
63 |
|
64 |
+
# X -> bfloat16 conversion
|
65 |
for param in model.parameters():
|
66 |
if 'float32' in str(param.dtype).lower() or 'float16' in str(param.dtype).lower():
|
67 |
param.data = param.data.to(torch.bfloat16)
|