Spaces:
Runtime error
Runtime error
# Optimize prompts by training on prompts-ratings pairings dataset | |
# taken from https://github.com/JD-P/simulacra-aesthetic-captions | |
import os | |
import sqlite3 | |
from urllib.request import urlretrieve | |
from accelerate import Accelerator | |
import trlx | |
from trlx.data.default_configs import default_ilql_config | |
url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite" | |
dbpath = "sac_public_2022_06_29.sqlite" | |
if __name__ == "__main__": | |
accelerator = Accelerator() | |
if os.environ.get("LOCAL_RANK", "0") == "0" and not os.path.exists(dbpath): | |
print(f"fetching {dbpath}") | |
urlretrieve(url, dbpath) | |
accelerator.wait_for_everyone() | |
conn = sqlite3.connect(dbpath) | |
c = conn.cursor() | |
c.execute( | |
"SELECT prompt, rating FROM ratings " | |
"JOIN images ON images.id=ratings.iid " | |
"JOIN generations ON images.gid=generations.id " | |
"WHERE rating IS NOT NULL;" | |
) | |
prompts, ratings = tuple(map(list, zip(*c.fetchall()))) | |
trlx.train( | |
config=default_ilql_config(), | |
samples=prompts, | |
rewards=ratings, | |
eval_prompts=["An astronaut riding a horse"] * 64, | |
) | |