Spaces:
Build error
Build error
File size: 3,002 Bytes
b7be07b 3cc3906 b7be07b 3cc3906 b7be07b 54c36b3 b7be07b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
from config import Config
from networks import peft_model
tokenizer = Config.tokenizer
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_tokens('<question-answer>')
peft_model, audio_model = Config.peft_model, Config.audio_model
clip_model, projection = Config.clip_model, Config.projection
def prepare_inputs(text_input=None, image_input=None, audio_input=None):
text_audio, text_embed, image_embed = None, None, None
if audio_input:
audio_transcribed = audio_model.transcribe(audio_input)
processed_audio = ''
for audio_segment in audio_transcribed['segments']:
processed_audio += audio_segment['text']
processed_audio = processed_audio.strip()
if image_input != None:
image_processed = Config.processor(images=image_input, return_tensors="pt")
with torch.no_grad():
outputs = clip_model(**image_processed.to(Config.device))
last_hidden_state = outputs.last_hidden_state[:, 1:, :]
image_embed = projection(last_hidden_state.to(Config.device)).to(torch.float16)
if audio_input != None and text_input != None:
text_audio = f"{text_input} {processed_audio}"
elif audio_input and text_input == None:
text_audio = processed_audio
elif audio_input == None and text_input:
text_audio = text_input
if text_audio:
tokenized_text_audio = tokenizer.encode(text_audio)
tokenized_text_audio = Config.IMAGE_SEPARATOR_TOKENS + tokenized_text_audio + [Config.QUESTION_ANSWER_SEPARATOR_ID]
with torch.no_grad():
tokenized_text_audio = torch.tensor(tokenized_text_audio)
text_embed = peft_model.model.model.embed_tokens(tokenized_text_audio.to(Config.device)).unsqueeze(0)
if text_audio != None and image_input != None:
combined_embed = torch.cat([image_embed, text_embed], dim=1)
elif text_audio and image_input == None:
combined_embed = text_embed
elif text_audio == None and image_input:
combined_embed = image_embed
return(combined_embed)
def chatbot_response(text_input, image_input, audio_input):
if text_input == '':
text_input = None
if text_input == None and image_input == None and audio_input == None:
return "Please enter text, upload an image, or record audio."
combined_embeds = prepare_inputs(text_input, image_input, audio_input)
generated_tokens = generate_tokens(combined_embeds, max_tokens=60)
return(tokenizer.decode(generated_tokens))
def generate_tokens(combined_embeds, max_tokens=100):
pred_tokens = []
combined_embed = combined_embeds
for _ in range(max_tokens):
logits = peft_model(inputs_embeds=combined_embed).logits[:, -1, :]
next_token_id = logits.argmax(dim=-1)
if next_token_id.item() == 50256:
break
pred_tokens.append(next_token_id.item())
next_token_embed = peft_model.model.model.embed_tokens(next_token_id.unsqueeze(0))
with torch.no_grad():
combined_embed = torch.cat((combined_embed, next_token_embed), dim=1)
return(pred_tokens) |