dongyh20 commited on
Commit
745d2d6
·
1 Parent(s): ae6b5e2

update space

Browse files
Files changed (2) hide show
  1. app.py +151 -60
  2. requirements.txt +27 -1
app.py CHANGED
@@ -1,63 +1,154 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import re
4
+ from decord import VideoReader, cpu
5
+ from PIL import Image
6
+ import numpy as np
7
+ import transformers
8
+ from typing import Dict, Optional, Sequence, List
9
+
10
+ import sys
11
+ from oryx.conversation import conv_templates, SeparatorStyle
12
+ from oryx.model.builder import load_pretrained_model
13
+ from oryx.utils import disable_torch_init
14
+ from oryx.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video_genli
15
+ from oryx.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
16
+
17
+
18
+ model_path = "THUdyh/Oryx-7B"
19
+ model_name = get_model_name_from_path(model_path)
20
+ overwrite_config = {}
21
+ overwrite_config["mm_resampler_type"] = "dynamic_compressor"
22
+ overwrite_config["patchify_video_feature"] = False
23
+ overwrite_config["attn_implementation"] = "sdpa" if torch.__version__ >= "2.1.2" else "eager"
24
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, device_map="cuda:0", overwrite_config=overwrite_config)
25
+ model.to('cuda').eval()
26
+
27
+ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.") -> Dict:
28
+ roles = {"human": "<|im_start|>user", "gpt": "<|im_start|>assistant"}
29
+
30
+ im_start, im_end = tokenizer.additional_special_tokens_ids
31
+ nl_tokens = tokenizer("\n").input_ids
32
+ _system = tokenizer("system").input_ids + nl_tokens
33
+ _user = tokenizer("user").input_ids + nl_tokens
34
+ _assistant = tokenizer("assistant").input_ids + nl_tokens
35
+
36
+ # Apply prompt templates
37
+ input_ids, targets = [], []
38
+
39
+ source = sources
40
+ if roles[source[0]["from"]] != roles["human"]:
41
+ source = source[1:]
42
+
43
+ input_id, target = [], []
44
+ system = [im_start] + _system + tokenizer(system_message).input_ids + [im_end] + nl_tokens
45
+ input_id += system
46
+ target += [im_start] + [IGNORE_INDEX] * (len(system) - 3) + [im_end] + nl_tokens
47
+ assert len(input_id) == len(target)
48
+ for j, sentence in enumerate(source):
49
+ role = roles[sentence["from"]]
50
+ if has_image and sentence["value"] is not None and "<image>" in sentence["value"]:
51
+ num_image = len(re.findall(DEFAULT_IMAGE_TOKEN, sentence["value"]))
52
+ texts = sentence["value"].split('<image>')
53
+ _input_id = tokenizer(role).input_ids + nl_tokens
54
+ for i,text in enumerate(texts):
55
+ _input_id += tokenizer(text).input_ids
56
+ if i<len(texts)-1:
57
+ _input_id += [IMAGE_TOKEN_INDEX] + nl_tokens
58
+ _input_id += [im_end] + nl_tokens
59
+ assert sum([i==IMAGE_TOKEN_INDEX for i in _input_id])==num_image
60
+ else:
61
+ if sentence["value"] is None:
62
+ _input_id = tokenizer(role).input_ids + nl_tokens
63
+ else:
64
+ _input_id = tokenizer(role).input_ids + nl_tokens + tokenizer(sentence["value"]).input_ids + [im_end] + nl_tokens
65
+ input_id += _input_id
66
+ if role == "<|im_start|>user":
67
+ _target = [im_start] + [IGNORE_INDEX] * (len(_input_id) - 3) + [im_end] + nl_tokens
68
+ elif role == "<|im_start|>assistant":
69
+ _target = [im_start] + [IGNORE_INDEX] * len(tokenizer(role).input_ids) + _input_id[len(tokenizer(role).input_ids) + 1 : -2] + [im_end] + nl_tokens
70
+ else:
71
+ raise NotImplementedError
72
+ target += _target
73
+
74
+ input_ids.append(input_id)
75
+ targets.append(target)
76
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
77
+ targets = torch.tensor(targets, dtype=torch.long)
78
+ return input_ids
79
+
80
 
