Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
β’
d64b893
1
Parent(s):
a2c0551
working on improvements
Browse files
src/app/main.tsx
CHANGED
@@ -68,7 +68,7 @@ export default function Main() {
|
|
68 |
const newPanelsPrompts: string[] = []
|
69 |
const newCaptions: string[] = []
|
70 |
|
71 |
-
const nbPanelsToGenerate =
|
72 |
|
73 |
for (
|
74 |
let currentPanel = 0;
|
|
|
68 |
const newPanelsPrompts: string[] = []
|
69 |
const newCaptions: string[] = []
|
70 |
|
71 |
+
const nbPanelsToGenerate = 1
|
72 |
|
73 |
for (
|
74 |
let currentPanel = 0;
|
src/app/queries/getStoryContinuation.ts
CHANGED
@@ -7,8 +7,8 @@ export const getStoryContinuation = async ({
|
|
7 |
preset,
|
8 |
stylePrompt = "",
|
9 |
userStoryPrompt = "",
|
10 |
-
nbPanelsToGenerate =
|
11 |
-
nbTotalPanels =
|
12 |
existingPanels = [],
|
13 |
}: {
|
14 |
preset: Preset;
|
|
|
7 |
preset,
|
8 |
stylePrompt = "",
|
9 |
userStoryPrompt = "",
|
10 |
+
nbPanelsToGenerate = 1,
|
11 |
+
nbTotalPanels = 4,
|
12 |
existingPanels = [],
|
13 |
}: {
|
14 |
preset: Preset;
|
src/app/queries/predictNextPanels.ts
CHANGED
@@ -6,12 +6,13 @@ import { cleanJson } from "@/lib/cleanJson"
|
|
6 |
import { createZephyrPrompt } from "@/lib/createZephyrPrompt"
|
7 |
import { dirtyGeneratedPanelCleaner } from "@/lib/dirtyGeneratedPanelCleaner"
|
8 |
import { dirtyGeneratedPanelsParser } from "@/lib/dirtyGeneratedPanelsParser"
|
|
|
9 |
|
10 |
export const predictNextPanels = async ({
|
11 |
preset,
|
12 |
prompt = "",
|
13 |
-
nbPanelsToGenerate =
|
14 |
-
nbTotalPanels =
|
15 |
existingPanels = [],
|
16 |
}: {
|
17 |
preset: Preset;
|
@@ -58,17 +59,26 @@ export const predictNextPanels = async ({
|
|
58 |
|
59 |
let result = ""
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
try {
|
62 |
// console.log(`calling predict(${query}, ${nbTotalPanels})`)
|
63 |
-
result = `${await predict(query,
|
64 |
console.log("LLM result (1st trial):", result)
|
65 |
if (!result.length) {
|
66 |
throw new Error("empty result on 1st trial!")
|
67 |
}
|
68 |
} catch (err) {
|
69 |
// console.log(`prediction of the story failed, trying again..`)
|
|
|
|
|
|
|
70 |
try {
|
71 |
-
result = `${await predict(query + " \n ",
|
72 |
console.log("LLM result (2nd trial):", result)
|
73 |
if (!result.length) {
|
74 |
throw new Error("empty result on 2nd trial!")
|
|
|
6 |
import { createZephyrPrompt } from "@/lib/createZephyrPrompt"
|
7 |
import { dirtyGeneratedPanelCleaner } from "@/lib/dirtyGeneratedPanelCleaner"
|
8 |
import { dirtyGeneratedPanelsParser } from "@/lib/dirtyGeneratedPanelsParser"
|
9 |
+
import { sleep } from "@/lib/sleep"
|
10 |
|
11 |
export const predictNextPanels = async ({
|
12 |
preset,
|
13 |
prompt = "",
|
14 |
+
nbPanelsToGenerate = 1,
|
15 |
+
nbTotalPanels = 4,
|
16 |
existingPanels = [],
|
17 |
}: {
|
18 |
preset: Preset;
|
|
|
59 |
|
60 |
let result = ""
|
61 |
|
62 |
+
// we don't require a lot of token for our task
|
63 |
+
// but to be safe, let's count ~130 tokens per panel
|
64 |
+
const nbTokensPerPanel = 130
|
65 |
+
|
66 |
+
const nbMaxNewTokens = nbPanelsToGenerate * nbTokensPerPanel
|
67 |
+
|
68 |
try {
|
69 |
// console.log(`calling predict(${query}, ${nbTotalPanels})`)
|
70 |
+
result = `${await predict(query, nbMaxNewTokens)}`.trim()
|
71 |
console.log("LLM result (1st trial):", result)
|
72 |
if (!result.length) {
|
73 |
throw new Error("empty result on 1st trial!")
|
74 |
}
|
75 |
} catch (err) {
|
76 |
// console.log(`prediction of the story failed, trying again..`)
|
77 |
+
// this should help throttle things on a bit on the LLM API side
|
78 |
+
await sleep(2000)
|
79 |
+
|
80 |
try {
|
81 |
+
result = `${await predict(query + " \n ", nbMaxNewTokens)}`.trim()
|
82 |
console.log("LLM result (2nd trial):", result)
|
83 |
if (!result.length) {
|
84 |
throw new Error("empty result on 2nd trial!")
|
src/app/queries/predictWithGroq.ts
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
|
3 |
import Groq from "groq-sdk"
|
4 |
|
5 |
-
export async function predict(inputs: string,
|
6 |
const groqApiKey = `${process.env.AUTH_GROQ_API_KEY || ""}`
|
7 |
const groqApiModel = `${process.env.LLM_GROQ_API_MODEL || "mixtral-8x7b-32768"}`
|
8 |
|
@@ -18,6 +18,9 @@ export async function predict(inputs: string, nbPanels: number): Promise<string>
|
|
18 |
const res = await groq.chat.completions.create({
|
19 |
messages: messages,
|
20 |
model: groqApiModel,
|
|
|
|
|
|
|
21 |
})
|
22 |
|
23 |
return res.choices[0].message.content || ""
|
|
|
2 |
|
3 |
import Groq from "groq-sdk"
|
4 |
|
5 |
+
export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
|
6 |
const groqApiKey = `${process.env.AUTH_GROQ_API_KEY || ""}`
|
7 |
const groqApiModel = `${process.env.LLM_GROQ_API_MODEL || "mixtral-8x7b-32768"}`
|
8 |
|
|
|
18 |
const res = await groq.chat.completions.create({
|
19 |
messages: messages,
|
20 |
model: groqApiModel,
|
21 |
+
stream: false,
|
22 |
+
temperature: 0.5,
|
23 |
+
max_tokens: nbMaxNewTokens,
|
24 |
})
|
25 |
|
26 |
return res.choices[0].message.content || ""
|
src/app/queries/predictWithHuggingFace.ts
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
|
4 |
import { LLMEngine } from "@/types"
|
5 |
|
6 |
-
export async function predict(inputs: string,
|
7 |
const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)
|
8 |
|
9 |
const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
|
@@ -12,10 +12,6 @@ export async function predict(inputs: string, nbPanels: number): Promise<string>
|
|
12 |
|
13 |
let hfie: HfInferenceEndpoint = hf
|
14 |
|
15 |
-
// we don't require a lot of token for our task
|
16 |
-
// but to be safe, let's count ~110 tokens per panel
|
17 |
-
const nbMaxNewTokens = nbPanels * 130 // 110 isn't enough anymore for long dialogues
|
18 |
-
|
19 |
switch (llmEngine) {
|
20 |
case "INFERENCE_ENDPOINT":
|
21 |
if (inferenceEndpoint) {
|
|
|
3 |
import { HfInference, HfInferenceEndpoint } from "@huggingface/inference"
|
4 |
import { LLMEngine } from "@/types"
|
5 |
|
6 |
+
export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
|
7 |
const hf = new HfInference(process.env.AUTH_HF_API_TOKEN)
|
8 |
|
9 |
const llmEngine = `${process.env.LLM_ENGINE || ""}` as LLMEngine
|
|
|
12 |
|
13 |
let hfie: HfInferenceEndpoint = hf
|
14 |
|
|
|
|
|
|
|
|
|
15 |
switch (llmEngine) {
|
16 |
case "INFERENCE_ENDPOINT":
|
17 |
if (inferenceEndpoint) {
|
src/app/queries/predictWithOpenAI.ts
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
import type { ChatCompletionMessage } from "openai/resources/chat"
|
4 |
import OpenAI from "openai"
|
5 |
|
6 |
-
export async function predict(inputs: string,
|
7 |
const openaiApiKey = `${process.env.AUTH_OPENAI_API_KEY || ""}`
|
8 |
const openaiApiBaseUrl = `${process.env.LLM_OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
|
9 |
const openaiApiModel = `${process.env.LLM_OPENAI_API_MODEL || "gpt-3.5-turbo"}`
|
@@ -23,6 +23,8 @@ export async function predict(inputs: string, nbPanels: number): Promise<string>
|
|
23 |
stream: false,
|
24 |
model: openaiApiModel,
|
25 |
temperature: 0.8,
|
|
|
|
|
26 |
// TODO: use the nbPanels to define a max token limit
|
27 |
})
|
28 |
|
|
|
3 |
import type { ChatCompletionMessage } from "openai/resources/chat"
|
4 |
import OpenAI from "openai"
|
5 |
|
6 |
+
export async function predict(inputs: string, nbMaxNewTokens: number): Promise<string> {
|
7 |
const openaiApiKey = `${process.env.AUTH_OPENAI_API_KEY || ""}`
|
8 |
const openaiApiBaseUrl = `${process.env.LLM_OPENAI_API_BASE_URL || "https://api.openai.com/v1"}`
|
9 |
const openaiApiModel = `${process.env.LLM_OPENAI_API_MODEL || "gpt-3.5-turbo"}`
|
|
|
23 |
stream: false,
|
24 |
model: openaiApiModel,
|
25 |
temperature: 0.8,
|
26 |
+
max_tokens: nbMaxNewTokens,
|
27 |
+
|
28 |
// TODO: use the nbPanels to define a max token limit
|
29 |
})
|
30 |
|