chatlawv1 / trlx /examples /simulacra.py
teachyourselfcoding's picture
Upload 245 files
fa6856c
raw
history blame
1.2 kB
# Optimize prompts by training on prompts-ratings pairings dataset
# taken from https://github.com/JD-P/simulacra-aesthetic-captions
import os
import sqlite3
from urllib.request import urlretrieve
from accelerate import Accelerator
import trlx
from trlx.data.default_configs import default_ilql_config
url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite"
dbpath = "sac_public_2022_06_29.sqlite"
if __name__ == "__main__":
accelerator = Accelerator()
if os.environ.get("LOCAL_RANK", "0") == "0" and not os.path.exists(dbpath):
print(f"fetching {dbpath}")
urlretrieve(url, dbpath)
accelerator.wait_for_everyone()
conn = sqlite3.connect(dbpath)
c = conn.cursor()
c.execute(
"SELECT prompt, rating FROM ratings "
"JOIN images ON images.id=ratings.iid "
"JOIN generations ON images.gid=generations.id "
"WHERE rating IS NOT NULL;"
)
prompts, ratings = tuple(map(list, zip(*c.fetchall())))
trlx.train(
config=default_ilql_config(),
samples=prompts,
rewards=ratings,
eval_prompts=["An astronaut riding a horse"] * 64,
)