155 lines
4.7 KiB
TypeScript
155 lines
4.7 KiB
TypeScript
|
import * as tf from '@tensorflow/tfjs-core';
|
||
|
|
||
|
import { ConvParams, extractWeightsFactory, ExtractWeightsFunction, ParamMapping } from '../common';
|
||
|
import { isFloat } from '../utils';
|
||
|
import { ConvLayerParams, NetParams, ResidualLayerParams, ScaleLayerParams } from './types';
|
||
|
|
||
|
function extractorsFactory(extractWeights: ExtractWeightsFunction, paramMappings: ParamMapping[]) {
|
||
|
|
||
|
function extractFilterValues(numFilterValues: number, numFilters: number, filterSize: number): tf.Tensor4D {
|
||
|
const weights = extractWeights(numFilterValues)
|
||
|
const depth = weights.length / (numFilters * filterSize * filterSize)
|
||
|
|
||
|
if (isFloat(depth)) {
|
||
|
throw new Error(`depth has to be an integer: ${depth}, weights.length: ${weights.length}, numFilters: ${numFilters}, filterSize: ${filterSize}`)
|
||
|
}
|
||
|
|
||
|
return tf.tidy(
|
||
|
() => tf.transpose(
|
||
|
tf.tensor4d(weights, [numFilters, depth, filterSize, filterSize]),
|
||
|
[2, 3, 1, 0]
|
||
|
)
|
||
|
)
|
||
|
}
|
||
|
|
||
|
function extractConvParams(
|
||
|
numFilterValues: number,
|
||
|
numFilters: number,
|
||
|
filterSize: number,
|
||
|
mappedPrefix: string
|
||
|
): ConvParams {
|
||
|
|
||
|
const filters = extractFilterValues(numFilterValues, numFilters, filterSize)
|
||
|
const bias = tf.tensor1d(extractWeights(numFilters))
|
||
|
|
||
|
paramMappings.push(
|
||
|
{ paramPath: `${mappedPrefix}/filters` },
|
||
|
{ paramPath: `${mappedPrefix}/bias` }
|
||
|
)
|
||
|
|
||
|
return { filters, bias }
|
||
|
}
|
||
|
|
||
|
function extractScaleLayerParams(numWeights: number, mappedPrefix: string): ScaleLayerParams {
|
||
|
|
||
|
const weights = tf.tensor1d(extractWeights(numWeights))
|
||
|
const biases = tf.tensor1d(extractWeights(numWeights))
|
||
|
|
||
|
paramMappings.push(
|
||
|
{ paramPath: `${mappedPrefix}/weights` },
|
||
|
{ paramPath: `${mappedPrefix}/biases` }
|
||
|
)
|
||
|
|
||
|
return {
|
||
|
weights,
|
||
|
biases
|
||
|
}
|
||
|
}
|
||
|
|
||
|
function extractConvLayerParams(
|
||
|
numFilterValues: number,
|
||
|
numFilters: number,
|
||
|
filterSize: number,
|
||
|
mappedPrefix: string
|
||
|
): ConvLayerParams {
|
||
|
|
||
|
const conv = extractConvParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv`)
|
||
|
const scale = extractScaleLayerParams(numFilters, `${mappedPrefix}/scale`)
|
||
|
|
||
|
return { conv, scale }
|
||
|
}
|
||
|
|
||
|
function extractResidualLayerParams(
|
||
|
numFilterValues: number,
|
||
|
numFilters: number,
|
||
|
filterSize: number,
|
||
|
mappedPrefix: string,
|
||
|
isDown: boolean = false
|
||
|
): ResidualLayerParams {
|
||
|
|
||
|
const conv1 = extractConvLayerParams((isDown ? 0.5 : 1) * numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv1`)
|
||
|
const conv2 = extractConvLayerParams(numFilterValues, numFilters, filterSize, `${mappedPrefix}/conv2`)
|
||
|
|
||
|
return { conv1, conv2 }
|
||
|
}
|
||
|
|
||
|
return {
|
||
|
extractConvLayerParams,
|
||
|
extractResidualLayerParams
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
export function extractParams(weights: Float32Array): { params: NetParams, paramMappings: ParamMapping[] } {
|
||
|
|
||
|
const {
|
||
|
extractWeights,
|
||
|
getRemainingWeights
|
||
|
} = extractWeightsFactory(weights)
|
||
|
|
||
|
const paramMappings: ParamMapping[] = []
|
||
|
|
||
|
const {
|
||
|
extractConvLayerParams,
|
||
|
extractResidualLayerParams
|
||
|
} = extractorsFactory(extractWeights, paramMappings)
|
||
|
|
||
|
const conv32_down = extractConvLayerParams(4704, 32, 7, 'conv32_down')
|
||
|
const conv32_1 = extractResidualLayerParams(9216, 32, 3, 'conv32_1')
|
||
|
const conv32_2 = extractResidualLayerParams(9216, 32, 3, 'conv32_2')
|
||
|
const conv32_3 = extractResidualLayerParams(9216, 32, 3, 'conv32_3')
|
||
|
|
||
|
const conv64_down = extractResidualLayerParams(36864, 64, 3, 'conv64_down', true)
|
||
|
const conv64_1 = extractResidualLayerParams(36864, 64, 3, 'conv64_1')
|
||
|
const conv64_2 = extractResidualLayerParams(36864, 64, 3, 'conv64_2')
|
||
|
const conv64_3 = extractResidualLayerParams(36864, 64, 3, 'conv64_3')
|
||
|
|
||
|
const conv128_down = extractResidualLayerParams(147456, 128, 3, 'conv128_down', true)
|
||
|
const conv128_1 = extractResidualLayerParams(147456, 128, 3, 'conv128_1')
|
||
|
const conv128_2 = extractResidualLayerParams(147456, 128, 3, 'conv128_2')
|
||
|
|
||
|
const conv256_down = extractResidualLayerParams(589824, 256, 3, 'conv256_down', true)
|
||
|
const conv256_1 = extractResidualLayerParams(589824, 256, 3, 'conv256_1')
|
||
|
const conv256_2 = extractResidualLayerParams(589824, 256, 3, 'conv256_2')
|
||
|
const conv256_down_out = extractResidualLayerParams(589824, 256, 3, 'conv256_down_out')
|
||
|
|
||
|
const fc = tf.tidy(
|
||
|
() => tf.transpose(tf.tensor2d(extractWeights(256 * 128), [128, 256]), [1, 0])
|
||
|
)
|
||
|
paramMappings.push({ paramPath: `fc` })
|
||
|
|
||
|
if (getRemainingWeights().length !== 0) {
|
||
|
throw new Error(`weights remaing after extract: ${getRemainingWeights().length}`)
|
||
|
}
|
||
|
|
||
|
const params = {
|
||
|
conv32_down,
|
||
|
conv32_1,
|
||
|
conv32_2,
|
||
|
conv32_3,
|
||
|
conv64_down,
|
||
|
conv64_1,
|
||
|
conv64_2,
|
||
|
conv64_3,
|
||
|
conv128_down,
|
||
|
conv128_1,
|
||
|
conv128_2,
|
||
|
conv256_down,
|
||
|
conv256_1,
|
||
|
conv256_2,
|
||
|
conv256_down_out,
|
||
|
fc
|
||
|
}
|
||
|
|
||
|
return { params, paramMappings }
|
||
|
}
|