inpaint-web / src /adapters /superResolution.ts
zhou20120904's picture
Upload 48 files
a62d4c5 verified
/* eslint-disable no-console */
/* eslint-disable no-plusplus */
import cv, { Mat } from 'opencv-ts'
import { getCapabilities } from './util'
import { ensureModel } from './cache'
function loadImage(url: string): Promise<HTMLImageElement> {
return new Promise((resolve, reject) => {
const img = new Image()
img.crossOrigin = 'Anonymous'
img.onload = () => resolve(img)
img.onerror = () => reject(new Error(`Failed to load image from ${url}`))
img.src = url
})
}
function imgProcess(img: Mat) {
const channels = new cv.MatVector()
cv.split(img, channels) // 分割通道
const C = channels.size() // 通道数
const H = img.rows // 图像高度
const W = img.cols // 图像宽度
const chwArray = new Float32Array(C * H * W) // 创建新的数组来存储转换后的数据
for (let c = 0; c < C; c++) {
const channelData = channels.get(c).data // 获取单个通道的数据
for (let h = 0; h < H; h++) {
for (let w = 0; w < W; w++) {
chwArray[c * H * W + h * W + w] = channelData[h * W + w] / 255.0
// chwArray[c * H * W + h * W + w] = channelData[h * W + w]
}
}
}
channels.delete() // 清理内存
return chwArray // 返回转换后的数据
}
async function tileProc(
inputTensor: ort.Tensor,
session: ort.InferenceSession,
callback: (progress: number) => void
) {
const inputDims = inputTensor.dims
const imageW = inputDims[3]
const imageH = inputDims[2]
const rOffset = 0
const gOffset = imageW * imageH
const bOffset = imageW * imageH * 2
const outputDims = [
inputDims[0],
inputDims[1],
inputDims[2] * 4,
inputDims[3] * 4,
]
const outputTensor = new ort.Tensor(
'float32',
new Float32Array(
outputDims[0] * outputDims[1] * outputDims[2] * outputDims[3]
),
outputDims
)
const outImageW = outputDims[3]
const outImageH = outputDims[2]
const outROffset = 0
const outGOffset = outImageW * outImageH
const outBOffset = outImageW * outImageH * 2
const tileSize = 64
const tilePadding = 6
const tileSizePre = tileSize - tilePadding * 2
const tilesx = Math.ceil(inputDims[3] / tileSizePre)
const tilesy = Math.ceil(inputDims[2] / tileSizePre)
const { data } = inputTensor
console.log(inputTensor)
const numTiles = tilesx * tilesy
let currentTile = 0
for (let i = 0; i < tilesx; i++) {
for (let j = 0; j < tilesy; j++) {
const ti = Date.now()
const tileW = Math.min(tileSizePre, imageW - i * tileSizePre)
const tileH = Math.min(tileSizePre, imageH - j * tileSizePre)
console.log(`tileW: ${tileW} tileH: ${tileH}`)
const tileROffset = 0
const tileGOffset = tileSize * tileSize
const tileBOffset = tileSize * tileSize * 2
// padding tile 转移到上面的数据上
const tileData = new Float32Array(tileSize * tileSize * 3)
for (let xp = -tilePadding; xp < tileSizePre + tilePadding; xp++) {
for (let yp = -tilePadding; yp < tileSizePre + tilePadding; yp++) {
// 计算在data中的一维坐标,防止边缘溢出
let xim = i * tileSizePre + xp
if (xim < 0) xim = 0
else if (xim >= imageW) xim = imageW - 1
// 计算在data中的一维坐标,防止边缘溢出
let yim = j * tileSizePre + yp
if (yim < 0) yim = 0
else if (yim >= imageH) yim = imageH - 1
const idx = xim + yim * imageW
const xt = xp + tilePadding
const yt = yp + tilePadding
// const idx = (i * tileSize + x) + (j * tileSize + y) * imageW;
// 主要转化到一维的坐标上,
tileData[xt + yt * tileSize + tileROffset] = data[idx + rOffset]
tileData[xt + yt * tileSize + tileGOffset] = data[idx + gOffset]
tileData[xt + yt * tileSize + tileBOffset] = data[idx + bOffset]
}
}
const tile = new ort.Tensor('float32', tileData, [
1,
3,
tileSize,
tileSize,
])
const r = await session.run({ 'input.1': tile })
const results = {
output: r['1895'],
}
console.log(`pre dims:${results.output.dims}`)
const outTileW = tileW * 4
const outTileH = tileH * 4
const outTileSize = tileSize * 4
const outTileSizePre = tileSizePre * 4
const outTileROffset = 0
const outTileGOffset = outTileSize * outTileSize
const outTileBOffset = outTileSize * outTileSize * 2
// add tile to output,直接输出
for (let x = 0; x < outTileW; x++) {
for (let y = 0; y < outTileH; y++) {
const xim = i * outTileSizePre + x
const yim = j * outTileSizePre + y
const idx = xim + yim * outImageW
const xt = x + tilePadding * 4
const yt = y + tilePadding * 4
outputTensor.data[idx + outROffset] =
results.output.data[xt + yt * outTileSize + outTileROffset]
outputTensor.data[idx + outGOffset] =
results.output.data[xt + yt * outTileSize + outTileGOffset]
outputTensor.data[idx + outBOffset] =
results.output.data[xt + yt * outTileSize + outTileBOffset]
}
}
currentTile++
const dt = Date.now() - ti
const remTime = (numTiles - currentTile) * dt
console.log(
`tile ${currentTile} of ${numTiles} took ${dt} ms, remaining time: ${remTime} ms`
)
callback(Math.round(100 * (currentTile / numTiles)))
}
}
console.log(`output dims:${outputTensor.dims}`)
return outputTensor
}
function processImage(
img: HTMLImageElement,
canvasId?: string
): Promise<Uint8Array> {
return new Promise((resolve, reject) => {
try {
const src = cv.imread(img)
// eslint-disable-next-line camelcase
const src_rgb = new cv.Mat()
// 将图像从RGBA转换为RGB
cv.cvtColor(src, src_rgb, cv.COLOR_RGBA2RGB)
if (canvasId) {
cv.imshow(canvasId, src_rgb)
}
resolve(imgProcess(src_rgb))
src.delete()
src_rgb.delete()
} catch (error) {
reject(error)
}
})
}
function configEnv(capabilities: {
webgpu: any
wasm?: boolean
simd: any
threads: any
}) {
ort.env.wasm.wasmPaths =
'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.16.3/dist/'
if (capabilities.webgpu) {
ort.env.wasm.numThreads = 1
} else {
if (capabilities.threads) {
ort.env.wasm.numThreads = navigator.hardwareConcurrency ?? 4
}
if (capabilities.simd) {
ort.env.wasm.simd = true
}
ort.env.wasm.proxy = true
}
console.log('env', ort.env.wasm)
}
function postProcess(floatData: Float32Array, width: number, height: number) {
const chwToHwcData = []
const size = width * height
for (let h = 0; h < height; h++) {
for (let w = 0; w < width; w++) {
for (let c = 0; c < 3; c++) {
// RGB通道
const chwIndex = c * size + h * width + w
const pixelVal = floatData[chwIndex]
let newPiex = pixelVal
if (pixelVal > 1) {
newPiex = 1
} else if (pixelVal < 0) {
newPiex = 0
}
chwToHwcData.push(newPiex * 255) // 归一化反转
}
chwToHwcData.push(255) // Alpha通道
}
}
return chwToHwcData
}
function imageDataToDataURL(imageData: ImageData) {
// 创建 canvas
const canvas = document.createElement('canvas')
canvas.width = imageData.width
canvas.height = imageData.height
// 绘制 imageData 到 canvas
const ctx = canvas.getContext('2d')
ctx.putImageData(imageData, 0, 0)
// 导出为数据 URL
return canvas.toDataURL()
}
let model: ArrayBuffer | null = null
export default async function superResolution(
imageFile: File | HTMLImageElement,
callback: (progress: number) => void
) {
console.time('sessionCreate')
if (!model) {
const capabilities = await getCapabilities()
configEnv(capabilities)
const modelBuffer = await ensureModel('superResolution')
model = await ort.InferenceSession.create(modelBuffer, {
executionProviders: [capabilities.webgpu ? 'webgpu' : 'wasm'],
})
}
console.timeEnd('sessionCreate')
const img =
imageFile instanceof HTMLImageElement
? imageFile
: await loadImage(URL.createObjectURL(imageFile))
const imageTersorData = await processImage(img)
const imageTensor = new ort.Tensor('float32', imageTersorData, [
1,
3,
img.height,
img.width,
])
const result = await tileProc(imageTensor, model, callback)
console.time('postProcess')
const outsTensor = result
const chwToHwcData = postProcess(
outsTensor.data,
img.width * 4,
img.height * 4
)
const imageData = new ImageData(
new Uint8ClampedArray(chwToHwcData),
img.width * 4,
img.height * 4
)
console.log(imageData, 'imageData')
const url = imageDataToDataURL(imageData)
console.timeEnd('postProcess')
return url
}