个人记录
- tensorrt自定义插件层
- getOutputDimensions的调用接口
新定义的层需要定义getOutputDimensions方法,那这个是在供哪里调用的呢?接口的参数又是?
nvinfer1::Dims ResizeBilinearPlugin::getOutputDimensions(int index,
const nvinfer1::Dims *inputDims,
int nbInputs) {
assert(nbInputs == 1);
nvinfer1::Dims const& input = inputDims[0];
assert(is_CHW(input));
assert(_ndims == 2);
assert(index == 0);
nvinfer1::Dims output;
output.nbDims = input.nbDims;
int s = 0;
for( int d=0; d<input.nbDims; ++d ) {
output.type[d] = input.type[d];
if( input.type[d] == nvinfer1::DimensionType::kSPATIAL ) {
output.d[d] = int(input.d[d] * _scale[s++]);
} else {
output.d[d] = input.d[d];
}
}
return output;
}
其是用TypeSerializingPlugin进行了封装,
auto* wrapped_plugin = new TypeSerializingPlugin(plugin); // 这里进行了封装
然后调用TypeSerializingPlugin的父类 PluginAdapter 的方法
nvinfer1::Dims PluginAdapter::getOutputDimensions(int index,
const nvinfer1::Dims *inputDims,
int nbInputs) {
return _plugin->getOutputDimensions(index, inputDims, nbInputs); // 这里调用的 具体的类的输出尺寸
}
然后继续追寻PluginAdapter的调用就没有源码了,应该是打到tensorrt的库中了。
同理,距离类的initialize也是在PluginAdapter中调用的
int PluginAdapter::initialize() { return _plugin->initialize(); } // 这里调用了具体类的 initialize
同理是具体层的enqueue
以上三个接口在tensorrt中封装死了,也看不到调用的接口。