Update app.py
Browse files
app.py
CHANGED
@@ -57,8 +57,6 @@ def generate_interactive(
|
|
57 |
):
|
58 |
inputs = tokenizer([prompt], padding=True, return_tensors='pt')
|
59 |
input_length = len(inputs['input_ids'][0])
|
60 |
-
for k, v in inputs.items():
|
61 |
-
inputs[k] = v.cuda()
|
62 |
input_ids = inputs['input_ids']
|
63 |
_, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
64 |
if generation_config is None:
|
@@ -184,7 +182,7 @@ def load_model():
|
|
184 |
model_dir = 'internlm/internlm2-chat-1_8b'
|
185 |
model = (AutoModelForCausalLM.from_pretrained(
|
186 |
model_dir,
|
187 |
-
trust_remote_code=True).to(torch.bfloat16)
|
188 |
tokenizer = AutoTokenizer.from_pretrained(
|
189 |
model_dir,
|
190 |
trust_remote_code=True)
|
@@ -232,7 +230,6 @@ def combine_history(prompt):
|
|
232 |
|
233 |
|
234 |
def main():
|
235 |
-
# torch.cuda.empty_cache()
|
236 |
print('load model begin.')
|
237 |
model, tokenizer = load_model()
|
238 |
print('load model end.')
|
@@ -278,7 +275,7 @@ def main():
|
|
278 |
'role': 'robot',
|
279 |
'content': cur_response, # pylint: disable=undefined-loop-variable
|
280 |
})
|
281 |
-
|
282 |
|
283 |
|
284 |
if __name__ == '__main__':
|
|
|
57 |
):
|
58 |
inputs = tokenizer([prompt], padding=True, return_tensors='pt')
|
59 |
input_length = len(inputs['input_ids'][0])
|
|
|
|
|
60 |
input_ids = inputs['input_ids']
|
61 |
_, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
62 |
if generation_config is None:
|
|
|
182 |
model_dir = 'internlm/internlm2-chat-1_8b'
|
183 |
model = (AutoModelForCausalLM.from_pretrained(
|
184 |
model_dir,
|
185 |
+
trust_remote_code=True).to(torch.bfloat16))
|
186 |
tokenizer = AutoTokenizer.from_pretrained(
|
187 |
model_dir,
|
188 |
trust_remote_code=True)
|
|
|
230 |
|
231 |
|
232 |
def main():
|
|
|
233 |
print('load model begin.')
|
234 |
model, tokenizer = load_model()
|
235 |
print('load model end.')
|
|
|
275 |
'role': 'robot',
|
276 |
'content': cur_response, # pylint: disable=undefined-loop-variable
|
277 |
})
|
278 |
+
|
279 |
|
280 |
|
281 |
if __name__ == '__main__':
|