Skip to content

Commit

Permalink
Avx512 extract most significant bits (#82731)
Browse files Browse the repository at this point in the history
* Add `TYP_MASK` and `Vector512.ExtractMostSignificantBits`.

* Rebase / rename error fix.

* Review edits.

* Formatting.

* Review edits.

* Review cleanup.

* Build fixes.

* Address throughput issues pertaining to `availableRegCount`.

* kmov RR refactor.

* Split kmov into kmov_msk and kmov_gpr.

* Fix thread.

* Review edits.
  • Loading branch information
anthonycanino authored Mar 12, 2023
1 parent 5265218 commit 8599161
Show file tree
Hide file tree
Showing 18 changed files with 323 additions and 23 deletions.
3 changes: 2 additions & 1 deletion src/coreclr/jit/emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8293,12 +8293,13 @@ void emitter::emitDispDataSec(dataSecDsc* section, BYTE* dst)
i += j;
break;

case 64:
case 32:
case 16:
case 8:
assert((data->dsSize % 8) == 0);
printf("\tdq\t%016llXh", *reinterpret_cast<uint64_t*>(&data->dsCont[i]));
for (j = 8; j < 32; j += 8)
for (j = 8; j < 64; j += 8)
{
if (i + j >= data->dsSize)
break;
Expand Down
112 changes: 109 additions & 3 deletions src/coreclr/jit/emitxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,15 @@ bool emitter::IsSSEOrAVXInstruction(instruction ins)
return (ins >= INS_FIRST_SSE_INSTRUCTION) && (ins <= INS_LAST_AVX_INSTRUCTION);
}

bool emitter::IsKInstruction(instruction ins)
{
return (ins >= INS_FIRST_K_INSTRUCTION) && (ins <= INS_LAST_K_INSTRUCTION);
}

//------------------------------------------------------------------------
// IsAvx512OrPriorInstruction: Is this an Avx512 or Avx or Sse instruction.
// IsAvx512OrPriorInstruction: Is this an Avx512 or Avx or Sse or K (opmask) instruction.
// Technically, K instructions would be considered under the VEX encoding umbrella, but due to
// the instruction table encoding had to be pulled out with the rest of the `INST5` definitions.
//
// Arguments:
// ins - The instruction to check.
Expand All @@ -46,7 +53,7 @@ bool emitter::IsSSEOrAVXInstruction(instruction ins)
bool emitter::IsAvx512OrPriorInstruction(instruction ins)
{
// TODO-XArch-AVX512: Fix check once AVX512 instructions are added.
return (ins >= INS_FIRST_SSE_INSTRUCTION) && (ins <= INS_LAST_AVX512_INSTRUCTION);
return ((ins >= INS_FIRST_SSE_INSTRUCTION) && (ins <= INS_LAST_AVX512_INSTRUCTION));
}

bool emitter::IsAVXOnlyInstruction(instruction ins)
Expand Down Expand Up @@ -263,6 +270,15 @@ bool emitter::IsEvexEncodedInstruction(instruction ins) const
case INS_vbroadcastf128: // INS_vbroadcastf32x4, INS_vbroadcastf64x2.
case INS_vbroadcasti128: // INS_vbroadcasti32x4, INS_vbroadcasti64x2.

case INS_kmovb_msk:
case INS_kmovw_msk:
case INS_kmovd_msk:
case INS_kmovq_msk:
case INS_kmovb_gpr:
case INS_kmovw_gpr:
case INS_kmovd_gpr:
case INS_kmovq_gpr:

// TODO-XARCH-AVX512 these need to be encoded with the proper individual EVEX instructions (movdqu8,
// movdqu16 etc)
// For implementation speed, I have set it up so the standing instruction will default to the 32-bit operand
Expand Down Expand Up @@ -1248,6 +1264,8 @@ bool emitter::TakesRexWPrefix(instruction ins, emitAttr attr)
case INS_vpgatherqq:
case INS_vgatherdpd:
case INS_vgatherqpd:
case INS_vpmovw2m:
case INS_vpmovq2m:
return true;
default:
break;
Expand Down Expand Up @@ -1294,6 +1312,9 @@ bool emitter::TakesRexWPrefix(instruction ins, emitAttr attr)
case INS_shlx:
case INS_sarx:
case INS_shrx:
case INS_kmovq_msk:
case INS_kmovq_gpr:
case INS_kmovd_msk:
return true;
default:
return false;
Expand Down Expand Up @@ -3478,6 +3499,10 @@ inline UNATIVE_OFFSET emitter::emitInsSizeRR(instrDesc* id)
// Otherwise, it will be placed after the 4 byte encoding, making the total 5 bytes.
// This would probably be better expressed as a different format or something?
code_t code = insCodeRM(ins);
if (IsKInstruction(ins))
{
code = AddVexPrefix(ins, code, EA_SIZE(id->idOpSize()));
}

UNATIVE_OFFSET sz = emitGetAdjustedSize(id, code);

Expand Down Expand Up @@ -5856,6 +5881,14 @@ bool emitter::IsMovInstruction(instruction ins)
case INS_movupd:
case INS_movups:
case INS_movzx:
case INS_kmovb_msk:
case INS_kmovw_msk:
case INS_kmovd_msk:
case INS_kmovq_msk:
case INS_kmovb_gpr:
case INS_kmovw_gpr:
case INS_kmovd_gpr:
case INS_kmovq_gpr:
{
return true;
}
Expand Down Expand Up @@ -6006,6 +6039,19 @@ bool emitter::HasSideEffect(instruction ins, emitAttr size)
}
#endif // TARGET_AMD64

