模板匹配(Template matching, TM)是一种解码端推导方法,用来细化当前CU的运动信息,使得当前CU的MV更准确。
TM主要是通过寻找一个MV使得当前图片的模板(当前 CU 的顶部和/或左侧相邻块)和参考图片的模板之间的匹配误差最小。如下图所示,在 [– 8, +8] 像素搜索范围内围绕当前 CU 的初始 MV 搜索更好的 MV。其中TM是基于 AMVR 模式确定搜索步长,并且 TM 可以在Merge模式下与双边匹配(bilateral matching, BM)过程级联。
在AMVP模式下,仅对特定的MV候选项进行细化,具体地,根据模板匹配误差确定进行TM细化的MVP候选项:选取当前块模板与参考块模板差异最小的MVP候选项进行TM细化。 TM 通过使用迭代菱形搜索,从 [–8, +8] 像素搜索范围内的全像素 MVD 精度(或AMVR 模式下的 4 像素)开始优化此 MVP 候选。 可以通过使用具有全像素 MVD 精度(或AMVR 模式下的 4 像素)的交叉搜索来进一步细化 AMVP 候选,然后根据表 1 中指定的 AMVR 模式依次进行半像素和四分之一像素搜索。 这个搜索过程确保MVP候选在TM过程之后仍然保持与AMVR模式所指示的相同的MV精度。
在Merge模式下,对Merge索引所指示的Merge候选者应用相似的搜索方法。 如表 1 所示,TM 可以一直执行到 1/8 像素 MVD 精度或跳过那些超过半像素 MVD 精度的,这取决于是否根据Merge的运动信息使用替代插值滤波器(alternative interpolation filter, AltIF,当 AMVR 处于半像素模式时使用) 。 此外,当启用 TM 模式时,模板匹配可以作为基于块和基于子块的双边匹配 (BM) 方法之间的独立过程或额外的 MV 细化过程,这取决于BM是否可以根据其启用条件检查启用。
相关代码
ECM中,TM细化MV的入口函数是deriveTMMv函数,需要注意的是Merge模式下对全部的MV候选项都会进行TM细化,而AMVP模式下仅对候选列表中的模板匹配误差最小的MV进行TM细化,二者调用的函数不同,如下所示:
#if TM_MRG // Merge模式下调用的函数 void deriveTMMv (PredictionUnit& pu); #endif // 对特定的MV进行细化 Distortion deriveTMMv (const PredictionUnit& pu, bool fillCurTpl, Distortion curBestCost, RefPicList eRefList, int refIdx, int maxSearchRounds, Mv& mv, const MvField* otherMvf = nullptr);
这两个函数的代码及注释如下所示:
void InterPrediction::deriveTMMv(PredictionUnit& pu) { if( !pu.tmMergeFlag ) { return; } Distortion minCostUni[NUM_REF_PIC_LIST_01] = { std::numeric_limits<Distortion>::max(), std::numeric_limits<Distortion>::max() }; for (int iRefList = 0; iRefList < ( pu.cu->slice->isInterB() ? NUM_REF_PIC_LIST_01 : 1 ) ; ++iRefList) { if (pu.interDir & (iRefList + 1)) { minCostUni[iRefList] = deriveTMMv(pu, true, std::numeric_limits<Distortion>::max(), (RefPicList)iRefList, pu.refIdx[iRefList], TM_MAX_NUM_OF_ITERATIONS, pu.mv[iRefList]); } } if (pu.cu->slice->isInterB() && pu.interDir == 3 #if MULTI_PASS_DMVR && !PU::checkBDMVRCondition(pu) #endif ) { if (minCostUni[0] == std::numeric_limits<Distortion>::max() || minCostUni[1] == std::numeric_limits<Distortion>::max()) { return; } RefPicList eTargetPicList = (minCostUni[0] <= minCostUni[1]) ? REF_PIC_LIST_1 : REF_PIC_LIST_0; MvField mvfBetterUni(pu.mv[1 - eTargetPicList], pu.refIdx[1 - eTargetPicList]); Distortion minCostBi = deriveTMMv(pu, true, std::numeric_limits<Distortion>::max(), eTargetPicList, pu.refIdx[eTargetPicList], TM_MAX_NUM_OF_ITERATIONS, pu.mv[eTargetPicList], &mvfBetterUni); if (minCostBi > (minCostUni[1 - eTargetPicList] + (minCostUni[1 - eTargetPicList] >> 3))) { pu.interDir = 1 + (1 - eTargetPicList); pu.mv [eTargetPicList] = Mv(); pu.refIdx[eTargetPicList] = NOT_VALID; } } }
#if TM_AMVP || TM_MRG // maxSearchRounds 最大搜索次数,为0时表示不进行搜索,仅计算初始MV对应的模板的Cost Distortion InterPrediction::deriveTMMv(const PredictionUnit& pu, bool fillCurTpl, Distortion curBestCost, RefPicList eRefList, int refIdx, int maxSearchRounds, Mv& mv, const MvField* otherMvf) { CHECK(refIdx < 0, "Invalid reference index for TM"); const CodingUnit& cu = *pu.cu; const Picture& refPic = *cu.slice->getRefPic(eRefList, refIdx)->unscaledPic; bool doSimilarityCheck = otherMvf == nullptr ? false : cu.slice->getRefPOC((RefPicList)eRefList, refIdx) == cu.slice->getRefPOC((RefPicList)(1 - eRefList), otherMvf->refIdx); InterPredResources interRes(m_pcReshape, m_pcRdCost, m_if, m_filteredBlockTmp[0][COMPONENT_Y] , m_filteredBlock[3][1][0], m_filteredBlock[3][0][0] ); // 构造函数,获取当前模板和参考模板 TplMatchingCtrl tplCtrl(pu, interRes, refPic, fillCurTpl, COMPONENT_Y, true, maxSearchRounds, m_pcCurTplAbove, m_pcCurTplLeft, m_pcRefTplAbove, m_pcRefTplLeft, mv, (doSimilarityCheck ? &(otherMvf->mv) : nullptr), curBestCost); if (!tplCtrl.getTemplatePresentFlag()) { // 如果上模板和左模板都不存在 return std::numeric_limits<Distortion>::max(); } if (otherMvf == nullptr) // uni prediction 单向预测 { tplCtrl.deriveMvUni<TM_TPL_SIZE>(); mv = tplCtrl.getFinalMv(); // 返回最终细化的MV return tplCtrl.getMinCost(); // 返回最小的代价 } else // bi prediction 双向预测 { const Picture& otherRefPic = *cu.slice->getRefPic((RefPicList)(1-eRefList), otherMvf->refIdx)->unscaledPic; // 另一个方向的参考帧 // 当前模板减去另一个方向的参考模板 tplCtrl.removeHighFreq<TM_TPL_SIZE>(otherRefPic, otherMvf->mv, getBcwWeight(cu.BcwIdx, eRefList)); tplCtrl.deriveMvUni<TM_TPL_SIZE>(); mv = tplCtrl.getFinalMv(); int8_t intWeight = getBcwWeight(cu.BcwIdx, eRefList); return (tplCtrl.getMinCost() * intWeight + (g_BcwWeightBase >> 1)) >> g_BcwWeightBase; } }
TM过程中模板的获取以及搜索过程都是通过TplMatchingCtrl类控制的,代码如下所示:
class TplMatchingCtrl { enum TMSearchMethod { TMSEARCH_DIAMOND, TMSEARCH_CROSS, TMSEARCH_NUMBER_OF_METHODS }; const CodingUnit& m_cu; const PredictionUnit& m_pu; InterPredResources& m_interRes; const Picture& m_refPic; const Mv m_mvStart; Mv m_mvFinal; const Mv* m_otherRefListMv; Distortion m_minCost; bool m_useWeight; int m_maxSearchRounds; ComponentID m_compID; PelBuf m_curTplAbove; PelBuf m_curTplLeft; PelBuf m_refTplAbove; PelBuf m_refTplLeft; PelBuf m_refSrAbove; // pre-filled samples on search area PelBuf m_refSrLeft; // pre-filled samples on search area #if JVET_X0056_DMVD_EARLY_TERMINATION Distortion m_earlyTerminateTh; #endif #if MULTI_PASS_DMVR Distortion m_tmCostArrayDiamond[9]; Distortion m_tmCostArrayCross[5]; #endif public: // 构造函数,获取当前模板和参考模板 TplMatchingCtrl(const PredictionUnit& pu, InterPredResources& interRes, // Bridge required resource from InterPrediction const Picture& refPic, const bool fillCurTpl, const ComponentID compID, const bool useWeight, const int maxSearchRounds, Pel* curTplAbove, Pel* curTplLeft, Pel* refTplAbove, Pel* refTplLeft, const Mv& mvStart, const Mv* otherRefListMv, const Distortion curBestCost ); // 返回模板是否存在 bool getTemplatePresentFlag() { return m_curTplAbove.buf != nullptr || m_curTplLeft.buf != nullptr; } Distortion getMinCost () { return m_minCost; } // 返回最小的cost Mv getFinalMv () { return m_mvFinal; } // 返回最终细化后的MV static int getDeltaMean (const PelBuf& bufCur, const PelBuf& bufRef, const int rowSubShift, const int bd); template <int tplSize> void deriveMvUni (); // 推导单向MV template <int tplSize> void removeHighFreq (const Picture& otherRefPic, const Mv& otherRefMv, const uint8_t curRefBcwWeight); private: template <int tplSize, bool TrueA_FalseL> bool xFillCurTemplate (Pel* tpl); template <int tplSize, bool TrueA_FalseL, int sr> PelBuf xGetRefTemplate (const PredictionUnit& curPu, const Picture& refPic, const Mv& _mv, PelBuf& dstBuf); template <int tplSize, bool TrueA_FalseL> void xRemoveHighFreq (const Picture& otherRefPic, const Mv& otherRefMv, const uint8_t curRefBcwWeight); template <int tplSize, int searchPattern> void xRefineMvSearch (int maxSearchRounds, int searchStepShift); #if MULTI_PASS_DMVR template <int searchPattern> void xNextTmCostAarray (int bestDirect); template <int searchPattern> void xDeriveCostBasedMv (); template <bool TrueX_FalseY> void xDeriveCostBasedOffset (Distortion costLorA, Distortion costCenter, Distortion costRorB, int log2StepSize); int xBinaryDivision (int64_t numerator, int64_t denominator, int fracBits); #endif template <int tplSize> Distortion xGetTempMatchError (const Mv& mv); template <int tplSize, bool TrueA_FalseL> Distortion xGetTempMatchError (const Mv& mv); };
TM模式中的当前模板的获取和参考模板的获取是在TplMatchingCtrl类的构造函数中实现的,分别调用xFillCurTemplate函数和xGetRefTemplate函数实现当前模板的获取和参考模板的获取。
#if TM_AMVP || TM_MRG TplMatchingCtrl::TplMatchingCtrl( const PredictionUnit& pu, InterPredResources& interRes, const Picture& refPic, const bool fillCurTpl, const ComponentID compID, const bool useWeight, const int maxSearchRounds, Pel* curTplAbove, Pel* curTplLeft, Pel* refTplAbove, Pel* refTplLeft, const Mv& mvStart, const Mv* otherRefListMv, const Distortion curBestCost ) : m_cu (*pu.cu) , m_pu (pu) , m_interRes (interRes) , m_refPic (refPic) , m_mvStart (mvStart) , m_mvFinal (mvStart) , m_otherRefListMv (otherRefListMv) , m_minCost (curBestCost) , m_useWeight (useWeight) , m_maxSearchRounds (maxSearchRounds) , m_compID (compID) { // Initialization 初始化 // 填充当前模板 const bool tplAvalableAbove = xFillCurTemplate<TM_TPL_SIZE, true >((fillCurTpl ? curTplAbove : nullptr)); // 上侧模板可用 const bool tplAvalableLeft = xFillCurTemplate<TM_TPL_SIZE, false>((fillCurTpl ? curTplLeft : nullptr)); // 左侧模板可用 m_curTplAbove = tplAvalableAbove ? PelBuf(curTplAbove, pu.lwidth(), TM_TPL_SIZE ) : PelBuf(); m_curTplLeft = tplAvalableLeft ? PelBuf(curTplLeft , TM_TPL_SIZE, pu.lheight()) : PelBuf(); // 参考模板 m_refTplAbove = tplAvalableAbove ? PelBuf(refTplAbove, m_curTplAbove ) : PelBuf(); m_refTplLeft = tplAvalableLeft ? PelBuf(refTplLeft , m_curTplLeft ) : PelBuf(); #if JVET_X0056_DMVD_EARLY_TERMINATION m_earlyTerminateTh = TM_TPL_SIZE * ((tplAvalableAbove ? m_pu.lwidth() : 0) + (tplAvalableLeft ? m_pu.lheight() : 0)); #endif // Pre-interpolate samples on search area 在搜索区域预插样本 // 上参考模板以及其相邻长度为 8 的搜索范围 m_refSrAbove = tplAvalableAbove && maxSearchRounds > 0 ? PelBuf(interRes.m_preFillBufA, m_curTplAbove.width + 2 * TM_SEARCH_RANGE, m_curTplAbove.height + 2 * TM_SEARCH_RANGE) : PelBuf(); if (m_refSrAbove.buf != nullptr) { m_refSrAbove = xGetRefTemplate<TM_TPL_SIZE, true, TM_SEARCH_RANGE>(m_pu, m_refPic, mvStart, m_refSrAbove); m_refSrAbove = m_refSrAbove.subBuf(Position(TM_SEARCH_RANGE, TM_SEARCH_RANGE), m_curTplAbove); // 定位到搜索参考模板的初始位置 } // 左参考模板 m_refSrLeft = tplAvalableLeft && maxSearchRounds > 0 ? PelBuf(interRes.m_preFillBufL, m_curTplLeft .width + 2 * TM_SEARCH_RANGE, m_curTplLeft .height + 2 * TM_SEARCH_RANGE) : PelBuf(); if (m_refSrLeft.buf != nullptr) { m_refSrLeft = xGetRefTemplate<TM_TPL_SIZE, false, TM_SEARCH_RANGE>(m_pu, m_refPic, mvStart, m_refSrLeft); m_refSrLeft = m_refSrLeft.subBuf(Position(TM_SEARCH_RANGE, TM_SEARCH_RANGE), m_curTplLeft); } }
xFillCurTemplate函数获取当前模板:
template <int tplSize, bool TrueA_FalseL> bool TplMatchingCtrl::xFillCurTemplate(Pel* tpl) { const Position posOffset = TrueA_FalseL ? Position(0, -tplSize) : Position(-tplSize, 0); // 位置偏移 // 相邻CU const CodingUnit* const cuNeigh = m_cu.cs->getCU(m_pu.blocks[m_compID].pos().offset(posOffset), toChannelType(m_compID)); if (cuNeigh == nullptr) // 相邻CU不可用,直接返回FALSE { return false; } if (tpl == nullptr) // 存储模板的指针为空,返回 { return true; } const Picture& currPic = *m_cu.cs->picture; // 当前帧 const CPelBuf recBuf = currPic.getRecoBuf(m_cu.cs->picture->blocks[m_compID]); // 当前帧的重建分量 std::vector<Pel>& invLUT = m_interRes.m_pcReshape->getInvLUT(); const bool useLUT = isLuma(m_compID) && m_cu.cs->picHeader->getLmcsEnabledFlag() && m_interRes.m_pcReshape->getCTUFlag(); #if JVET_W0097_GPM_MMVD_TM & TM_MRG if (m_cu.geoFlag) { CHECK(m_pu.geoTmType == GEO_TM_OFF, "invalid geo template type value"); if (m_pu.geoTmType == GEO_TM_SHAPE_A) { if (TrueA_FalseL == 0) { return false; } } if (m_pu.geoTmType == GEO_TM_SHAPE_L) { if (TrueA_FalseL == 1) { return false; } } } #endif const Size dstSize = (TrueA_FalseL ? Size(m_pu.lwidth(), tplSize) : Size(tplSize, m_pu.lheight())); for (int h = 0; h < (int)dstSize.height; h++) { const Position recPos = TrueA_FalseL ? Position(0, -tplSize + h) : Position(-tplSize, h); const Pel* rec = recBuf.bufAt(m_pu.blocks[m_compID].pos().offset(recPos)); Pel* dst = tpl + h * dstSize.width; for (int w = 0; w < (int)dstSize.width; w++) { int recVal = rec[w]; dst[w] = useLUT ? invLUT[recVal] : recVal; } } return true; }
xGetRefTemplate函数获取参考模板:
template <int tplSize, bool TrueA_FalseL, int sr> PelBuf TplMatchingCtrl::xGetRefTemplate(const PredictionUnit& curPu, const Picture& refPic, const Mv& _mv, PelBuf& dstBuf) { // read from pre-interpolated buffer 从预插值缓冲区读取 PelBuf& refSrBuf = TrueA_FalseL ? m_refSrAbove : m_refSrLeft; // sr = 0 直接从预插值的缓冲区读取样本 if (sr == 0 && refPic.getPOC() == m_refPic.getPOC() && refSrBuf.buf != nullptr) { Mv mvDiff = _mv - m_mvStart; if ((mvDiff.getAbsHor() & ((1 << MV_FRACTIONAL_BITS_INTERNAL) - 1)) == 0 && (mvDiff.getAbsVer() & ((1 << MV_FRACTIONAL_BITS_INTERNAL) - 1)) == 0) { mvDiff >>= MV_FRACTIONAL_BITS_INTERNAL; if (mvDiff.getAbsHor() <= TM_SEARCH_RANGE && mvDiff.getAbsVer() <= TM_SEARCH_RANGE) { return refSrBuf.subBuf(Position(mvDiff.getHor(), mvDiff.getVer()), dstBuf); } } } // Do interpolation on the fly 插值 Position blkPos = ( TrueA_FalseL ? Position(curPu.lx(), curPu.ly() - tplSize) : Position(curPu.lx() - tplSize, curPu.ly()) ); Size blkSize = Size(dstBuf.width, dstBuf.height); Mv mv = _mv - Mv(sr << MV_FRACTIONAL_BITS_INTERNAL, sr << MV_FRACTIONAL_BITS_INTERNAL); clipMv( mv, blkPos, blkSize, *m_cu.cs->sps, *m_cu.cs->pps ); const int lumaShift = 2 + MV_FRACTIONAL_BITS_DIFF; const int horShift = (lumaShift + ::getComponentScaleX(m_compID, m_cu.chromaFormat)); const int verShift = (lumaShift + ::getComponentScaleY(m_compID, m_cu.chromaFormat)); const int xInt = mv.getHor() >> horShift; const int yInt = mv.getVer() >> verShift; const int xFrac = mv.getHor() & ((1 << horShift) - 1); const int yFrac = mv.getVer() & ((1 << verShift) - 1); const CPelBuf refBuf = refPic.getRecoBuf(refPic.blocks[m_compID]); const Pel* ref = refBuf.bufAt(blkPos.offset(xInt, yInt)); Pel* dst = dstBuf.buf; int refStride = refBuf.stride; int dstStride = dstBuf.stride; int bw = (int)blkSize.width; int bh = (int)blkSize.height; const int nFilterIdx = 1; const bool useAltHpelIf = false; const bool biMCForDMVR = false; if ( yFrac == 0 ) { m_interRes.m_if.filterHor( m_compID, (Pel*) ref, refStride, dst, dstStride, bw, bh, xFrac, true, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf ); } else if ( xFrac == 0 ) { m_interRes.m_if.filterVer( m_compID, (Pel*) ref, refStride, dst, dstStride, bw, bh, yFrac, true, true, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf ); } else { const int vFilterSize = isLuma(m_compID) ? NTAPS_BILINEAR : NTAPS_CHROMA; PelBuf tmpBuf = PelBuf(m_interRes.m_ifBuf, Size(bw, bh+vFilterSize-1)); m_interRes.m_if.filterHor( m_compID, (Pel*)ref - ((vFilterSize>>1) -1)*refStride, refStride, tmpBuf.buf, tmpBuf.stride, bw, bh+vFilterSize-1, xFrac, false, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf ); JVET_J0090_SET_CACHE_ENABLE( false ); m_interRes.m_if.filterVer( m_compID, tmpBuf.buf + ((vFilterSize>>1) -1)*tmpBuf.stride, tmpBuf.stride, dst, dstStride, bw, bh, yFrac, false, true, m_cu.chromaFormat, m_cu.slice->clpRng(m_compID), nFilterIdx, biMCForDMVR, useAltHpelIf ); JVET_J0090_SET_CACHE_ENABLE( true ); } return dstBuf; }
在deriveMvUni函数中进行单向MV的细化:
template <int tplSize> void TplMatchingCtrl::deriveMvUni() { if (m_minCost == std::numeric_limits<Distortion>::max()) { m_minCost = xGetTempMatchError<tplSize>(m_mvStart); // 计算初始位置处模板的Cost } if (m_maxSearchRounds <= 0) { return; } // 搜索步长 int searchStepShift = (m_cu.imv == IMV_4PEL ? MV_FRACTIONAL_BITS_INTERNAL + 2 : MV_FRACTIONAL_BITS_INTERNAL); xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_DIAMOND>(m_maxSearchRounds, searchStepShift); xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS >( 1, searchStepShift); xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS >( 1, searchStepShift - 1); #if MULTI_PASS_DMVR if (!m_pu.bdmvrRefine) { #endif xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS >( 1, searchStepShift - 2); xRefineMvSearch<tplSize, TplMatchingCtrl::TMSEARCH_CROSS >( 1, searchStepShift - 3); #if MULTI_PASS_DMVR } else { xDeriveCostBasedMv<TplMatchingCtrl::TMSEARCH_CROSS>(); } #endif }