FacePoke / client /src /hooks /useMainStore.ts
jbilcke-hf's picture
jbilcke-hf HF staff
fix rotation of selfies
4d0b778
import { create } from 'zustand'
import * as vision from '@mediapipe/tasks-vision'
import { truncateFileName } from '@/lib/utils'
import { convertImageToBase64 } from '@/lib/convertImageToBase64'
import { type Metadata, type ImageModificationParams, type OnServerResponseParams, type ActionMode, type ImageStateValues, type ClosestLandmark } from '@/types'
import { mapRange } from '@/lib/mapRange'
import { facePoke } from '@/lib/facePoke'
export type ImageState = ImageStateValues & {
setStatus: (status?: string) => void
setError: (error?: string) => void
setFaceLandmarks: (faceLandmarks: vision.NormalizedLandmark[][]) => void
setBlendShapes: (blendShapes: vision.Classifications[]) => void
setImageFile: (file?: File) => Promise<void>
setIsFollowingCursor: (isFollowingCursor: boolean) => void
setIsGazingAtCursor: (isGazingAtCursor: boolean) => void
setOriginalImage: (url: string) => void
setOriginalImageUuid: (uuid: string) => void
setPreviewImage: (url: string) => void
resetImage: () => void
setAverageLatency: (averageLatency: number) => void
setActiveLandmark: (activeLandmark?: ClosestLandmark) => void
setMetadata: (metadata?: Metadata) => void
setParams: (params: Partial<ImageModificationParams>) => void
handleServerResponse: (params: OnServerResponseParams) => Promise<void>
applyModifiedHeadToCanvas: (headImageBlob: Blob) => Promise<string>
modifyImage: ({ landmark, vector, mode }: {
landmark: ClosestLandmark
vector: { x: number; y: number; z: number }
mode: ActionMode
}) => Promise<void>
}
export const getDefaultState = (): ImageStateValues => ({
status: '',
error: '',
imageFile: null,
isFollowingCursor: false,
isGazingAtCursor: false,
originalImage: '',
originalImageUuid: '',
previewImage: '',
minLatency: 20, // min time between requests
averageLatency: 190, // this should be the average for most people
maxLatency: 4000, // max time between requests
activeLandmark: undefined,
metadata: {
center: [0, 0],
size: 0,
bbox: [[0,0],[0,0],[0,0],[0,0]],
angle: 0,
},
params: {},
faceLandmarks: [],
blendShapes: [],
})
export const useMainStore = create<ImageState>((set, get) => ({
...getDefaultState(),
setStatus: (status: string = '') => set({ status }),
setError: (error: string = '') => set({ error }),
setFaceLandmarks: (faceLandmarks: vision.NormalizedLandmark[][]) => {
set({ faceLandmarks })
},
setBlendShapes: (blendShapes: vision.Classifications[]) => {
set({ blendShapes })
},
setImageFile: async (file?: File) => {
if (!file) {
set({
...getDefaultState(),
status: 'No file selected',
})
return;
}
try {
const image = await convertImageToBase64(file);
set({
...getDefaultState(),
imageFile: file,
status: `File selected: ${truncateFileName(file.name, 16)}`,
previewImage: image,
originalImage: image,
})
facePoke.loadImage(image);
} catch (err) {
console.log(`failed to load the image: `, err);
set({
...getDefaultState(),
status: 'Failed to load the image',
})
}
},
setIsFollowingCursor: (isFollowingCursor: boolean) => set({ isFollowingCursor }),
setIsGazingAtCursor: (isGazingAtCursor: boolean) => set({ isGazingAtCursor }),
setOriginalImage: (url) => set({ originalImage: url }),
setOriginalImageUuid: (originalImageUuid) => set({ originalImageUuid }),
setPreviewImage: (url) => set({ previewImage: url }),
resetImage: () => {
const { originalImage } = get()
if (originalImage) {
set({ previewImage: originalImage })
}
},
setAverageLatency: (averageLatency: number) => set({ averageLatency }),
setActiveLandmark: (activeLandmark?: ClosestLandmark) => set({ activeLandmark }),
setMetadata: (metadata?: Metadata) => set(metadata ? {
metadata
} : {
metadata: getDefaultState().metadata,
}),
setParams: (params: Partial<ImageModificationParams>) => {
const {params: previousParams } = get()
set({ params: {
...previousParams,
...params
}})
},
handleServerResponse: async (params: OnServerResponseParams) => {
const { originalImage, setMetadata, setPreviewImage, setOriginalImageUuid, applyModifiedHeadToCanvas, modifyImage } = useMainStore.getState();
if (typeof params.error === "string") {
console.error(`handleServerResponse: failed to perform the request, resetting the app (${params.error})`)
setPreviewImage(originalImage)
setOriginalImageUuid('')
} else if (typeof params.image !== "undefined") {
// this is where we decide to paste back the image as a whole,
// or apply some shenanigans to only past back the head.
// the part about the head is not done yet, so we do it all for now.
// --- old way: use it whole ---
const image = await convertImageToBase64(params.image);
// --- future way: try to only apply the head ---
// const image = await applyModifiedHeadToCanvas(params.image);
setPreviewImage(image);
} else if (typeof params.loaded !== "undefined") {
//console.log(`handleServerResponse: received a json`, params)
setOriginalImageUuid(params.loaded.u)
setMetadata({
center: params.loaded.c, // center - 2x1
size: params.loaded.s, // size - scalar
bbox: params.loaded.b, // bbox - 4x2
angle: params.loaded.a, //angle - rad, counterclockwise
})
// right after we received the id, we perform a first blank request
await modifyImage({
landmark: {
group: 'background',
distance: 0,
vector: { x: 0.5, y: 0.5, z: 0 }
},
vector: {x: 0, y: 0, z: 0 },
mode: 'PRIMARY'
})
} else {
console.log(`handleServerResponse: received an unknown json`, params)
}
},
applyModifiedHeadToCanvas: async (headImageBlob: Blob): Promise<string> => {
return new Promise(async (resolve, reject) => {
const originalImg = new Image();
const { originalImage, metadata } = useMainStore.getState();
originalImg.onload = async () => {
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
if (!ctx) {
reject(new Error('Failed to get 2D context'));
return;
}
// Get device pixel ratio
const pixelRatio = window.devicePixelRatio || 1;
canvas.width = originalImg.width;
canvas.height = originalImg.height;
// Draw original image
ctx.drawImage(originalImg, 0, 0);
const headImageBitmap = await createImageBitmap(headImageBlob, {
resizeQuality: 'high'
});
// Create a temporary canvas for the head image with gradient
const tempCanvas = document.createElement('canvas');
const tempCtx = tempCanvas.getContext('2d');
if (!tempCtx) {
reject(new Error('Failed to get 2D context for temporary canvas'));
return;
}
tempCanvas.width = headImageBitmap.width;
tempCanvas.height = headImageBitmap.height;
// Draw the head image on the temporary canvas
tempCtx.drawImage(headImageBitmap, 0, 0);
// Create gradient mask
const gradientSize = 20; // Size of the gradient in pixels
const gradient = tempCtx.createRadialGradient(
tempCanvas.width / 2, tempCanvas.height / 2, Math.min(tempCanvas.width, tempCanvas.height) / 2 - gradientSize,
tempCanvas.width / 2, tempCanvas.height / 2, Math.min(tempCanvas.width, tempCanvas.height) / 2
);
gradient.addColorStop(0, 'rgba(0, 0, 0, 1)');
gradient.addColorStop(1, 'rgba(0, 0, 0, 0)');
// Apply gradient mask
tempCtx.globalCompositeOperation = 'destination-in';
tempCtx.fillStyle = gradient;
tempCtx.fillRect(0, 0, tempCanvas.width, tempCanvas.height);
console.log("metadata:", metadata);
ctx.save();
ctx.rotate(metadata.angle);
// TODO: this is where we need to grab the coordinates from the metadata and start drawing our image
// now the issue if that there is an angle, so we need to see how this impacts the transformation
ctx.restore();
/*
TODO finish the implementation
console.log("metadata:", metadata);
ctx.save();
ctx.rotate(metadata.angle);
// TODO we need the coordinate inside the final image
ctx.drawImage(
tempCanvas,
topLeftCornerX,
topLeftCornerY,
width,
height,
);
ctx.restore();
*/
// Our head image already contains compression artifacts,
// so let's avoid double-encoding in here and use PNG
resolve(canvas.toDataURL('image/png'));
};
originalImg.src = originalImage;
})
},
modifyImage: async ({ landmark, vector, mode }: {
landmark: ClosestLandmark
vector: { x: number; y: number; z: number }
mode: ActionMode
}): Promise<void> => {
const {
originalImage,
originalImageUuid,
params: previousParams,
setParams,
setError,
isFollowingCursor,
isGazingAtCursor
} = get()
if (!originalImage) {
console.error('Image file or facePoke not available');
return;
}
const params = {
...previousParams
}
const generalControl = {
minX: -0.30,
maxX: 0.30,
minY: -0.30,
maxY: 0.30
}
const pupilControl = {
minX: -0.50,
maxX: 0.50,
minY: -0.50,
maxY: 0.50
}
const eyeControl = {
minX: -0.50,
maxX: 0.50,
minY: -0.50,
maxY: 0.50
}
// for the min and max value, we can look here:
// https://github.com/fofr/cog-expression-editor/blob/main/predict.py
// regardless of the landmark, those rule will apply
if (isFollowingCursor) {
// displacing the face horizontally by moving the mouse on the X axis
// should perform a yaw rotation
// rotate_yaw (min: -20, max: 20, default: 0)
const yawMin = -40
const yawMax = 40
// note: we invert the axis here
params.rotate_yaw = mapRange(-vector.x, generalControl.minX, generalControl.maxX, yawMin, yawMax);
// displacing the face vertically by moving the mouse on the Y axis
// should perform a pitch rotation
// rotate_pitch (min: -20, max: 20, default: 0)
const pitchMin = -40
const pitchMax = 40
params.rotate_pitch = mapRange(vector.y, generalControl.minY, generalControl.maxY, pitchMin, pitchMax);
}
if (isGazingAtCursor) {
const pupilsXMin = -15
const pupilsXMax = 15
params.pupil_x = mapRange(vector.x, pupilControl.minX, pupilControl.maxX, pupilsXMin, pupilsXMax);
const pupilsYMin = -2 // -15
const pupilsYMax = 8 // 15
params.pupil_y = mapRange(-vector.y, pupilControl.minY, pupilControl.maxY, pupilsYMin, pupilsYMax);
}
// if the user clicked on the primary or secondary button,
// then it triggers some more things
if (mode !== 'HOVERING') {
switch (landmark.group) {
case 'leftEye':
case 'rightEye':
// if (mode === 'PRIMARY') {
const pupilsXMin = -15
const pupilsXMax = 15
params.pupil_x = mapRange(vector.x, pupilControl.minX, pupilControl.maxX, pupilsXMin, pupilsXMax);
const eyesMin = -20
const eyesMax = 5
params.eyes = mapRange(-vector.y, eyeControl.minX, eyeControl.maxX, eyesMin, eyesMax);
//}
break;
case 'leftEyebrow':
case 'rightEyebrow':
// moving the mouse vertically for the eyebrow
// should make them up/down
// eyebrow (min: -10, max: 15, default: 0)
const eyebrowMin = -10
const eyebrowMax = 15
params.eyebrow = mapRange(-vector.y, eyeControl.minY, eyeControl.maxY, eyebrowMin, eyebrowMax);
break;
case 'lips':
// aaa (min: -30, max: 120, default: 0)
const aaaMin = -30
const aaaMax = 120
params.aaa = mapRange(-vector.y, eyeControl.minY, eyeControl.maxY, aaaMin, aaaMax);
// eee (min: -20, max: 15, default: 0)
const eeeMin = -20
const eeeMax = 15
params.eee = mapRange(vector.x, eyeControl.minX, eyeControl.maxX, eeeMin, eeeMax);
// woo (min: -20, max: 15, default: 0)
//const wooMin = -20
//const wooMax = 15
//params.woo = mapRange(-vector.x, eyeControl.minX, eyeControl.maxX, wooMin, wooMax);
break;
case 'faceOval':
// displacing the face horizontally by moving the mouse on the X axis
// should perform a yaw rotation
// rotate_roll (min: -20, max: 20, default: 0)
const rollMin = -40
const rollMax = 40
// note: we invert the axis here
params.rotate_roll = mapRange(vector.x, eyeControl.minX, eyeControl.maxX, rollMin, rollMax);
break;
case 'background':
// displacing the face horizontally by moving the mouse on the X axis
// should perform a yaw rotation
// rotate_yaw (min: -20, max: 20, default: 0)
const yawMin = -40
const yawMax = 40
// note: we invert the axis here
params.rotate_yaw = mapRange(-vector.x, generalControl.minX, generalControl.maxX, yawMin, yawMax);
// displacing the face vertically by moving the mouse on the Y axis
// should perform a pitch rotation
// rotate_pitch (min: -20, max: 20, default: 0)
const pitchMin = -40
const pitchMax = 40
params.rotate_pitch = mapRange(vector.y, eyeControl.minY, eyeControl.maxY, pitchMin, pitchMax);
break;
default:
return
}
}
for (const [key, value] of Object.entries(params)) {
if (isNaN(value as any) || !isFinite(value as any)) {
console.log(`${key} is NaN, aborting`)
return
}
}
//console.log(`PITCH=${params.rotate_pitch || 0}, YAW=${params.rotate_yaw || 0}, ROLL=${params.rotate_roll || 0}`);
setParams(params)
try {
if (originalImageUuid) {
facePoke.transformImage(originalImageUuid, params);
}
} catch (error) {
// console.error('Error modifying image:', error);
setError('Failed to modify image');
}
},
}))