Spaces:
Runtime error
Runtime error
File size: 1,484 Bytes
fa6856c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
# Toy example of optimizing textual interior designs to output the least number of rooms
# Also see https://architext.design/
import trlx
from trlx.data.default_configs import default_ppo_config
def reward_fn(samples, **kwargs):
"Gives a negative count of rooms for each sample"
return [-sample.count(":") for sample in samples]
prompts = [
"[prompt] the bedroom is adjacent to the living room [layout]",
"[prompt] a bedroom is adjacent to the living room [layout]",
"[prompt] the bedroom is adjacent to the kitchen [layout]",
"[prompt] a bedroom is adjacent to the kitchen [layout]",
"[prompt] the bedroom is adjacent to the kitchen [layout]",
"[prompt] the kitchen is adjacent to the bathroom [layout]",
"[prompt] a bathroom is adjacent to the living room [layout]",
"[prompt] the bathroom is adjacent to the living room [layout]",
"[prompt] the bedroom is not adjacent to the living room [layout]",
"[prompt] a bedroom is not adjacent to the living room [layout]",
"[prompt] the bedroom is not adjacent to the kitchen [layout]",
"[prompt] a bedroom is not adjacent to the kitchen [layout]",
"[prompt] the bedroom is not adjacent to the kitchen [layout]",
"[prompt] the kitchen is not adjacent to the bathroom [layout]",
]
def main():
config = default_ppo_config()
trlx.train(model_path="architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config)
if __name__ == "__main__":
main()
|