jbilcke-hf HF staff commited on
Commit
d243e97
1 Parent(s): 930b21b
.env CHANGED
@@ -25,6 +25,9 @@ VIDEO_HOTSHOT_XL_API_GRADIO="https://jbilcke-hf-hotshot-xl-server-1.hf.space"
25
  VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL="cloneofsimo/hotshot-xl-lora-controlnet"
26
  VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL_VERSION="75e26ffd033a59a78954a3d675632f47f7f8470402aec51c255b9f9b7b62568b"
27
 
 
 
 
28
  # ----------- CENSORSHIP -------
29
  ENABLE_CENSORSHIP=""
30
  FINGERPRINT_KEY=""
 
25
  VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL="cloneofsimo/hotshot-xl-lora-controlnet"
26
  VIDEO_HOTSHOT_XL_API_REPLICATE_MODEL_VERSION="75e26ffd033a59a78954a3d675632f47f7f8470402aec51c255b9f9b7b62568b"
27
 
28
+ INTERPOLATION_API_REPLICATE_MODEL="zsxkib/st-mfnet"
29
+ INTERPOLATION_API_REPLICATE_MODEL_VERSION="faa7693430b0a4ac95d1b8e25165673c1d7a7263537a7c4bb9be82a3e2d130fb"
30
+
31
  # ----------- CENSORSHIP -------
32
  ENABLE_CENSORSHIP=""
33
  FINGERPRINT_KEY=""
src/app/interface/generate/index.tsx CHANGED
@@ -15,6 +15,7 @@ import { getSDXLModels } from "@/app/server/actions/models"
15
  import { HotshotImageInferenceSize, Post, QualityLevel, QualityOption, SDXLModel } from "@/types"
16
  import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"
17
  import { TooltipProvider } from "@radix-ui/react-tooltip"
 
18
 
