27 lines
832 B
TypeScript
27 lines
832 B
TypeScript
import * as tf from '@tensorflow/tfjs-core';
|
|
|
|
import { disposeUnusedWeightTensors, extractWeightEntryFactory, FCParams, ParamMapping } from '../common';
|
|
import { NetParams } from './types';
|
|
|
|
export function extractParamsFromWeigthMap(
|
|
weightMap: tf.NamedTensorMap
|
|
): { params: NetParams, paramMappings: ParamMapping[] } {
|
|
|
|
const paramMappings: ParamMapping[] = []
|
|
|
|
const extractWeightEntry = extractWeightEntryFactory(weightMap, paramMappings)
|
|
|
|
function extractFcParams(prefix: string): FCParams {
|
|
const weights = extractWeightEntry<tf.Tensor2D>(`${prefix}/weights`, 2)
|
|
const bias = extractWeightEntry<tf.Tensor1D>(`${prefix}/bias`, 1)
|
|
return { weights, bias }
|
|
}
|
|
|
|
const params = {
|
|
fc: extractFcParams('fc')
|
|
}
|
|
|
|
disposeUnusedWeightTensors(weightMap, paramMappings)
|
|
|
|
return { params, paramMappings }
|
|
} |