Spaces:
Sleeping
Sleeping
Florian
commited on
Commit
·
af6c532
1
Parent(s):
5b2e6a5
remove penalty alpha and just put the model
Browse files
app.py
CHANGED
@@ -25,23 +25,8 @@ def reset():
|
|
25 |
_, st.session_state['logits'], _, st.session_state['head_tokens'] = generate_next_token(st.session_state.model, st.session_state.tokenizer, st.session_state['current_sentence'])
|
26 |
|
27 |
@st.cache_resource
|
28 |
-
def load_model(
|
29 |
-
|
30 |
-
0.5:"model_20240118-192548.bin",
|
31 |
-
2:"model_20240118-211943.bin",
|
32 |
-
5:"model_20240118-231333.bin",
|
33 |
-
10:"model_20240119-010725.bin",
|
34 |
-
20:"model_20240119-030115.bin",
|
35 |
-
0:"model_20240119-135506.bin",
|
36 |
-
1:"model_20240119-154900.bin",
|
37 |
-
-20: "model_20240208-072350.bin",
|
38 |
-
-10: "model_20240208-052958.bin",
|
39 |
-
-5: "model_20240208-033606.bin",
|
40 |
-
-2: "model_20240208-014211.bin",
|
41 |
-
-1: "model_20240207-234817.bin",
|
42 |
-
-0.5: "model_20240207-215423.bin",
|
43 |
-
-0.1: "model_20240207-200020.bin"}
|
44 |
-
|
45 |
model_str = "susnato/phi-1_5_dev"
|
46 |
model = AutoModelForCausalLM.from_pretrained(model_str).to("cuda:1")
|
47 |
tokenizer = AutoTokenizer.from_pretrained(model_str)
|
@@ -49,19 +34,15 @@ def load_model(penalty_alpha):
|
|
49 |
branch_locations = list(range(0, 23, 5))
|
50 |
model = BranchyModel(branch_locations= branch_locations, model= model).to("cuda:1")
|
51 |
|
52 |
-
# Load the specific model
|
53 |
-
model_path =
|
54 |
-
if model_path:
|
55 |
-
model.load_state_dict(torch.load(model_path, map_location="cuda:1"))
|
56 |
-
else:
|
57 |
-
print("Invalid penalty_alpha. Using default model weights.")
|
58 |
|
59 |
return model, tokenizer
|
60 |
|
61 |
|
62 |
if "model" not in st.session_state or "tokenizer" not in st.session_state:
|
63 |
print("Loading model...")
|
64 |
-
st.session_state.model, st.session_state.tokenizer = load_model(
|
65 |
st.session_state["head_number"] = len(st.session_state.model.branch_locations) + 1
|
66 |
print(f"Head number: {st.session_state['head_number']}")
|
67 |
# Session state to store the current sentence
|
|
|
25 |
_, st.session_state['logits'], _, st.session_state['head_tokens'] = generate_next_token(st.session_state.model, st.session_state.tokenizer, st.session_state['current_sentence'])
|
26 |
|
27 |
@st.cache_resource
|
28 |
+
def load_model(model_path):
|
29 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
model_str = "susnato/phi-1_5_dev"
|
31 |
model = AutoModelForCausalLM.from_pretrained(model_str).to("cuda:1")
|
32 |
tokenizer = AutoTokenizer.from_pretrained(model_str)
|
|
|
34 |
branch_locations = list(range(0, 23, 5))
|
35 |
model = BranchyModel(branch_locations= branch_locations, model= model).to("cuda:1")
|
36 |
|
37 |
+
# Load the specific model
|
38 |
+
model.load_state_dict(torch.load(model_path, map_location="cuda:1"))
|
|
|
|
|
|
|
|
|
39 |
|
40 |
return model, tokenizer
|
41 |
|
42 |
|
43 |
if "model" not in st.session_state or "tokenizer" not in st.session_state:
|
44 |
print("Loading model...")
|
45 |
+
st.session_state.model, st.session_state.tokenizer = load_model("model/model.bin")
|
46 |
st.session_state["head_number"] = len(st.session_state.model.branch_locations) + 1
|
47 |
print(f"Head number: {st.session_state['head_number']}")
|
48 |
# Session state to store the current sentence
|