19
  const qualityOptions = [
20
  {
@@ -27,6 +28,8 @@ const qualityOptions = [
27
  }
28
  ] as QualityOption[]
29
 
 
 
30
  export function Generate() {
31
  const router = useRouter()
32
  const pathname = usePathname()
@@ -49,12 +52,14 @@ export function Generate() {
49
 
50
  const [communityRoll, setCommunityRoll] = useState<Post[]>([])
51
 
 
 
52
  const [qualityLevel, setQualityLevel] = useState<QualityLevel>("low")
53
 
54
  const { progressPercent, remainingTimeInSec } = useCountdown({
55
  isActive: isLocked,
56
  timerId: runs, // everytime we change this, the timer will reset
57
- durationInSec: 80, // it usually takes 40 seconds, but there might be lag
58
  onEnd: () => {}
59
  })
60
 
@@ -83,8 +88,9 @@ export function Generate() {
83
  if (!promptDraft) { return }
84
 
85
  setShowModels(false)
86
- setRuns(runs + 1)
87
  setLocked(true)
 
88
 
89
  scrollRef.current?.scroll({
90
  top: 0,
@@ -118,44 +124,73 @@ export function Generate() {
118
  size
119
  }
120
 
121
- let newAssetUrl = ""
122
  try {
123
  // console.log("starting transition, calling generateAnimation")
124
- newAssetUrl = await generateAnimation(params)
125
- setAssetUrl(newAssetUrl)
 
 
 
 
 
126
 
127
  } catch (err) {
128
  console.log("generation failed! probably just a Gradio failure, so let's just run the round robin again!")
129
 
130
  try {
131
- newAssetUrl = await generateAnimation(params)
132
- setAssetUrl(newAssetUrl)
133
  } catch (err) {
134
  console.error(`generation failed again! ${err}`)
135
  }
136
- } finally {
 
 
 
137
  setLocked(false)
 
 
 
 
138
 
139
- if (newAssetUrl) {
140
- try {
141
- const post = await postToCommunity({
142
- prompt: promptDraft,
143
- model: huggingFaceLora,
144
- assetUrl: newAssetUrl,
145
- })
146
- console.log("successfully submitted to the community!", post)
147
-
148
- // now you got a read/write object
149
- const current = new URLSearchParams(Array.from(searchParams.entries()))
150
- current.set("postId", post.postId.trim())
151
- current.set("prompt", post.prompt.trim())
152
- current.set("model", post.model.trim())
153
- const search = current.toString()
154
- router.push(`${pathname}${search ? `?${search}` : ""}`)
155
- } catch (err) {
156
- console.error(`not a blocker, but we failed to post to the community (reason: ${err})`)
157
- }
158
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  }
160
  })
161
  }
@@ -323,7 +358,7 @@ export function Generate() {
323
  `space-y-3 md:space-y-6`,
324
  `items-center`,
325
  )}>
326
- {assetUrl.startsWith("data:video/mp4")
327
  ? <video
328
  muted
329
  autoPlay
@@ -420,7 +455,10 @@ export function Generate() {
420
  disabled={isLocked}
421
  onClick={handleSubmit}
422
  >
423
- {isLocked ? `Loading..` : "Generate"}
 
 
 
424
  </animated.button>
425
  </div>
426
  </div>
 
15
  import { HotshotImageInferenceSize, Post, QualityLevel, QualityOption, SDXLModel } from "@/types"
16
  import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"
17
  import { TooltipProvider } from "@radix-ui/react-tooltip"
18
+ import { interpolate } from "@/app/server/actions/interpolate"
19
 
20
  const qualityOptions = [
21
  {
 
28
  }
29
  ] as QualityOption[]
30
 
31
+ type Stage = "generate" | "interpolate" | "finished"
32
+
33
  export function Generate() {
34
  const router = useRouter()
35
  const pathname = usePathname()
 
52
 
53
  const [communityRoll, setCommunityRoll] = useState<Post[]>([])
54
 
55
+ const [stage, setStage] = useState<Stage>("generate")
56
+
57
  const [qualityLevel, setQualityLevel] = useState<QualityLevel>("low")
58
 
59
  const { progressPercent, remainingTimeInSec } = useCountdown({
60
  isActive: isLocked,
61
  timerId: runs, // everytime we change this, the timer will reset
62
+ durationInSec: /*stage === "interpolate" ? 30 :*/ 80, // it usually takes 40 seconds, but there might be lag
63
  onEnd: () => {}
64
  })
65
 
 
88
  if (!promptDraft) { return }
89
 
90
  setShowModels(false)
91
+ setRuns(runsRef.current + 1)
92
  setLocked(true)
93
+ setStage("generate")
94
 
95
  scrollRef.current?.scroll({
96
  top: 0,
 
124
  size
125
  }
126
 
127
+ let rawAssetUrl = ""
128
  try {
129
  // console.log("starting transition, calling generateAnimation")
130
+ rawAssetUrl = await generateAnimation(params)
131
+
132
+ if (!rawAssetUrl) {
133
+ throw new Error("invalid asset url")
134
+ }
135
+
136
+ setAssetUrl(rawAssetUrl)
137
 
138
  } catch (err) {
139
  console.log("generation failed! probably just a Gradio failure, so let's just run the round robin again!")
140
 
141
  try {
142
+ rawAssetUrl = await generateAnimation(params)
 
143
  } catch (err) {
144
  console.error(`generation failed again! ${err}`)
145
  }
146
+ }
147
+
148
+ if (!rawAssetUrl) {
149
+ console.log("failed to generate the video, aborting")
150
  setLocked(false)
151
+ return
152
+ }
153
+
154
+ setAssetUrl(rawAssetUrl)
155
 
156
+
157
+ let assetUrl = rawAssetUrl
158
+
159
+ setStage("interpolate")
160
+ setRuns(runsRef.current + 1)
161
+
162
+ try {
163
+ assetUrl = await interpolate(rawAssetUrl)
164
+
165
+ if (!assetUrl) {
166
+ throw new Error("invalid interpolated asset url")
 
 
 
 
 
 
 
 
167
  }
168
+
169
+ setAssetUrl(assetUrl)
170
+ } catch (err) {
171
+ console.log(`failed to interpolate the video, but this is not a blocker: ${err}`)
172
+ }
173
+
174
+ setLocked(false)
175
+ setStage("generate")
176
+
177
+ try {
178
+ const post = await postToCommunity({
179
+ prompt: promptDraft,
180
+ model: huggingFaceLora,
181
+ assetUrl,
182
+ })
183
+ console.log("successfully submitted to the community!", post)
184
+
185
+ // now you got a read/write object
186
+ const current = new URLSearchParams(Array.from(searchParams.entries()))
187
+ current.set("postId", post.postId.trim())
188
+ current.set("prompt", post.prompt.trim())
189
+ current.set("model", post.model.trim())
190
+ const search = current.toString()
191
+ router.push(`${pathname}${search ? `?${search}` : ""}`)
192
+ } catch (err) {
193
+ console.error(`not a blocker, but we failed to post to the community (reason: ${err})`)
194
  }
195
  })
196
  }
 
358
  `space-y-3 md:space-y-6`,
359
  `items-center`,
360
  )}>
361
+ {assetUrl.startsWith("data:video/mp4") || assetUrl.endsWith(".mp4")
362
  ? <video
363
  muted
364
  autoPlay
 
455
  disabled={isLocked}
456
  onClick={handleSubmit}
457
  >
458
+ {isLocked
459
+ ? (stage === "generate" ? `Generating..` : `Smoothing..`)
460
+ : "Generate"
461
+ }
462
  </animated.button>
463
  </div>
464
  </div>
src/app/server/actions/interpolate.ts ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "use server"
2
+
3
+ import Replicate from "replicate"
4
+
5
+ import { sleep } from "@/lib/sleep"
6
+
7
+ const replicateToken = `${process.env.AUTH_REPLICATE_API_TOKEN || ""}`
8
+ const replicateModel = `${process.env.INTERPOLATION_API_REPLICATE_MODEL || ""}`
9
+ const replicateModelVersion = `${process.env.INTERPOLATION_API_REPLICATE_MODEL_VERSION || ""}`
10
+
11
+ export async function interpolate(input: string): Promise<string> {
12
+ if (!replicateToken) {
13
+ throw new Error(`you need to configure your AUTH_REPLICATE_API_TOKEN in order to use interpolation`)
14
+ }
15
+ if (!replicateModel) {
16
+ throw new Error(`you need to configure your INTERPOLATION_API_REPLICATE_MODEL in order to use interpolation`)
17
+ }
18
+
19
+ if (!replicateModelVersion) {
20
+ throw new Error(`you need to configure your INTERPOLATION_API_REPLICATE_MODEL_VERSION in order to use interpolation`)
21
+ }
22
+ const replicate = new Replicate({ auth: replicateToken })
23
+
24
+ const prediction = await replicate.predictions.create({
25
+ version: replicateModelVersion,
26
+ input: {
27
+ mp4: input,
28
+ framerate_multiplier: 8,
29
+ keep_original_duration: true,
30
+ }
31
+ })
32
+
33
+ let res: Response
34
+ let pollingCount = 0
35
+ do {
36
+ // This is normally a fast model, so let's check every 4 seconds
37
+ await sleep(4000)
38
+
39
+ res = await fetch(`https://api.replicate.com/v1/predictions/${prediction.id}`, {
40
+ method: "GET",
41
+ headers: {
42
+ Authorization: `Token ${replicateToken}`,
43
+ },
44
+ cache: 'no-store',
45
+ })
46
+
47
+ // console.log("res:", res)
48
+
49
+ /*
50
+ try {
51
+ const text = await res.text()
52
+ console.log("res.text:", text)
53
+ } catch (err) {
54
+ console.error("res.text() error:", err)
55
+ }
56
+ */
57
+
58
+ if (res.status === 200) {
59
+ try {
60
+ const response = (await res.json()) as any
61
+ // console.log("response:", response)
62
+ const error = `${response?.error || ""}`
63
+ if (error) {
64
+ throw new Error(error)
65
+ }
66
+ if (response.status === "succeeded") {
67
+ return response.output[1]
68
+ }
69
+ } catch (err) {
70
+ console.error("res.json() error:", err)
71
+ }
72
+ }
73
+
74
+ pollingCount++
75
+
76
+ // To prevent indefinite polling, we can stop after a certain number
77
+ if (pollingCount >= 20) {
78
+ throw new Error('Request time out.')
79
+ }
80
+ } while (true)
81
+ }