case INS_kmovb_msk:
case INS_kmovw_msk:
case INS_kmovd_msk:
case INS_kmovq_msk:
case INS_kmovb_gpr:
case INS_kmovw_gpr:
case INS_kmovd_gpr:
case INS_kmovq_gpr:
{
hasSideEffect = true;
break;
}

default:
{
unreached();
Expand Down Expand Up @@ -6223,6 +6269,25 @@ void emitter::emitIns_Mov(instruction ins, emitAttr attr, regNumber dstReg, regN
}
#endif // TARGET_AMD64

case INS_kmovb_msk:
case INS_kmovw_msk:
case INS_kmovd_msk:
case INS_kmovq_msk:
{
assert((isMaskReg(dstReg) || isMaskReg(srcReg)) && !isGeneralRegister(dstReg) &&
!isGeneralRegister(srcReg));
break;
}

case INS_kmovb_gpr:
case INS_kmovw_gpr:
case INS_kmovd_gpr:
case INS_kmovq_gpr:
{
assert(isGeneralRegister(dstReg) || isGeneralRegister(srcReg));
break;
}

default:
{
unreached();
Expand Down Expand Up @@ -9619,6 +9684,11 @@ const char* emitter::emitRegName(regNumber reg, emitAttr attr, bool varName)
#ifdef TARGET_AMD64
char suffix = '\0';

if (isMaskReg(reg))
{
return rn;
}

switch (EA_SIZE(attr))
{
case EA_64BYTE:
Expand Down Expand Up @@ -13843,7 +13913,18 @@ BYTE* emitter::emitOutputRR(BYTE* dst, instrDesc* id)
{
assert((ins != INS_movd) || (isFloatReg(reg1) != isFloatReg(reg2)));

if ((ins != INS_movd) || isFloatReg(reg1))
if (ins == INS_kmovb_gpr || ins == INS_kmovw_gpr || ins == INS_kmovd_gpr || ins == INS_kmovq_gpr)
{
assert(!(isGeneralRegister(reg1) && isGeneralRegister(reg2)));

code = insCodeRM(ins);
if (isGeneralRegister(reg1))
{
// kmov r, k form, flip last byte of opcode from 0x92 to 0x93
code |= 0x01;
}
}
else if ((ins != INS_movd) || isFloatReg(reg1))
{
code = insCodeRM(ins);
}
Expand Down Expand Up @@ -18150,6 +18231,31 @@ emitter::insExecutionCharacteristics emitter::getInsExecutionCharacteristics(ins
break;
}
#endif

case INS_vpmovb2m:
case INS_vpmovw2m:
case INS_vpmovd2m:
case INS_vpmovq2m:
{
result.insLatency += PERFSCORE_LATENCY_1C;
result.insThroughput = PERFSCORE_THROUGHPUT_1C;
break;
}

case INS_kmovb_msk:
case INS_kmovw_msk:
case INS_kmovd_msk:
case INS_kmovq_msk:
case INS_kmovb_gpr:
case INS_kmovw_gpr:
case INS_kmovd_gpr:
case INS_kmovq_gpr:
{
result.insLatency += PERFSCORE_LATENCY_3C;
result.insThroughput = PERFSCORE_THROUGHPUT_1C;
break;
}

default:
// unhandled instruction insFmt combination
perfScoreUnhandledInstruction(id, &result);
Expand Down
6 changes: 6 additions & 0 deletions src/coreclr/jit/emitxarch.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ inline static bool isDoubleReg(regNumber reg)
return isFloatReg(reg);
}

inline static bool isMaskReg(regNumber reg)
{
return (reg >= REG_MASK_FIRST && reg <= REG_MASK_LAST);
}

/************************************************************************/
/* Routines that compute the size of / encode instructions */
/************************************************************************/
Expand Down Expand Up @@ -96,6 +101,7 @@ static bool IsAvx512OnlyInstruction(instruction ins);
static bool IsFMAInstruction(instruction ins);
static bool IsAVXVNNIInstruction(instruction ins);
static bool IsBMIInstruction(instruction ins);
static bool IsKInstruction(instruction ins);

static regNumber getBmiRegNumber(instruction ins);
static regNumber getSseShiftRegNumber(instruction ins);
Expand Down
48 changes: 48 additions & 0 deletions src/coreclr/jit/hwintrinsiccodegenxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,54 @@ void CodeGen::genAvxFamilyIntrinsic(GenTreeHWIntrinsic* node)
break;
}

case NI_AVX512F_MoveMaskSpecial:
{
op1Reg = op1->GetRegNum();
regNumber maskReg = node->ExtractTempReg(RBM_ALLMASK);

instruction maskIns;
instruction kmovIns;

// TODO-XARCH-AVX512 note that this type/kmov combination assumes 512-bit vector types but would change
// if used for other vector lengths, i.e., TYPE_BYTE requires kmovq for for 512-bit vector, but kmovd
// for 256-bit vector.
switch (baseType)
{
case TYP_BYTE:
case TYP_UBYTE:
maskIns = INS_vpmovb2m;
kmovIns = INS_kmovq_gpr;
break;
case TYP_SHORT:
case TYP_USHORT:
maskIns = INS_vpmovw2m;
kmovIns = INS_kmovd_gpr;
break;
case TYP_INT:
case TYP_UINT:
case TYP_FLOAT:
maskIns = INS_vpmovd2m;
kmovIns = INS_kmovw_gpr;
break;
case TYP_DOUBLE:
case TYP_LONG:
case TYP_ULONG:
maskIns = INS_vpmovq2m;
kmovIns = INS_kmovb_gpr;
break;
default:
unreached();
}

// TODO-XARCH-AVX512 remove REG_K1 check when all K registers possible for
// allocation.
assert(emitter::isMaskReg(maskReg) && maskReg == REG_K1);

emit->emitIns_R_R(maskIns, attr, maskReg, op1Reg);
emit->emitIns_Mov(kmovIns, EA_8BYTE, targetReg, maskReg, INS_FLAGS_DONT_CARE);
break;
}

default:
unreached();
break;
Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/jit/hwintrinsiclistxarch.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ HARDWARE_INTRINSIC(Vector512, StoreAligned,
HARDWARE_INTRINSIC(Vector512, StoreAlignedNonTemporal, 64, 2, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoCodeGen)
HARDWARE_INTRINSIC(Vector512, StoreUnsafe, 64, -1, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoCodeGen)

HARDWARE_INTRINSIC(Vector512, ExtractMostSignificantBits, 64, 1, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoCodeGen)

// ***************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************
// ISA Function name SIMD size NumArg Instructions Category Flags
// {TYP_BYTE, TYP_UBYTE, TYP_SHORT, TYP_USHORT, TYP_INT, TYP_UINT, TYP_LONG, TYP_ULONG, TYP_FLOAT, TYP_DOUBLE}
Expand Down Expand Up @@ -884,6 +886,8 @@ HARDWARE_INTRINSIC(SSE2, UCOMISD,
HARDWARE_INTRINSIC(SSE41, PTEST, 16, 2, {INS_ptest, INS_ptest, INS_ptest, INS_ptest, INS_ptest, INS_ptest, INS_ptest, INS_ptest, INS_invalid, INS_invalid}, HW_Category_SimpleSIMD, HW_Flag_NoRMWSemantics|HW_Flag_NoEvexSemantics)
HARDWARE_INTRINSIC(AVX, PTEST, 0, 2, {INS_ptest, INS_ptest, INS_ptest, INS_ptest, INS_ptest, INS_ptest, INS_ptest, INS_ptest, INS_vtestps, INS_vtestpd}, HW_Category_SimpleSIMD, HW_Flag_NoRMWSemantics|HW_Flag_NoEvexSemantics)

HARDWARE_INTRINSIC(AVX512F, MoveMaskSpecial, 64, 1, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_movd, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMDScalar, HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoContainment|HW_Flag_SpecialCodeGen|HW_Flag_NoRMWSemantics)

#endif // FEATURE_HW_INTRINSIC

#undef HARDWARE_INTRINSIC
Expand Down
14 changes: 14 additions & 0 deletions src/coreclr/jit/hwintrinsicxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,20 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Vector512_ExtractMostSignificantBits:
{
if (IsBaselineVector512IsaSupported())
{
var_types simdType = getSIMDTypeForSize(simdSize);

op1 = impSIMDPopStack(simdType);

retNode = gtNewSimdHWIntrinsicNode(retType, op1, NI_AVX512F_MoveMaskSpecial, simdBaseJitType, simdSize,
/* isSimdAsHWIntrinsic */ false);
}
break;
}

case NI_Vector128_ExtractMostSignificantBits:
case NI_Vector256_ExtractMostSignificantBits:
{
Expand Down
3 changes: 2 additions & 1 deletion src/coreclr/jit/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ const char* CodeGen::genInsDisplayName(emitter::instrDesc* id)
static char buf[4][TEMP_BUFFER_LEN];
const char* retbuf;

if (GetEmitter()->IsVexEncodedInstruction(ins) && !GetEmitter()->IsBMIInstruction(ins))
if (GetEmitter()->IsVexEncodedInstruction(ins) && !GetEmitter()->IsBMIInstruction(ins) &&
!GetEmitter()->IsKInstruction(ins))
{
sprintf_s(buf[curBuf], TEMP_BUFFER_LEN, "v%s", insName);
retbuf = buf[curBuf];
Expand Down
21 changes: 21 additions & 0 deletions src/coreclr/jit/instrsxarch.h
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,22 @@ INST3(shrx, "shrx", IUM_WR, BAD_CODE, BAD_CODE,

INST3(LAST_BMI_INSTRUCTION, "LAST_BMI_INSTRUCTION", IUM_WR, BAD_CODE, BAD_CODE, BAD_CODE, INS_TT_NONE, INS_FLAGS_None)

INST3(FIRST_K_INSTRUCTION, "FIRST_K_INSTRUCTION", IUM_WR, BAD_CODE, BAD_CODE, BAD_CODE, INS_TT_NONE, INS_FLAGS_None )

INST3(kmovb_msk, "kmovb", IUM_WR, PACK3(0x66, 0x0F, 0x91), BAD_CODE, PACK3(0x66, 0x0F, 0x90), INS_TT_NONE, INS_FLAGS_None )
INST3(kmovw_msk, "kmovw", IUM_WR, PACK2(0x0F, 0x91), BAD_CODE, PACK2(0x0F, 0x90), INS_TT_NONE, INS_FLAGS_None )
INST3(kmovd_msk, "kmovd", IUM_WR, PACK3(0xF2, 0x0F, 0x91), BAD_CODE, PACK3(0xF2, 0x0F, 0x90), INS_TT_NONE, INS_FLAGS_None )
INST3(kmovq_msk, "kmovq", IUM_WR, PACK3(0xF2, 0x0F, 0x91), BAD_CODE, PACK3(0xF2, 0x0F, 0x90), INS_TT_NONE, INS_FLAGS_None )


INST3(kmovb_gpr, "kmovb", IUM_WR, BAD_CODE, BAD_CODE, PACK3(0x66, 0x0F, 0x92), INS_TT_NONE, INS_FLAGS_None )
INST3(kmovw_gpr, "kmovw", IUM_WR, BAD_CODE, BAD_CODE, PACK2(0x0F, 0x92), INS_TT_NONE, INS_FLAGS_None )
INST3(kmovd_gpr, "kmovd", IUM_WR, BAD_CODE, BAD_CODE, PACK3(0xF2, 0x0F, 0x92), INS_TT_NONE, INS_FLAGS_None )
INST3(kmovq_gpr, "kmovq", IUM_WR, BAD_CODE, BAD_CODE, PACK3(0xF2, 0x0F, 0x92), INS_TT_NONE, INS_FLAGS_None )

INST3(LAST_K_INSTRUCTION, "LAST_K_INSTRUCTION", IUM_WR, BAD_CODE, BAD_CODE, BAD_CODE, INS_TT_NONE, INS_FLAGS_None )


INST3(LAST_AVX_INSTRUCTION, "LAST_AVX_INSTRUCTION", IUM_WR, BAD_CODE, BAD_CODE, BAD_CODE, INS_TT_NONE, INS_FLAGS_None)

INST3(FIRST_AVX512_INSTRUCTION, "FIRST_AVX512_INSTRUCTION", IUM_WR, BAD_CODE, BAD_CODE, BAD_CODE, INS_TT_NONE, INS_FLAGS_None)
Expand All @@ -650,6 +666,11 @@ INST3(vinsertf32x8, "insertf32x8", IUM_WR, BAD_CODE, BAD_CODE,
INST3(vinserti32x8, "inserti32x8", IUM_WR, BAD_CODE, BAD_CODE, SSE3A(0x3A), INS_TT_TUPLE8, Input_32Bit | INS_Flags_IsDstDstSrcAVXInstruction) // Insert 256-bit packed quadword integer values
INST3(LAST_AVX512DQ_INSTRUCTION, "LAST_AVX512DQ_INSTRUCTION", IUM_WR, BAD_CODE, BAD_CODE, BAD_CODE, INS_TT_NONE, INS_FLAGS_None)

INST3(vpmovb2m, "vpmovb2m", IUM_WR, BAD_CODE, BAD_CODE, PACK4(0xF3, 0x0F, 0x38, 0x29), INS_TT_NONE, Input_8Bit)
INST3(vpmovw2m, "vpmovw2m", IUM_WR, BAD_CODE, BAD_CODE, PACK4(0xF3, 0x0F, 0x38, 0x29), INS_TT_NONE, Input_16Bit)
INST3(vpmovd2m, "vpmovd2m", IUM_WR, BAD_CODE, BAD_CODE, PACK4(0xF3, 0x0F, 0x38, 0x39), INS_TT_NONE, Input_32Bit)
INST3(vpmovq2m, "vpmovq2m", IUM_WR, BAD_CODE, BAD_CODE, PACK4(0xF3, 0x0F, 0x38, 0x39), INS_TT_NONE, Input_64Bit)

INST3(LAST_AVX512_INSTRUCTION, "LAST_AVX512_INSTRUCTION", IUM_WR, BAD_CODE, BAD_CODE, BAD_CODE, INS_TT_NONE, INS_FLAGS_None)

// Scalar instructions in SSE4.2
Expand Down
Loading

0 comments on commit 8599161

Please sign in to comment.