nicolaus625 commited on
Commit
1cb4787
·
verified ·
1 Parent(s): 3dabb20

update readme.md with one sample inferrence

Browse files
Files changed (1) hide show
  1. README.md +47 -39
README.md CHANGED
@@ -43,27 +43,48 @@ class StoppingCriteriaSub(StoppingCriteria):
43
  return True
44
  return False
45
 
46
- def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5,
47
- repetition_penalty=1.0, length_penalty=1, temperature=0.1, max_length=2000):
48
- audio = samples["audio"].cuda()
49
- audio_embeds, atts_audio = self.encode_audio(audio)
50
- if 'instruction_input' in samples: # instruction dataset
51
- #print('Instruction Batch')
52
- instruction_prompt = []
53
- for instruction in samples['instruction_input']:
54
- prompt = '<Audio><AudioHere></Audio> ' + instruction
55
- instruction_prompt.append(self.prompt_template.format(prompt))
56
- audio_embeds, atts_audio = self.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
57
- self.llama_tokenizer.padding_side = "right"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  batch_size = audio_embeds.shape[0]
59
  bos = torch.ones([batch_size, 1],
60
  dtype=torch.long,
61
- device=torch.device('cuda')) * self.llama_tokenizer.bos_token_id
62
- bos_embeds = self.llama_model.model.embed_tokens(bos)
63
- atts_bos = atts_audio[:, :1]
64
  inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
65
- attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
66
- outputs = self.llama_model.generate(
67
  inputs_embeds=inputs_embeds,
68
  max_new_tokens=max_new_tokens,
69
  stopping_criteria=stopping,
@@ -80,34 +101,21 @@ def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=
80
  output_token = output_token[1:]
81
  if output_token[0] == 1: # if there is a start token <s> at the beginning. remove it
82
  output_token = output_token[1:]
83
- output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
84
  output_text = output_text.split('###')[0] # remove the stop sign '###'
85
  output_text = output_text.split('Assistant:')[-1].strip()
86
  return output_text
87
 
88
- processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
89
- ds = CMIDataset(processor, 'path/to/MI_dataset', 'test', question_type='long')
90
- dl = DataLoader(
91
- ds,
92
- batch_size=1,
93
- num_workers=0,
94
- pin_memory=True,
95
- shuffle=False,
96
- drop_last=True,
97
- collate_fn=ds.collater
98
- )
99
 
 
 
100
  stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
101
- torch.tensor([2277, 29937]).cuda()])])
102
-
103
- from transformers import AutoModel
104
- model_long = AutoModel.from_pretrained("m-a-p/MusiLingo-long-v1")
105
-
106
- for idx, sample in tqdm(enumerate(dl)):
107
- ans = answer(Musilingo_long.model, sample, stopping, length_penalty=100, temperature=0.1)
108
- txt = sample['text_input'][0]
109
- print(txt)
110
- print(and)
111
  ```
112
 
113
  # Citing This Work
 
43
  return True
44
  return False
45
 
46
+
47
+ class StoppingCriteriaSub(StoppingCriteria):
48
+ def __init__(self, stops=[], encounters=1):
49
+ super().__init__()
50
+ self.stops = stops
51
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
52
+ for stop in self.stops:
53
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
54
+ return True
55
+ return False
56
+
57
+ def get_musilingo_pred(model, text, audio_path, stopping, length_penalty=1, temperature=0.1,
58
+ max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5, repetition_penalty=1.0):
59
+
60
+ # see https://huggingface.co/m-a-p/MusiLingo-musicqa-v1 for load_audio function definition
61
+ audio = load_audio(audio_path, target_sr=24000,
62
+ is_mono=True,
63
+ is_normalize=False,
64
+ crop_to_length_in_sample_points=int(30*16000)+1,
65
+ crop_randomly=True,
66
+ pad=False).cuda()
67
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
68
+ audio = processor(audio,
69
+ sampling_rate=24000,
70
+ return_tensors="pt")['input_values'][0].cuda()
71
+
72
+ audio_embeds, atts_audio = model.encode_audio(audio)
73
+
74
+ prompt = '<Audio><AudioHere></Audio> ' + text
75
+ instruction_prompt = [model.prompt_template.format(prompt)]
76
+ audio_embeds, atts_audio = model.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
77
+
78
+ model.llama_tokenizer.padding_side = "right"
79
  batch_size = audio_embeds.shape[0]
80
  bos = torch.ones([batch_size, 1],
81
  dtype=torch.long,
82
+ device=torch.device('cuda')) * model.llama_tokenizer.bos_token_id
83
+ bos_embeds = model.llama_model.model.embed_tokens(bos)
84
+ # atts_bos = atts_audio[:, :1]
85
  inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
86
+ # attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
87
+ outputs = model.llama_model.generate(
88
  inputs_embeds=inputs_embeds,
89
  max_new_tokens=max_new_tokens,
90
  stopping_criteria=stopping,
 
101
  output_token = output_token[1:]
102
  if output_token[0] == 1: # if there is a start token <s> at the beginning. remove it
103
  output_token = output_token[1:]
104
+ output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
105
  output_text = output_text.split('###')[0] # remove the stop sign '###'
106
  output_text = output_text.split('Assistant:')[-1].strip()
107
  return output_text
108
 
109
+ musilingo = AutoModel.from_pretrained("m-a-p/MusiLingo-long-v1", trust_remote_code=True)
110
+ musilingo.to("cuda")
111
+ musilingo.eval()
 
 
 
 
 
 
 
 
112
 
113
+ prompt = "this is the task instruction and input question for MusiLingo model"
114
+ audio = "/path/to/the/audio"
115
  stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
116
+ torch.tensor([2277, 29937]).cuda()])])
117
+ response = get_musilingo_pred(musilingo.model, prompt, audio_path, stopping, length_penalty=100, temperature=0.1)
118
+
 
 
 
 
 
 
 
119
  ```
120
 
121
  # Citing This Work