ente/thirdparty/face-api/faceRecognitionNet/extractParams.ts

155 lines
4.7 KiB
TypeScript
Raw Normal View History

import * as tf from '@tensorflow/tfjs-core';
import { ConvParams, extractWeightsFactory, ExtractWeightsFunction, ParamMapping } from '../common';
import { isFloat } from '../utils';
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D {
const weights = extractWeights(numFilterValues)
const depth = weights.length / (numFilters * filterSize * filterSize)
if (isFloat(depth)) {
throw new Error(`depth has to be an integer: ${depth}, weights.length: ${weights.length}, numFilters: ${numFilters}, filterSize: ${filterSize}`)
}
return tf.tidy(
() => tf.transpose(
tf.tensor4d(weights, [numFilters, depth, filterSize, filterSize]),
[2, 3, 1, 0]
)
)
}
function extractConvParams(
numFilterValues: number,
numFilters: number,
filterSize: number,
mappedPrefix: string
): ConvParams {
const filters = extractFilterValues(numFilterValues, numFilters, filterSize)
const bias = tf.tensor1d(extractWeights(numFilters))
paramMappings.push(
{ paramPath: `${mappedPrefix}/filters` },
{ paramPath: `${mappedPrefix}/bias` }
)
return { filters, bias }
}
function extractScaleLayerParams(numWeights: number, mappedPrefix: string): ScaleLayerParams {
const weights = tf.tensor1d(extractWeights(numWeights))
const biases = tf.tensor1d(extractWeights(numWeights))
paramMappings.push(
{ paramPath: `${mappedPrefix}/weights` },
{ paramPath: `${mappedPrefix}/biases` }
)
return {
weights,
biases
}
}
function extractConvLayerParams(
numFilterValues: number,
numFilters: number,
filterSize: number,
mappedPrefix: string
): ConvLayerParams {
const conv = extractConvParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv`)
const scale = extractScaleLayerParams(numFilters, `${mappedPrefix}/scale`)
return { conv, scale }
}
function extractResidualLayerParams(
numFilterValues: number,
numFilters: number,
filterSize: number,
mappedPrefix: string,
isDown: boolean = false
): ResidualLayerParams {
const conv1 = extractConvLayerParams((isDown ? 0.5 : 1) * numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv1`)
const conv2 = extractConvLayerParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv2`)
return { conv1, conv2 }
}
return {
extractConvLayerParams,
extractResidualLayerParams
}
}
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
const {
extractWeights,
getRemainingWeights
} = extractWeightsFactory(weights)
const paramMappings: ParamMapping[] = []
const {
extractConvLayerParams,
extractResidualLayerParams
} = extractorsFactory(extractWeights, paramMappings)
const conv32_down = extractConvLayerParams(4704, 32, 7, 'conv32_down')
const conv32_1 = extractResidualLayerParams(9216, 32, 3, 'conv32_1')
const conv32_2 = extractResidualLayerParams(9216, 32, 3, 'conv32_2')
const conv32_3 = extractResidualLayerParams(9216, 32, 3, 'conv32_3')
const conv64_down = extractResidualLayerParams(36864, 64, 3, 'conv64_down', true)
const conv64_1 = extractResidualLayerParams(36864, 64, 3, 'conv64_1')
const conv64_2 = extractResidualLayerParams(36864, 64, 3, 'conv64_2')
const conv64_3 = extractResidualLayerParams(36864, 64, 3, 'conv64_3')
const conv128_down = extractResidualLayerParams(147456, 128, 3, 'conv128_down', true)
const conv128_1 = extractResidualLayerParams(147456, 128, 3, 'conv128_1')
const conv128_2 = extractResidualLayerParams(147456, 128, 3, 'conv128_2')
const conv256_down = extractResidualLayerParams(589824, 256, 3, 'conv256_down', true)
const conv256_1 = extractResidualLayerParams(589824, 256, 3, 'conv256_1')
const conv256_2 = extractResidualLayerParams(589824, 256, 3, 'conv256_2')
const conv256_down_out = extractResidualLayerParams(589824, 256, 3, 'conv256_down_out')
const fc = tf.tidy(
() => tf.transpose(tf.tensor2d(extractWeights(256 * 128), [128, 256]), [1, 0])
)
paramMappings.push({ paramPath: `fc` })
if (getRemainingWeights().length !== 0) {
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
}
const params = {
conv32_down,
conv32_1,
conv32_2,
conv32_3,
conv64_down,
conv64_1,
conv64_2,
conv64_3,
conv128_down,
conv128_1,
conv128_2,
conv256_down,
conv256_1,
conv256_2,
conv256_down_out,
fc
}
return { params, paramMappings }
}