Spaces:
Sleeping
Sleeping
import json | |
import os | |
import torch | |
import argparse | |
from PIL import Image | |
from chameleon.inference.chameleon import ChameleonInferenceModel, Options | |
from constants import ( | |
MODEL_7B_PATH, | |
TOKENIZER_TEXT_PATH, | |
TOKENIZER_IMAGE_CFG_PATH, | |
TOKENIZER_IMAGE_PATH, | |
) | |
from typing import List, Tuple | |
import logging | |
# Set up the logging configuration | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
def split_token_sequence( | |
tokens: torch.LongTensor, | |
boi: int, | |
eoi: int | |
) -> List[Tuple[str, torch.LongTensor]]: | |
""" | |
Split a sequence of tokens into text and image segments. | |
Args: | |
tokens (torch.LongTensor): The token sequence. | |
boi (int): Begin of image token. | |
eoi (int): End of image token. | |
Returns: | |
List[Tuple[str, torch.LongTensor]]: List of tuples indicating segment type and tokens. | |
""" | |
batch_size, _ = tokens.shape | |
assert batch_size == 1, "Batch size must be 1" | |
device = tokens.device | |
tokens = tokens[0] # remove batch dimension | |
tokens = tokens.to(device) | |
segments = [] | |
current_segment = [] | |
in_image_seg = False | |
for token in tokens: | |
if token == boi: | |
# if entering an image segment, save the current text segment (if any) | |
if current_segment: | |
segments.append(("text_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1))) | |
current_segment = [] | |
in_image_seg = True | |
elif token == eoi and in_image_seg: | |
# if exiting an image segment, save the current image segment | |
segments.append(("image_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1))) | |
current_segment = [] | |
in_image_seg = False | |
else: | |
current_segment.append(token) | |
# save any remaining tokens | |
if current_segment: | |
if in_image_seg: | |
segments.append(("image_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1))) | |
else: | |
segments.append(("text_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1))) | |
return segments | |
def main(args: argparse.Namespace): | |
"""Main function to generate and process model output.""" | |
# Load Chameleon model | |
model = ChameleonInferenceModel( | |
MODEL_7B_PATH.as_posix(), | |
TOKENIZER_TEXT_PATH.as_posix(), | |
TOKENIZER_IMAGE_CFG_PATH.as_posix(), | |
TOKENIZER_IMAGE_PATH.as_posix(), | |
) | |
# Print model configuration | |
logging.info(f"Model path: {MODEL_7B_PATH}") | |
logging.info(f"Text tokenizer path: {TOKENIZER_TEXT_PATH}") | |
logging.info(f"Image tokenizer config path: {TOKENIZER_IMAGE_CFG_PATH}") | |
logging.info(f"Image tokenizer path: {TOKENIZER_IMAGE_PATH}") | |
# Generate options | |
options = Options() | |
# Prepare prompt | |
instructions = [args.instruction] | |
batch_prompt_ui = [] | |
for instruction in instructions: | |
if isinstance(instruction, Tuple): | |
inst, image_path = instruction | |
batch_prompt_ui += [ | |
[ | |
{"type": "image", "value": f"file:{image_path}"}, | |
{"type": "text", "value": inst} | |
], | |
] | |
else: | |
batch_prompt_ui += [ | |
[ | |
{"type": "text", "value": instruction} | |
], | |
] | |
# generate | |
tokens: torch.LongTensor = model.generate( | |
batch_prompt_ui=batch_prompt_ui, | |
options=options | |
) | |
# split | |
boi, eoi = model.vocab.begin_image, model.vocab.end_image # 8197(boi), 8196(eoi) | |
segments = split_token_sequence(tokens, boi, eoi) | |
# decode | |
os.makedirs(args.save_dir, exist_ok=True) | |
segments_data = [] | |
for seg_id, (seg_type, seg_tokens) in enumerate(segments): | |
if seg_type == "image_seg": | |
assert seg_tokens.shape[1] == 1024 | |
img = model.decode_image(seg_tokens)[0] | |
image_path = os.path.join(args.save_dir, f"{seg_id}.png") | |
img.save(image_path) | |
segments_data.append({"type": "image", "content": image_path}) | |
else: | |
assert seg_type == "text_seg" | |
decoded_text = model.decode_text(seg_tokens)[0] | |
segments_data.append({"type": "text", "content": decoded_text}) | |
jsonl_path = os.path.join("./segments.jsonl") | |
with open(jsonl_path, 'w') as jsonl_file: | |
for segment in segments_data: | |
jsonl_file.write(json.dumps(segment) + '\n') | |
def parse_arguments() -> argparse.Namespace: | |
"""Parse command line arguments.""" | |
parser = argparse.ArgumentParser(description="Generate interleaved image-text content based on text instructions.") | |
parser.add_argument("-i", "--instruction", type=str, required=True, help="The instruction for interleaved image-text generation.") | |
parser.add_argument("-s", "--save_dir", type=str, default="./outputs/interleaved/", help="The directory to save the generated images.") | |
args: argparse.Namespace = parser.parse_args() | |
return args | |
if __name__ == "__main__": | |
args: argparse.Namespace = parse_arguments() | |
main(args) | |