81
+ def oryx_inference(video, text):
82
+ vr = VideoReader(video, ctx=cpu(0))
83
+ total_frame_num = len(vr)
84
+ fps = round(vr.get_avg_fps())
85
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int)
86
+ frame_idx = uniform_sampled_frames.tolist()
87
+ spare_frames = vr.get_batch(frame_idx).asnumpy()
88
+ video = [Image.fromarray(frame) for frame in spare_frames]
89
+
90
+ conv_mode = "qwen_1_5"
91
+
92
+ question = text
93
+ question = "<image>\n" + question
94
+
95
+ conv = conv_templates[conv_mode].copy()
96
+ conv.append_message(conv.roles[0], question)
97
+ conv.append_message(conv.roles[1], None)
98
+ prompt = conv.get_prompt()
99
+
100
+ input_ids = preprocess_qwen([{'from': 'human','value': question},{'from': 'gpt','value': None}], tokenizer, has_image=True).cuda()
101
+
102
+ video_processed = []
103
+ for idx, frame in enumerate(video):
104
+ image_processor.do_resize = False
105
+ image_processor.do_center_crop = False
106
+ frame = process_anyres_video_genli(frame, image_processor)
107
+
108
+ if frame_idx is not None and idx in frame_idx:
109
+ video_processed.append(frame.unsqueeze(0))
110
+ elif frame_idx is None:
111
+ video_processed.append(frame.unsqueeze(0))
112
+
113
+ if frame_idx is None:
114
+ frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
115
+
116
+ video_processed = torch.cat(video_processed, dim=0).bfloat16().cuda()
117
+ video_processed = (video_processed, video_processed)
118
+
119
+ video_data = (video_processed, (384, 384), "video")
120
+
121
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
122
+ keywords = [stop_str]
123
+
124
+ with torch.inference_mode():
125
+ output_ids = model.generate(
126
+ inputs=input_ids,
127
+ images=video_data[0][0],
128
+ images_highres=video_data[0][1],
129
+ modalities=video_data[2],
130
+ do_sample=False,
131
+ temperature=0,
132
+ max_new_tokens=1024,
133
+ use_cache=True,
134
+ )
135
+
136
+
137
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
138
+ outputs = outputs.strip()
139
+ if outputs.endswith(stop_str):
140
+ outputs = outputs[:-len(stop_str)]
141
+ outputs = outputs.strip()
142
+ return outputs
143
+
144
+ # Define input and output for the Gradio interface
145
+ demo = gr.Interface(
146
+ fn=oryx_inference,
147
+ inputs=[gr.Video(label="Input Video"), gr.Textbox(label="Input Text")],
148
+ outputs="text",
149
+ title="Oryx Inference",
150
+ description="This is a demo for Oryx inference."
151
+ )
152
 
153
+ # Launch the Gradio app
154
+ demo.launch(server_name="0.0.0.0",server_port=80)
requirements.txt CHANGED
@@ -1 +1,27 @@
1
- huggingface_hub==0.22.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub==0.22.2
2
+ torch
3
+ torchvision,
4
+ transformers==4.39.2
5
+ tokenizers==0.15.2
6
+ sentencepiece==0.1.99
7
+ shortuuid
8
+ accelerate==0.27.2
9
+ peft==0.4.0
10
+ bitsandbytes==0.41.0
11
+ pydantic<2,>=1
12
+ markdown2
13
+ numpy
14
+ scikit-learn==1.2.2
15
+ gradio==3.35.2
16
+ gradio_client==0.2.9
17
+ requests
18
+ httpx==0.24.0
19
+ uvicorn
20
+ fastapi
21
+ einops==0.6.1
22
+ einops-exts==0.0.4
23
+ timm==0.9.16
24
+ decord
25
+ ninja
26
+ deepspeed==0.12.2
27
+ protobuf