jbilcke-hf HF staff commited on
Commit
d64b893
β€’
1 Parent(s): a2c0551

working on improvements

Browse files
src/app/main.tsx CHANGED
@@ -68,7 +68,7 @@ export default function Main() {
68
  const newPanelsPrompts: string[] = []
69
  const newCaptions: string[] = []
70
 
71
- const nbPanelsToGenerate = 2
72
 
73
  for (
74
  let currentPanel = 0;
 
68
  const newPanelsPrompts: string[] = []
69
  const newCaptions: string[] = []
70
 
71
+ const nbPanelsToGenerate = 1
72
 
73
  for (
74
  let currentPanel = 0;
src/app/queries/getStoryContinuation.ts CHANGED
@@ -7,8 +7,8 @@ export const getStoryContinuation = async ({
7
  preset,
8
  stylePrompt = "",
9
  userStoryPrompt = "",
10
- nbPanelsToGenerate = 2,
11
- nbTotalPanels = 8,
12
  existingPanels = [],
13
  }: {
14
  preset: Preset;
 
7
  preset,
8
  stylePrompt = "",
9
  userStoryPrompt = "",
10
+ nbPanelsToGenerate = 1,
11
+ nbTotalPanels = 4,
12
  existingPanels = [],
13
  }: {
14
  preset: Preset;
src/app/queries/predictNextPanels.ts CHANGED
@@ -6,12 +6,13 @@ import { cleanJson } from "@/lib/cleanJson"
6
  import { createZephyrPrompt } from "@/lib/createZephyrPrompt"
7
  import { dirtyGeneratedPanelCleaner } from "@/lib/dirtyGeneratedPanelCleaner"
8
  import { dirtyGeneratedPanelsParser } from "@/lib/dirtyGeneratedPanelsParser"
 
9
 
10
  export const predictNextPanels = async ({
11
  preset,
12
  prompt = "",
13
- nbPanelsToGenerate = 2,
14
- nbTotalPanels = 8,
15
  existingPanels = [],
16
  }: {
17
  preset: Preset;
@@ -58,17 +59,26 @@ export const predictNextPanels = async ({
58
 
59
  let result = ""
60
 
 
 
 
 
 
 
61
  try {
62
  // console.log(`calling predict(${query}, ${nbTotalPanels})`)
63
- result = `${await predict(query, nbPanelsToGenerate) || ""}`.trim()
64
  console.log("LLM result (1st trial):", result)
65
  if (!result.length) {
66
  throw new Error("empty result on 1st trial!")
67
  }
68
  } catch (err) {
69
  // console.log(`prediction of the story failed, trying again..`)
 
 
 
70
  try {
71
- result = `${await predict(query + " \n ", nbPanelsToGenerate) || ""}`.trim()
72
  console.log("LLM result (2nd trial):", result)
73
  if (!result.length) {
74
  throw new Error("empty result on 2nd trial!")
 
6
  import { createZephyrPrompt } from "@/lib/createZephyrPrompt"
7
  import { dirtyGeneratedPanelCleaner } from "@/lib/dirtyGeneratedPanelCleaner"
8
  import { dirtyGeneratedPanelsParser } from "@/lib/dirtyGeneratedPanelsParser"
9
+ import { sleep } from "@/lib/sleep"
10
 
11
  export const predictNextPanels = async ({
12
  preset,
13
  prompt = "",
14
+ nbPanelsToGenerate = 1,
15
+ nbTotalPanels = 4,
16
  existingPanels = [],
17
  }: {
18
  preset: Preset;
 
59
 
60
  let result = ""
61
 
62
+ // we don't require a lot of token for our task
63
+ // but to be safe, let's count ~130 tokens per panel
64
+ const nbTokensPerPanel = 130
65
+
66
+ const nbMaxNewTokens = nbPanelsToGenerate * nbTokensPerPanel
67
+
68
  try {
69
  // console.log(`calling predict(${query}, ${nbTotalPanels})`)
70
+ result = `${await predict(query, nbMaxNewTokens)}`.trim()
71
  console.log("LLM result (1st trial):", result)
72
  if (!result.length) {
73
  throw new Error("empty result on 1st trial!")
74
  }
75
  } catch (err) {
76
  // console.log(`prediction of the story failed, trying again..`)
77
+ // this should help throttle things on a bit on the LLM API side
78
+ await sleep(2000)
79
+
80
  try {
81
+ result = `${await predict(query + " \n ", nbMaxNewTokens)}`.trim()
82
  console.log("LLM result (2nd trial):", result)
83
  if (!result.length) {
84
  throw new Error("empty result on 2nd trial!")
src/app/queries/predictWithGroq.ts CHANGED
@@ -2,7 +2,7 @@
2
 
3
  import Groq from "groq-sdk"
4
 
5
- export async function predict(inputs: string, nbPanels: number): Promise<string> {
6
  const groqApiKey = `${process.env.AUTH_GROQ_API_KEY || ""}`
7
  const groqApiModel = `${process.env.LLM_GROQ_API_MODEL || "mixtral-8x7b-32768"}`
8
 
@@ -18,6 +18,9 @@ export async function predict(inputs: string, nbPanels: number): Promise<string>
18
  const res = await groq.chat.completions.create({
19
  messages: messages,
20
  model: groqApiModel,
 
 
 
21
  })
22
 
23
  return res.choices[0].message.content || ""
 
2
 
3
  import Groq from "groq-sdk"
4
 
5
+ export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
6
  const groqApiKey = `${process.env.AUTH_GROQ_API_KEY || ""}`
7
  const groqApiModel = `${process.env.LLM_GROQ_API_MODEL || "mixtral-8x7b-32768"}`
8
 
 
18
  const res = await groq.chat.completions.create({
19
  messages: messages,
20
  model: groqApiModel,
21
+ stream: false,
22
+ temperature: 0.5,
23
+ max_tokens: nbMaxNewTokens,
24
  })
25
 
26
  return res.choices[0].message.content || ""
src/app/queries/predictWithHuggingFace.ts CHANGED
@@ -3,7 +3,7 @@
3
  import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
4
  import { LLMEngine } from "@/types"
5
 
6
- export async function predict(inputs: string, nbPanels: number): Promise<string> {
7
  const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)
8
 
9
  const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
@@ -12,10 +12,6 @@ export async function predict(inputs: string, nbPanels: number): Promise<string>
12
 
13
  let hfie: HfInferenceEndpoint = hf
14
 
15
- // we don't require a lot of token for our task
16
- // but to be safe, let's count ~110 tokens per panel
17
- const nbMaxNewTokens = nbPanels * 130 // 110 isn't enough anymore for long dialogues
18
-
19
  switch (llmEngine) {
20
  case "INFERENCE_ENDPOINT":
21
  if (inferenceEndpoint) {
 
3
  import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
4
  import { LLMEngine } from "@/types"
5
 
6
+ export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
7
  const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)
8
 
9
  const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
 
12
 
13
  let hfie: HfInferenceEndpoint = hf
14
 
 
 
 
 
15
  switch (llmEngine) {
16
  case "INFERENCE_ENDPOINT":
17
  if (inferenceEndpoint) {
src/app/queries/predictWithOpenAI.ts CHANGED
@@ -3,7 +3,7 @@
3
  import type { ChatCompletionMessage } from "openai/resources/chat"
4
  import OpenAI from "openai"
5
 
6
- export async function predict(inputs: string, nbPanels: number): Promise<string> {
7
  const openaiApiKey = `${process.env.AUTH_OPENAI_API_KEY || ""}`
8
  const openaiApiBaseUrl = `${process.env.LLM_OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
9
  const openaiApiModel = `${process.env.LLM_OPENAI_API_MODEL || "gpt-3.5-turbo"}`
@@ -23,6 +23,8 @@ export async function predict(inputs: string, nbPanels: number): Promise<string>
23
  stream: false,
24
  model: openaiApiModel,
25
  temperature: 0.8,
 
 
26
  // TODO: use the nbPanels to define a max token limit
27
  })
28
 
 
3
  import type { ChatCompletionMessage } from "openai/resources/chat"
4
  import OpenAI from "openai"
5
 
6
+ export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
7
  const openaiApiKey = `${process.env.AUTH_OPENAI_API_KEY || ""}`
8
  const openaiApiBaseUrl = `${process.env.LLM_OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
9
  const openaiApiModel = `${process.env.LLM_OPENAI_API_MODEL || "gpt-3.5-turbo"}`
 
23
  stream: false,
24
  model: openaiApiModel,
25
  temperature: 0.8,
26
+ max_tokens: nbMaxNewTokens,
27
+
28
  // TODO: use the nbPanels to define a max token limit
29
  })
30