hyx21 commited on
Commit
0576766
·
verified ·
1 Parent(s): d09e0b9

Upload modeling_minicpm.py

Browse files
Files changed (1) hide show
  1. modeling_minicpm.py +28 -1
modeling_minicpm.py CHANGED
@@ -20,7 +20,7 @@
20
  """ PyTorch MiniCPM model."""
21
  import math
22
  import warnings
23
- from typing import List, Optional, Tuple, Union
24
 
25
  import torch
26
  import torch.nn.functional as F
@@ -49,6 +49,7 @@ from transformers.utils import (
49
  )
50
  from transformers.utils.import_utils import is_torch_fx_available
51
  from .configuration_minicpm import MiniCPMConfig
 
52
 
53
 
54
  if is_flash_attn_2_available():
@@ -1302,6 +1303,32 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
1302
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1303
  )
1304
  return reordered_past
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1305
 
1306
 
1307
  @add_start_docstrings(
 
20
  """ PyTorch MiniCPM model."""
21
  import math
22
  import warnings
23
+ from typing import List, Optional, Tuple, Union, Dict
24
 
25
  import torch
26
  import torch.nn.functional as F
 
49
  )
50
  from transformers.utils.import_utils import is_torch_fx_available
51
  from .configuration_minicpm import MiniCPMConfig
52
+ import re
53
 
54
 
55
  if is_flash_attn_2_available():
 
1303
  tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1304
  )
1305
  return reordered_past
1306
+
1307
+ @torch.inference_mode()
1308
+ def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1309
+ max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None,
1310
+ **kwargs):
1311
+ if history is None:
1312
+ history = []
1313
+ if logits_processor:
1314
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1315
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1316
+ else:
1317
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1318
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1319
+
1320
+ history.append({"role": role, "content": query})
1321
+ history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False)
1322
+ inputs = tokenizer(history_str, return_tensors='pt').to(self.device)
1323
+ outputs = self.generate(**inputs, **gen_kwargs)
1324
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1325
+ response = tokenizer.decode(outputs)
1326
+ pattern = re.compile(r".*?(?=<AI>|<用户>)", re.DOTALL)
1327
+ matches = pattern.findall(response)
1328
+ if len(matches) > 0:
1329
+ response = matches[0]
1330
+ history.append({"role": "assistant", "content": response})
1331
+ return response, history
1332
 
1333
 
1334
  @add_start_docstrings(