nicolaus625 commited on
Commit
9df164c
1 Parent(s): 6d48ca2

update readme.md with one sample inference

Browse files
Files changed (1) hide show
  1. README.md +35 -39
README.md CHANGED
@@ -42,27 +42,37 @@ class StoppingCriteriaSub(StoppingCriteria):
42
  return True
43
  return False
44
 
45
- def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5,
46
- repetition_penalty=1.0, length_penalty=1, temperature=0.1, max_length=2000):
47
- audio = samples["audio"].cuda()
48
- audio_embeds, atts_audio = self.encode_audio(audio)
49
- if 'instruction_input' in samples: # instruction dataset
50
- #print('Instruction Batch')
51
- instruction_prompt = []
52
- for instruction in samples['instruction_input']:
53
- prompt = '<Audio><AudioHere></Audio> ' + instruction
54
- instruction_prompt.append(self.prompt_template.format(prompt))
55
- audio_embeds, atts_audio = self.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
56
- self.llama_tokenizer.padding_side = "right"
 
 
 
 
 
 
 
 
 
 
57
  batch_size = audio_embeds.shape[0]
58
  bos = torch.ones([batch_size, 1],
59
  dtype=torch.long,
60
- device=torch.device('cuda')) * self.llama_tokenizer.bos_token_id
61
- bos_embeds = self.llama_model.model.embed_tokens(bos)
62
- atts_bos = atts_audio[:, :1]
63
  inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
64
- attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
65
- outputs = self.llama_model.generate(
66
  inputs_embeds=inputs_embeds,
67
  max_new_tokens=max_new_tokens,
68
  stopping_criteria=stopping,
@@ -79,34 +89,20 @@ def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=
79
  output_token = output_token[1:]
80
  if output_token[0] == 1: # if there is a start token <s> at the beginning. remove it
81
  output_token = output_token[1:]
82
- output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
83
  output_text = output_text.split('###')[0] # remove the stop sign '###'
84
  output_text = output_text.split('Assistant:')[-1].strip()
85
  return output_text
86
 
87
- processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
88
- ds = CMIDataset(processor, 'path/to/MI_dataset', 'test', question_type='short')
89
- dl = DataLoader(
90
- ds,
91
- batch_size=1,
92
- num_workers=0,
93
- pin_memory=True,
94
- shuffle=False,
95
- drop_last=True,
96
- collate_fn=ds.collater
97
- )
98
 
 
 
99
  stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
100
- torch.tensor([2277, 29937]).cuda()])])
101
-
102
- from transformers import AutoModel
103
- model_short = AutoModel.from_pretrained("m-a-p/MusiLingo-short-v1")
104
-
105
- for idx, sample in tqdm(enumerate(dl)):
106
- ans = answer(Musilingo_short.model, sample, stopping, length_penalty=100, temperature=0.1)
107
- txt = sample['text_input'][0]
108
- print(txt)
109
- print(and)
110
  ```
111
 
112
  # Citing This Work
 
42
  return True
43
  return False
44
 
45
+ def get_musilingo_pred(model, text, audio_path, stopping, length_penalty=1, temperature=0.1,
46
+ max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5, repetition_penalty=1.0):
47
+
48
+ # see https://huggingface.co/m-a-p/MusiLingo-musicqa-v1 for load_audio function definition
49
+ audio = load_audio(audio_path, target_sr=24000,
50
+ is_mono=True,
51
+ is_normalize=False,
52
+ crop_to_length_in_sample_points=int(30*16000)+1,
53
+ crop_randomly=True,
54
+ pad=False).cuda()
55
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
56
+ audio = processor(audio,
57
+ sampling_rate=24000,
58
+ return_tensors="pt")['input_values'][0].cuda()
59
+
60
+ audio_embeds, atts_audio = model.encode_audio(audio)
61
+
62
+ prompt = '<Audio><AudioHere></Audio> ' + text
63
+ instruction_prompt = [model.prompt_template.format(prompt)]
64
+ audio_embeds, atts_audio = model.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
65
+
66
+ model.llama_tokenizer.padding_side = "right"
67
  batch_size = audio_embeds.shape[0]
68
  bos = torch.ones([batch_size, 1],
69
  dtype=torch.long,
70
+ device=torch.device('cuda')) * model.llama_tokenizer.bos_token_id
71
+ bos_embeds = model.llama_model.model.embed_tokens(bos)
72
+ # atts_bos = atts_audio[:, :1]
73
  inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
74
+ # attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
75
+ outputs = model.llama_model.generate(
76
  inputs_embeds=inputs_embeds,
77
  max_new_tokens=max_new_tokens,
78
  stopping_criteria=stopping,
 
89
  output_token = output_token[1:]
90
  if output_token[0] == 1: # if there is a start token <s> at the beginning. remove it
91
  output_token = output_token[1:]
92
+ output_text = model.llama_tokenizer.decode(output_token, add_special_tokens=False)
93
  output_text = output_text.split('###')[0] # remove the stop sign '###'
94
  output_text = output_text.split('Assistant:')[-1].strip()
95
  return output_text
96
 
97
+ musilingo = AutoModel.from_pretrained("m-a-p/MusiLingo-short-v1", trust_remote_code=True)
98
+ musilingo.to("cuda")
99
+ musilingo.eval()
 
 
 
 
 
 
 
 
100
 
101
+ prompt = "this is the task instruction and input question for MusiLingo model"
102
+ audio = "/path/to/the/audio"
103
  stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
104
+ torch.tensor([2277, 29937]).cuda()])])
105
+ response = get_musilingo_pred(musilingo.model, prompt, audio_path, stopping, length_penalty=100, temperature=0.1)
 
 
 
 
 
 
 
 
106
  ```
107
 
108
  # Citing This Work