Join the conversation

Join the community of Machine Learners and AI enthusiasts.

Sign Up
m-ricย 
posted an update Mar 26, 2024
Post
1703
๐—›๐—ผ๐˜„ ๐—ฑ๐—ผ๐—ฒ๐˜€ ๐—ฏ๐—ฒ๐—ฎ๐—บ ๐˜€๐—ฒ๐—ฎ๐—ฟ๐—ฐ๐—ต ๐—ฑ๐—ฒ๐—ฐ๐—ผ๐—ฑ๐—ถ๐—ป๐—ด ๐˜„๐—ผ๐—ฟ๐—ธ? โžก๏ธ ๐™‰๐™š๐™ฌ ๐™ซ๐™ž๐™จ๐™ช๐™–๐™ก๐™ž๐™ฏ๐™–๐™ฉ๐™ž๐™ค๐™ฃ ๐™ฉ๐™ค๐™ค๐™ก! ๐Ÿ‘€

In Decoder-type LLMs like GPT4 or Mistral-Large, the output is generated one token (=word part) at a time. That's why they're nicknamed "stochastic parrots": the "thinking" process only happens one step at a time, so it can seem really myopic.

๐’๐จ ๐ก๐จ๐ฐ ๐ข๐ฌ ๐ญ๐ก๐ž ๐ง๐ž๐ฑ๐ญ ๐ญ๐จ๐ค๐ž๐ง ๐ฌ๐ž๐ฅ๐ž๐œ๐ญ๐ž๐?

๐Ÿ“Š Given its input sentence like "๐˜ž๐˜ฉ๐˜ข๐˜ต ๐˜ช๐˜ด ๐˜ต๐˜ฉ๐˜ฆ 7๐˜ต๐˜ฉ ๐˜๐˜ช๐˜ฃ๐˜ฐ๐˜ฏ๐˜ข๐˜ค๐˜ค๐˜ช ๐˜ฏ๐˜ถ๐˜ฎ๐˜ฃ๐˜ฆ๐˜ณ? ๐˜›๐˜ฉ๐˜ฆ 7๐˜ต๐˜ฉ ๐˜๐˜ช๐˜ฃ๐˜ฐ๐˜ฏ๐˜ข๐˜ค๐˜ค๐˜ช ๐˜ฏ๐˜ถ๐˜ฎ๐˜ฃ๐˜ฆ๐˜ณ", the Decoder LLM generates, for each token in its vocabulary, a score that represents this token's probability of coming next.
For instance: "๐™ž๐™จ" gets score 0.56, and "๐™˜๐™–๐™ฃ" gets score 0.35.

๐Ÿค‘ ๐†๐ซ๐ž๐ž๐๐ฒ ๐๐ž๐œ๐จ๐๐ข๐ง๐  is the naive option where you simply take the next most probable token at each step. But this creates paths that maximize very short-term rewards, thus may overlook better paths for the long term (like this time when you played FIFA all evening and arrived unprepared to your school exam on the next day).
In our example, the next highest score token might be "๐™ž๐™จ", but this will strongly bias the LLM towards giving an hasty response. On the opposite, starting with "๐™˜๐™–๐™ฃ" could have been completed with "๐˜ฃ๐˜ฆ ๐˜ฐ๐˜ฃ๐˜ต๐˜ข๐˜ช๐˜ฏ๐˜ฆ๐˜ฅ ๐˜ง๐˜ณ๐˜ฐ๐˜ฎ ๐˜ค๐˜ฐ๐˜ฎ๐˜ฑ๐˜ถ๐˜ต๐˜ช๐˜ฏ๐˜จ ๐˜ฑ๐˜ณ๐˜ฆ๐˜ท๐˜ช๐˜ฐ๐˜ถ๐˜ด ๐˜๐˜ช๐˜ฃ๐˜ฐ๐˜ฏ๐˜ข๐˜ค๐˜ค๐˜ช ๐˜ฏ๐˜ถ๐˜ฎ๐˜ฃ๐˜ฆ๐˜ณ๐˜ด ๐˜ง๐˜ช๐˜ณ๐˜ด๐˜ต", which steers the LLM towards a correct reasoning!

๐Ÿ—บ๏ธ ๐๐ž๐š๐ฆ ๐ฌ๐ž๐š๐ซ๐œ๐ก improves on greedy decoding by generating at each step several paths - called beams - instead of one. This allows the generation to explore a much larger space, thus find better completions. In our example, both the "๐™ž๐™จ" and the "๐™˜๐™–๐™ฃ" completion could be tested. โœ…

๐Ÿ‘‰ I've created a tool to let you visualize it, thank you @joaogante for your great help!
๐™๐™ง๐™ฎ ๐™ž๐™ฉ ๐™๐™š๐™ง๐™š: m-ric/beam_search_visualizer
In this post