21 lines
581 B
TypeScript
21 lines
581 B
TypeScript
import { isTensor } from '../utils';
|
|
import { ParamMapping } from './types';
|
|
|
|
export function extractWeightEntryFactory(weightMap: any, paramMappings: ParamMapping[]) {
|
|
|
|
return function<T> (originalPath: string, paramRank: number, mappedPath?: string): T {
|
|
const tensor = weightMap[originalPath]
|
|
|
|
if (!isTensor(tensor, paramRank)) {
|
|
throw new Error(`expected weightMap[${originalPath}] to be a Tensor${paramRank}D, instead have ${tensor}`)
|
|
}
|
|
|
|
paramMappings.push(
|
|
{ originalPath, paramPath: mappedPath || originalPath }
|
|
)
|
|
|
|
return tensor
|
|
}
|
|
|
|
}
|