Commit
•
d243e97
1
Parent(s):
930b21b
up
Browse files- .env +3 -0
- src/app/interface/generate/index.tsx +67 -29
- src/app/server/actions/interpolate.ts +81 -0
.env
CHANGED
@@ -25,6 +25,9 @@ VIDEO_HOTSHOT_XL_API_GRADIO="https://jbilcke-hf-hotshot-xl-server-1.hf.space"
|
|
25 |
VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL="cloneofsimo/hotshot-xl-lora-controlnet"
|
26 |
VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL_VERSION="75e26ffd033a59a78954a3d675632f47f7f8470402aec51c255b9f9b7b62568b"
|
27 |
|
|
|
|
|
|
|
28 |
# ----------- CENSORSHIP -------
|
29 |
ENABLE_CENSORSHIP=""
|
30 |
FINGERPRINT_KEY=""
|
|
|
25 |
VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL="cloneofsimo/hotshot-xl-lora-controlnet"
|
26 |
VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL_VERSION="75e26ffd033a59a78954a3d675632f47f7f8470402aec51c255b9f9b7b62568b"
|
27 |
|
28 |
+
INTERPOLATION_API_REPLICATE_MODEL="zsxkib/st-mfnet"
|
29 |
+
INTERPOLATION_API_REPLICATE_MODEL_VERSION="faa7693430b0a4ac95d1b8e25165673c1d7a7263537a7c4bb9be82a3e2d130fb"
|
30 |
+
|
31 |
# ----------- CENSORSHIP -------
|
32 |
ENABLE_CENSORSHIP=""
|
33 |
FINGERPRINT_KEY=""
|
src/app/interface/generate/index.tsx
CHANGED
@@ -15,6 +15,7 @@ import { getSDXLModels } from "@/app/server/actions/models"
|
|
15 |
import { HotshotImageInferenceSize, Post, QualityLevel, QualityOption, SDXLModel } from "@/types"
|
16 |
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"
|
17 |
import { TooltipProvider } from "@radix-ui/react-tooltip"
|
|
|
18 |
|
19 |
const qualityOptions = [
|
20 |
{
|
@@ -27,6 +28,8 @@ const qualityOptions = [
|
|
27 |
}
|
28 |
] as QualityOption[]
|
29 |
|
|
|
|
|
30 |
export function Generate() {
|
31 |
const router = useRouter()
|
32 |
const pathname = usePathname()
|
@@ -49,12 +52,14 @@ export function Generate() {
|
|
49 |
|
50 |
const [communityRoll, setCommunityRoll] = useState<Post[]>([])
|
51 |
|
|
|
|
|
52 |
const [qualityLevel, setQualityLevel] = useState<QualityLevel>("low")
|
53 |
|
54 |
const { progressPercent, remainingTimeInSec } = useCountdown({
|
55 |
isActive: isLocked,
|
56 |
timerId: runs, // everytime we change this, the timer will reset
|
57 |
-
durationInSec: 80, // it usually takes 40 seconds, but there might be lag
|
58 |
onEnd: () => {}
|
59 |
})
|
60 |
|
@@ -83,8 +88,9 @@ export function Generate() {
|
|
83 |
if (!promptDraft) { return }
|
84 |
|
85 |
setShowModels(false)
|
86 |
-
setRuns(
|
87 |
setLocked(true)
|
|
|
88 |
|
89 |
scrollRef.current?.scroll({
|
90 |
top: 0,
|
@@ -118,44 +124,73 @@ export function Generate() {
|
|
118 |
size
|
119 |
}
|
120 |
|
121 |
-
let
|
122 |
try {
|
123 |
// console.log("starting transition, calling generateAnimation")
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
} catch (err) {
|
128 |
console.log("generation failed! probably just a Gradio failure, so let's just run the round robin again!")
|
129 |
|
130 |
try {
|
131 |
-
|
132 |
-
setAssetUrl(newAssetUrl)
|
133 |
} catch (err) {
|
134 |
console.error(`generation failed again! ${err}`)
|
135 |
}
|
136 |
-
}
|
|
|
|
|
|
|
137 |
setLocked(false)
|
|
|
|
|
|
|
|
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
current.set("postId", post.postId.trim())
|
151 |
-
current.set("prompt", post.prompt.trim())
|
152 |
-
current.set("model", post.model.trim())
|
153 |
-
const search = current.toString()
|
154 |
-
router.push(`${pathname}${search ? `?${search}` : ""}`)
|
155 |
-
} catch (err) {
|
156 |
-
console.error(`not a blocker, but we failed to post to the community (reason: ${err})`)
|
157 |
-
}
|
158 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
}
|
160 |
})
|
161 |
}
|
@@ -323,7 +358,7 @@ export function Generate() {
|
|
323 |
`space-y-3 md:space-y-6`,
|
324 |
`items-center`,
|
325 |
)}>
|
326 |
-
{assetUrl.startsWith("data:video/mp4")
|
327 |
? <video
|
328 |
muted
|
329 |
autoPlay
|
@@ -420,7 +455,10 @@ export function Generate() {
|
|
420 |
disabled={isLocked}
|
421 |
onClick={handleSubmit}
|
422 |
>
|
423 |
-
{isLocked
|
|
|
|
|
|
|
424 |
</animated.button>
|
425 |
</div>
|
426 |
</div>
|
|
|
15 |
import { HotshotImageInferenceSize, Post, QualityLevel, QualityOption, SDXLModel } from "@/types"
|
16 |
import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"
|
17 |
import { TooltipProvider } from "@radix-ui/react-tooltip"
|
18 |
+
import { interpolate } from "@/app/server/actions/interpolate"
|
19 |
|
20 |
const qualityOptions = [
|
21 |
{
|
|
|
28 |
}
|
29 |
] as QualityOption[]
|
30 |
|
31 |
+
type Stage = "generate" | "interpolate" | "finished"
|
32 |
+
|
33 |
export function Generate() {
|
34 |
const router = useRouter()
|
35 |
const pathname = usePathname()
|
|
|
52 |
|
53 |
const [communityRoll, setCommunityRoll] = useState<Post[]>([])
|
54 |
|
55 |
+
const [stage, setStage] = useState<Stage>("generate")
|
56 |
+
|
57 |
const [qualityLevel, setQualityLevel] = useState<QualityLevel>("low")
|
58 |
|
59 |
const { progressPercent, remainingTimeInSec } = useCountdown({
|
60 |
isActive: isLocked,
|
61 |
timerId: runs, // everytime we change this, the timer will reset
|
62 |
+
durationInSec: /*stage === "interpolate" ? 30 :*/ 80, // it usually takes 40 seconds, but there might be lag
|
63 |
onEnd: () => {}
|
64 |
})
|
65 |
|
|
|
88 |
if (!promptDraft) { return }
|
89 |
|
90 |
setShowModels(false)
|
91 |
+
setRuns(runsRef.current + 1)
|
92 |
setLocked(true)
|
93 |
+
setStage("generate")
|
94 |
|
95 |
scrollRef.current?.scroll({
|
96 |
top: 0,
|
|
|
124 |
size
|
125 |
}
|
126 |
|
127 |
+
let rawAssetUrl = ""
|
128 |
try {
|
129 |
// console.log("starting transition, calling generateAnimation")
|
130 |
+
rawAssetUrl = await generateAnimation(params)
|
131 |
+
|
132 |
+
if (!rawAssetUrl) {
|
133 |
+
throw new Error("invalid asset url")
|
134 |
+
}
|
135 |
+
|
136 |
+
setAssetUrl(rawAssetUrl)
|
137 |
|
138 |
} catch (err) {
|
139 |
console.log("generation failed! probably just a Gradio failure, so let's just run the round robin again!")
|
140 |
|
141 |
try {
|
142 |
+
rawAssetUrl = await generateAnimation(params)
|
|
|
143 |
} catch (err) {
|
144 |
console.error(`generation failed again! ${err}`)
|
145 |
}
|
146 |
+
}
|
147 |
+
|
148 |
+
if (!rawAssetUrl) {
|
149 |
+
console.log("failed to generate the video, aborting")
|
150 |
setLocked(false)
|
151 |
+
return
|
152 |
+
}
|
153 |
+
|
154 |
+
setAssetUrl(rawAssetUrl)
|
155 |
|
156 |
+
|
157 |
+
let assetUrl = rawAssetUrl
|
158 |
+
|
159 |
+
setStage("interpolate")
|
160 |
+
setRuns(runsRef.current + 1)
|
161 |
+
|
162 |
+
try {
|
163 |
+
assetUrl = await interpolate(rawAssetUrl)
|
164 |
+
|
165 |
+
if (!assetUrl) {
|
166 |
+
throw new Error("invalid interpolated asset url")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
}
|
168 |
+
|
169 |
+
setAssetUrl(assetUrl)
|
170 |
+
} catch (err) {
|
171 |
+
console.log(`failed to interpolate the video, but this is not a blocker: ${err}`)
|
172 |
+
}
|
173 |
+
|
174 |
+
setLocked(false)
|
175 |
+
setStage("generate")
|
176 |
+
|
177 |
+
try {
|
178 |
+
const post = await postToCommunity({
|
179 |
+
prompt: promptDraft,
|
180 |
+
model: huggingFaceLora,
|
181 |
+
assetUrl,
|
182 |
+
})
|
183 |
+
console.log("successfully submitted to the community!", post)
|
184 |
+
|
185 |
+
// now you got a read/write object
|
186 |
+
const current = new URLSearchParams(Array.from(searchParams.entries()))
|
187 |
+
current.set("postId", post.postId.trim())
|
188 |
+
current.set("prompt", post.prompt.trim())
|
189 |
+
current.set("model", post.model.trim())
|
190 |
+
const search = current.toString()
|
191 |
+
router.push(`${pathname}${search ? `?${search}` : ""}`)
|
192 |
+
} catch (err) {
|
193 |
+
console.error(`not a blocker, but we failed to post to the community (reason: ${err})`)
|
194 |
}
|
195 |
})
|
196 |
}
|
|
|
358 |
`space-y-3 md:space-y-6`,
|
359 |
`items-center`,
|
360 |
)}>
|
361 |
+
{assetUrl.startsWith("data:video/mp4") || assetUrl.endsWith(".mp4")
|
362 |
? <video
|
363 |
muted
|
364 |
autoPlay
|
|
|
455 |
disabled={isLocked}
|
456 |
onClick={handleSubmit}
|
457 |
>
|
458 |
+
{isLocked
|
459 |
+
? (stage === "generate" ? `Generating..` : `Smoothing..`)
|
460 |
+
: "Generate"
|
461 |
+
}
|
462 |
</animated.button>
|
463 |
</div>
|
464 |
</div>
|
src/app/server/actions/interpolate.ts
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"use server"
|
2 |
+
|
3 |
+
import Replicate from "replicate"
|
4 |
+
|
5 |
+
import { sleep } from "@/lib/sleep"
|
6 |
+
|
7 |
+
const replicateToken = `${process.env.AUTH_REPLICATE_API_TOKEN || ""}`
|
8 |
+
const replicateModel = `${process.env.INTERPOLATION_API_REPLICATE_MODEL || ""}`
|
9 |
+
const replicateModelVersion = `${process.env.INTERPOLATION_API_REPLICATE_MODEL_VERSION || ""}`
|
10 |
+
|
11 |
+
export async function interpolate(input: string): Promise<string> {
|
12 |
+
if (!replicateToken) {
|
13 |
+
throw new Error(`you need to configure your AUTH_REPLICATE_API_TOKEN in order to use interpolation`)
|
14 |
+
}
|
15 |
+
if (!replicateModel) {
|
16 |
+
throw new Error(`you need to configure your INTERPOLATION_API_REPLICATE_MODEL in order to use interpolation`)
|
17 |
+
}
|
18 |
+
|
19 |
+
if (!replicateModelVersion) {
|
20 |
+
throw new Error(`you need to configure your INTERPOLATION_API_REPLICATE_MODEL_VERSION in order to use interpolation`)
|
21 |
+
}
|
22 |
+
const replicate = new Replicate({ auth: replicateToken })
|
23 |
+
|
24 |
+
const prediction = await replicate.predictions.create({
|
25 |
+
version: replicateModelVersion,
|
26 |
+
input: {
|
27 |
+
mp4: input,
|
28 |
+
framerate_multiplier: 8,
|
29 |
+
keep_original_duration: true,
|
30 |
+
}
|
31 |
+
})
|
32 |
+
|
33 |
+
let res: Response
|
34 |
+
let pollingCount = 0
|
35 |
+
do {
|
36 |
+
// This is normally a fast model, so let's check every 4 seconds
|
37 |
+
await sleep(4000)
|
38 |
+
|
39 |
+
res = await fetch(`https://api.replicate.com/v1/predictions/${prediction.id}`, {
|
40 |
+
method: "GET",
|
41 |
+
headers: {
|
42 |
+
Authorization: `Token ${replicateToken}`,
|
43 |
+
},
|
44 |
+
cache: 'no-store',
|
45 |
+
})
|
46 |
+
|
47 |
+
// console.log("res:", res)
|
48 |
+
|
49 |
+
/*
|
50 |
+
try {
|
51 |
+
const text = await res.text()
|
52 |
+
console.log("res.text:", text)
|
53 |
+
} catch (err) {
|
54 |
+
console.error("res.text() error:", err)
|
55 |
+
}
|
56 |
+
*/
|
57 |
+
|
58 |
+
if (res.status === 200) {
|
59 |
+
try {
|
60 |
+
const response = (await res.json()) as any
|
61 |
+
// console.log("response:", response)
|
62 |
+
const error = `${response?.error || ""}`
|
63 |
+
if (error) {
|
64 |
+
throw new Error(error)
|
65 |
+
}
|
66 |
+
if (response.status === "succeeded") {
|
67 |
+
return response.output[1]
|
68 |
+
}
|
69 |
+
} catch (err) {
|
70 |
+
console.error("res.json() error:", err)
|
71 |
+
}
|
72 |
+
}
|
73 |
+
|
74 |
+
pollingCount++
|
75 |
+
|
76 |
+
// To prevent indefinite polling, we can stop after a certain number
|
77 |
+
if (pollingCount >= 20) {
|
78 |
+
throw new Error('Request time out.')
|
79 |
+
}
|
80 |
+
} while (true)
|
81 |
+
}
|