farzadab commited on
Commit
5383247
·
verified ·
1 Parent(s): 27da03b

Delete ultravox_pipeline.py

Browse files
Files changed (1) hide show
  1. ultravox_pipeline.py +0 -110
ultravox_pipeline.py DELETED
@@ -1,110 +0,0 @@
1
- import logging
2
- from typing import Any, Dict, List, Optional
3
-
4
- import transformers
5
-
6
- # We must use relative import in this directory to allow uploading to HF Hub
7
- from . import ultravox_model
8
- from . import ultravox_processing
9
-
10
-
11
- class UltravoxPipeline(transformers.Pipeline):
12
- def __init__(
13
- self,
14
- model: ultravox_model.UltravoxModel,
15
- tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None,
16
- audio_processor: Optional[transformers.ProcessorMixin] = None,
17
- **kwargs
18
- ):
19
- if tokenizer is None:
20
- tokenizer = transformers.AutoTokenizer.from_pretrained(
21
- model.config._name_or_path
22
- )
23
-
24
- if audio_processor is None:
25
- audio_processor = transformers.Wav2Vec2Processor.from_pretrained(
26
- model.config.audio_model_id
27
- )
28
-
29
- self.processor = ultravox_processing.UltravoxProcessor(
30
- audio_processor, tokenizer=tokenizer, stack_factor=model.config.stack_factor
31
- )
32
-
33
- super().__init__(model=model, tokenizer=tokenizer, **kwargs)
34
-
35
- def _sanitize_parameters(self, **kwargs):
36
- generation_kwargs = {}
37
- if "temperature" in kwargs:
38
- generation_kwargs["temperature"] = kwargs["temperature"]
39
- if "max_new_tokens" in kwargs:
40
- generation_kwargs["max_new_tokens"] = kwargs["max_new_tokens"]
41
- if "repetition_penalty" in kwargs:
42
- generation_kwargs["repetition_penalty"] = kwargs["repetition_penalty"]
43
- return {}, generation_kwargs, {}
44
-
45
- def preprocess(self, inputs: Dict[str, Any]):
46
- if "turns" in inputs:
47
- turns = inputs["turns"]
48
- else:
49
- prompt = inputs.get("prompt", "<|audio|>")
50
- if "<|audio|>" not in prompt:
51
- logging.warning(
52
- "Prompt does not contain '<|audio|>', appending '<|audio|>' to the end of the prompt."
53
- )
54
- prompt += " <|audio|>"
55
- turns = [{"role": "user", "content": prompt}]
56
-
57
- text = self.processor.tokenizer.apply_chat_template(turns, tokenize=False)
58
-
59
- # TODO: allow text-only mode?
60
- assert "audio" in inputs, "Audio input is required"
61
-
62
- if "sampling_rate" not in inputs:
63
- logging.warning(
64
- "No sampling rate provided, using default of 16kHz. We highly recommend providing the correct sampling rate."
65
- )
66
-
67
- return self.processor(
68
- text=text,
69
- audio=inputs["audio"],
70
- sampling_rate=inputs.get("sampling_rate", 16000),
71
- )
72
-
73
- def _forward(
74
- self,
75
- model_inputs: Dict[str, Any],
76
- temperature: Optional[float] = None,
77
- max_new_tokens: Optional[int] = None,
78
- repetition_penalty: float = 1.1,
79
- ) -> List[int]:
80
- temperature = temperature or None
81
- do_sample = temperature is not None
82
-
83
- terminators = [self.tokenizer.eos_token_id]
84
- if "<|eot_id|>" in self.tokenizer.added_tokens_encoder:
85
- terminators.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>"))
86
-
87
- input_len = model_inputs["input_ids"].shape[1]
88
-
89
- outputs = self.model.generate(
90
- **model_inputs,
91
- do_sample=do_sample,
92
- temperature=temperature,
93
- max_new_tokens=max_new_tokens,
94
- repetition_penalty=repetition_penalty,
95
- eos_token_id=terminators
96
- )
97
- return outputs[0][input_len:]
98
-
99
- def postprocess(self, model_outputs) -> str:
100
- output_text = self.tokenizer.decode(model_outputs, skip_special_tokens=True)
101
- return output_text
102
-
103
-
104
- transformers.pipelines.PIPELINE_REGISTRY.register_pipeline(
105
- "ultravox-pipeline",
106
- pipeline_class=UltravoxPipeline,
107
- pt_model=ultravox_model.UltravoxModel,
108
- default={"pt": ("fixie-ai/ultravox-v0.2", "main")},
109
- type="multimodal",
110
- )