import json import logging import regex import time from pathlib import Path from typing import Annotated, Iterator import ijson import outlines import torch from pydantic import BaseModel, StringConstraints, conlist, conset from outlines import generate, models from outlines.generate.api import SequenceGenerator from transformers import AutoTokenizer from fsm import replace_fields from samplers import PenalizedMultinomialSampler from utils import StringIteratorIO logger = logging.getLogger(__name__) logger.warning("Loading model...") model_id = "google/gemma-2b-it" # model_id = "Qwen/Qwen1.5-0.5B-Chat" if torch.backends.mps.is_available(): device = "mps" model = models.transformers(model_id, device=device) else: device = "cuda" model = models.transformers(model_id, device=device) tokenizer = AutoTokenizer.from_pretrained(model_id) sampler = PenalizedMultinomialSampler() low_temperature_sampler = PenalizedMultinomialSampler(temperature=0.3) empty_tokens = [token_id for token_id in range(tokenizer.vocab_size) if not tokenizer.decode([token_id], skip_special_tokens=True).strip()] sampler.set_max_repeats(empty_tokens, 1) disallowed_patterns = [regex.compile(r"\p{Han}")] # focus on english for now disallowed_tokens = [token_id for token_id in range(tokenizer.vocab_size) if any(pattern.match(tokenizer.decode([token_id], skip_special_tokens=True)) for pattern in disallowed_patterns)] sampler.set_max_repeats(disallowed_tokens, 0) # This Sample & Dataset models ztr just templated with placeholder fields class Sample(BaseModel): # We use get_samples_generator() to replace the placeholder with the requested fields ABCDabcd12: str EFGHefgh34: str IJKLijkl56: str MNOPmnop78: str QRSTqrst90: str # PS: don't use StringConstraints with max_length here since it creates a fsm that is too big class Dataset(BaseModel): # We use get_samples_generator() to set the length to infinity data: conlist(Sample, min_length=2, max_length=3) # type: ignore samples_generator_template = generate.json(model, Dataset, sampler=sampler) class Columns(BaseModel): columns: conset(Annotated[str, StringConstraints(pattern=r'[a-z0-9_]+')], min_length=2, max_length=len(Sample.model_fields) - 1) # type: ignore columns_generator = generate.json(model, Columns, sampler=low_temperature_sampler) def get_samples_generator(new_fields: list[str]) -> SequenceGenerator: fsm=samples_generator_template.fsm fsm = replace_fields( # replace the placeholder fields by the real fields fsm=samples_generator_template.fsm, model=Sample, new_fields=new_fields, tokenizer=tokenizer, make_infinite_loop=True # to generate as many samples as we want ) return SequenceGenerator( fsm=fsm, model=samples_generator_template.model, sampler=samples_generator_template.sampler, device=device ) @outlines.prompt def columns_prompt(filename: str): """I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data. Give an example of column names / columns for this dataset to populate a SQL schema. Please reply in JSON format and place the columns in a field named "columns". """ @outlines.prompt def samples_prommpt(filename: str, prompt: str, columns: str): """I would like to create a JSON file named {{ filename }}.json for a dataset of realistic data. Give an example of content using a JSON field named "data" with samples with columns {{ columns }}. {{ prompt }} """ def stream_jsonl_file(filename: str, prompt: str, columns: list[str], seed: int, size: int) -> Iterator[str]: filename = Path(filename).stem logger.warning(f"stream_response({filename=}, {prompt=}, {columns=})") _start = time.time() rng = torch.Generator(device=model.device) rng.manual_seed(seed) if not columns: messages = [ {"role": "user", "content": columns_prompt(filename=filename)} ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns...") columns_generator_tokens = columns_generator.stream(text, rng=rng) for column in ijson.items(StringIteratorIO(columns_generator_tokens), "columns.item", buf_size=16): columns.append(column) logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns... DONE (total={time.time() - _start:.02f}s)") columns = [ tokenizer.decode(tokenizer.encode(column, add_special_tokens=False)[:len(orig_field)], skip_special_tokens=True) for column, orig_field in zip(columns, Sample.model_fields) ] logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide...") samples_generator = get_samples_generator(new_fields=columns) logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide... DONE (total={time.time() - _start:.02f}s)") logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples...") messages = [ {"role": "user", "content": samples_prommpt(filename=filename, prompt=prompt, columns="'" + "', '".join(columns) + "'")} ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) samples_generator_tokens = samples_generator.stream(text, rng=rng) for _, sample in zip(range(size), ijson.items(StringIteratorIO(samples_generator_tokens), "data.item", buf_size=4)): yield json.dumps(sample, ensure_ascii=False) + "\n" logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating samples... DONE (total={time.time() - _start:.02f}s)")