From 302cb9dc44f2c35df4aa76c1acbce5fa2d03ad81 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Tue, 6 Feb 2024 18:06:50 +0000 Subject: [PATCH] address feedback: II --- stablehlo/dialect/VhloBytecode.cpp | 69 ++++++++---------- stablehlo/dialect/VhloDialect.td | 2 +- stablehlo/dialect/VhloTypes.td | 4 +- .../stablehlo_legalize_to_vhlo.0_17_0.mlir.bc | Bin 0 -> 17636 bytes .../stablehlo_legalize_to_vhlo.0_18_0.mlir | 6 +- .../stablehlo_legalize_to_vhlo.0_18_0.mlir.bc | Bin 17764 -> 17764 bytes ...o_to_version_downgrade_invalid.0_16_0.mlir | 42 ++++++----- 7 files changed, 62 insertions(+), 61 deletions(-) create mode 100644 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc diff --git a/stablehlo/dialect/VhloBytecode.cpp b/stablehlo/dialect/VhloBytecode.cpp index f730ada631c..aee3fb65063 100644 --- a/stablehlo/dialect/VhloBytecode.cpp +++ b/stablehlo/dialect/VhloBytecode.cpp @@ -1225,41 +1225,35 @@ VhloBytecodeInterface::readUniformQuantizedPerAxisV1Type( LOG_READ_CALL; uint64_t flags; Type storageType, expressedType; - FailureOr scale; + uint64_t quantizedDimension; + int64_t storageTypeMin, storageTypeMax; SmallVector scales; SmallVector zeroPoints; - int64_t quantizedDimension, numQuantizationParams, storageTypeMin, - storageTypeMax; - if (failed(reader.readVarInt(flags)) || - failed(reader.readType(storageType)) || - failed(reader.readType(expressedType)) || - failed(reader.readSignedVarInt(quantizedDimension)) || - failed(reader.readSignedVarInt(numQuantizationParams))) - return reader.emitError("invalid UniformQuantizedPerAxisType"), - UniformQuantizedPerAxisV1Type(); - - for (int64_t i = 0; i < numQuantizationParams; i++) { - if (failed(scale = reader.readAPFloatWithKnownSemantics( - llvm::APFloat::IEEEdouble()))) - return reader.emitError("invalid UniformQuantizedPerAxisType"), - UniformQuantizedPerAxisV1Type(); - scales.push_back(scale.value()); - } - - for (int64_t i = 0; i < numQuantizationParams; i++) { - if (failed(reader.readSignedVarInt(zeroPoints.emplace_back()))) - return reader.emitError("invalid UniformQuantizedPerAxisType"), - UniformQuantizedPerAxisV1Type(); + auto readScales = [&]() -> FailureOr { + return reader.readAPFloatWithKnownSemantics(llvm::APFloat::IEEEdouble()); + }; + auto readZeroPoints = [&]() -> FailureOr { + int64_t temp; + if (succeeded(reader.readSignedVarInt(temp))) { + return temp; + } + return failure(); + }; + if (succeeded(reader.readVarInt(flags)) && + succeeded(reader.readType(storageType)) && + succeeded(reader.readType(expressedType)) && + succeeded(reader.readVarInt(quantizedDimension)) && + succeeded(reader.readSignedVarInt(storageTypeMin)) && + succeeded(reader.readSignedVarInt(storageTypeMax)) && + succeeded(reader.readList(scales, readScales)) && + succeeded(reader.readList(zeroPoints, readZeroPoints))) { + return UniformQuantizedPerAxisV1Type::get( + getContext(), flags, storageType, expressedType, quantizedDimension, + scales, zeroPoints, storageTypeMin, storageTypeMax); } - if (failed(reader.readSignedVarInt(storageTypeMin)) || - failed(reader.readSignedVarInt(storageTypeMax))) - return reader.emitError("invalid UniformQuantizedPerAxisType"), - UniformQuantizedPerAxisV1Type(); - - return UniformQuantizedPerAxisV1Type::get( - getContext(), flags, storageType, expressedType, quantizedDimension, - scales, zeroPoints, storageTypeMin, storageTypeMax); + return reader.emitError("invalid UniformQuantizedPerAxisType"), + UniformQuantizedPerAxisV1Type(); } void VhloBytecodeInterface::write(UniformQuantizedPerAxisV1Type type, @@ -1268,15 +1262,14 @@ void VhloBytecodeInterface::write(UniformQuantizedPerAxisV1Type type, writer.writeVarInt(type.getFlags()); writer.writeType(type.getStorageType()); writer.writeType(type.getExpressedType()); - writer.writeSignedVarInt(type.getQuantizedDimension()); - int64_t numQuantizationParams = type.getScales().size(); - writer.writeSignedVarInt(numQuantizationParams); - for (auto scale : type.getScales()) - writer.writeAPFloatWithKnownSemantics(APFloat(scale)); - for (auto zeroPoint : type.getZeroPoints()) - writer.writeSignedVarInt(zeroPoint); + writer.writeVarInt(type.getQuantizedDimension()); writer.writeSignedVarInt(type.getStorageTypeMin()); writer.writeSignedVarInt(type.getStorageTypeMax()); + writer.writeList(type.getScales(), [&](const APFloat &type) { + writer.writeAPFloatWithKnownSemantics(type); + }); + writer.writeList(type.getZeroPoints(), + [&](int64_t type) { writer.writeSignedVarInt(type); }); } //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td index cc2c2a13599..cdb2c18531d 100644 --- a/stablehlo/dialect/VhloDialect.td +++ b/stablehlo/dialect/VhloDialect.td @@ -35,7 +35,7 @@ def VHLO_Dialect : Dialect { 0.15.0: MLIR bytecode version 5 => 6, use properties in VHLO. 0.16.0: Introduce `collective_broadcast` operation. 0.17.0: Allow reduce operations to promote to higher bitwidth. - 0.18.0: Allow serialization of UniformQuantizedPerAxisType. + 0.18.0: Introduce `UniformQuantizedPerAxisType` type. }]; let useDefaultAttributePrinterParser = 0; diff --git a/stablehlo/dialect/VhloTypes.td b/stablehlo/dialect/VhloTypes.td index be4406bfb9c..bdb6c1519a6 100644 --- a/stablehlo/dialect/VhloTypes.td +++ b/stablehlo/dialect/VhloTypes.td @@ -252,7 +252,7 @@ def VHLO_UniformQuantizedPerAxisV1 : VHLO_TypeDef<"UniformQuantizedPerAxisV1", " "unsigned":$flags, "::mlir::Type":$storageType, "::mlir::Type":$expressedType, - "int64_t":$quantizedDimension, + "int32_t":$quantizedDimension, VHLO_QuantizationScalesV1:$scales, ArrayRefParameter<"int64_t">:$zeroPoints, "int64_t":$storageTypeMin, @@ -263,7 +263,7 @@ def VHLO_UniformQuantizedPerAxisV1 : VHLO_TypeDef<"UniformQuantizedPerAxisV1", " LogicalResult UniformQuantizedPerAxisV1Type::verify( llvm::function_ref errFn, unsigned int, mlir::Type storageType, mlir::Type expressedType, - int64_t, ::llvm::ArrayRef<::llvm::APFloat>, ::llvm::ArrayRef, int64_t, int64_t) { + int32_t, ::llvm::ArrayRef<::llvm::APFloat>, ::llvm::ArrayRef, int64_t, int64_t) { if (!isFromVhlo(storageType) || !isFromVhlo(expressedType)) return errFn() << "expected VHLO type"; return success(); diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_17_0.mlir.bc new file mode 100644 index 0000000000000000000000000000000000000000..f6b993b17c6be8e5af166066f84452056d827a6b GIT binary patch literal 17636 zcmch9e~c8xw{Le<_ttGD!@OoPOqNM889K{m;4b@J_5v3On1#pifD3$quskp@ncbaT z26t!H*&oOn@QD#4MvNFSV#Fs#j2ba&#HbOYUNLIKE9Q$4BSwrEF=D&Gqc{IQ{$q7^_%AbB=FIqurk`V` zZvLS25dS_BjVF>$uN9KmGFS?{#cEo5&`!scZ(D%@(o@TgBSh7S_#*Y?$q0 zd)Pj9fE{8-*im+j9cL%lNp_l@W#`#Nc8OhKb$mRZ$S3oud#<@EHw7KaQ+xMu7yWLFJX860y)tg-O*`zgELmZGDcS ze#LM8&(dfd)AP?Z{{Px;^*yUsPiFUceMY0vWJEdFb-F5I1|!jkz`uAr+A5MF9&sYk zsMydLjm8~EBoYb8ffIPe5slGkYcd{ppf4MzHrZ$*0lr8kouq&$z(yw#O^8S|5r=Lx zN+v+4kVq!+frutfqPejVpGo4QeY_EH9O6+h9ctAW74alVpi`1x5l_G>^`EE`D!Ts? zi4a855LU#a?83h$IX6Exf8x6QIU3iUw@e<5o5_}x=Dl=}n!58wG#-|}ByX^N%=sl6 zx03g_`E8g7_|)e^q*822`9;5#&@DZ=g{OS2KJI+r`~XS68H)eSYXb9pu1~@8<$P+H z+EGyz`0ksq>lhrSp~ZweyYht@D#JjydC*Gl4minEZ#cmO1Ur>0r)oCVx%w zQ@<%2wG3Qr#naL|mGA`@5 z^Sd0!9frYjrw-%gaw3}J5w-J?o4HJ zIhU)rT+5wl7%Z2YxHALe<+6i2GcjZ?w{cnIvdrZ$mm^&6;&L~4W@DV(nTMfrXCX$* zokbWhcQP0=mj}5##N}Zwk8t@E=Qw7e%7;*Xp37rgzQW~k?ktBJxU&kL z;PO2#PjY#RJ8=7H?recKxID|{IWEt0d4bD|T>i-AB`z;>d4)@6$T~xgGn{Vt!H^RS zIni*oVcrcn*^pBVr--RH({P5FoNYKGOwKc$T}&=CWEjpxCNCS#B}9keTt9ljX$dnn{h5d!>uf^m8ZT?#JsSo)SQl?k-rNQ!7euZ`4t-gRy zuch!>{J_TVZ>T=?dx<7e(L|ba5l!fFxhYp+$4Vr-xz72Fx|pu(N!_CEqj8=RDKo+A zdA-if7hb@LYeslTgmlZYga`>s@U$+{yiKI_2d%W8XiDpI(`Fm$W?Kq?nNIV{E+J&v zYQsw;(q@W8cBZX#F2@*(@wC2&LD9}=+N=mv3L-5Q)lK0Y7F3d2D-I1PSoj>N0}Y2q z>e@soEmYsavKt^`#w-lehW6tut-pu&t7-i-$T_?(;Qb@sSMaXGD8}PG5%0-(PsMu% z-m~$}@U&^eSZ<}$W*Q33v@*ktr+LH9!06Ix%UZHiV&pcuG;P>wwal6|fr)0Ww+b$% zk7HLIn%Rc-B1Ik%De@Z1_~dp1Hwdqely+cry>7 z_MmO*2nic8quaV}^G4oiV~gvvc^~JtDQ(^g5RGi1>m@WYBM~04`7Qhwo7+_CR(f)FV1iDa&AyYNPY8ZJJ0q>`*=0*9&~QMO196DGrhgNN|{CS~Mcq zOayF8sY}{sI?AKA+1yN>!q!J^TyhlmF`Gwu%nrqLq?;(1oUAvw!!fVC&4B=qK@WeN z6qDPaIEfD|H}W5xNcNe{RJUTj5^81+===%t)RStUPuteekgl80JPX-|1~PcS9s6?_ z`*s?;A}`p8e6rIb`_Zb+HeY;+zl`3{waqa(I^O}xn{B?B>n;3M;cYj&|5w`;ep%VI z*fy6e;Y%=}rAzrz+g!GcFRQmr$Kj4`u2{iWz$q(N@|AG->eYNTm<@>o1JAj%699{lLTrk|E5}4Vl`O22CJ{Z5;a(@ zow5c?*2tx5FsBAHZc5y$2dIJ5=2^q$gChTCy2}q zBJ-=r8l5i?Mi?6f*L3K<&TF?)3up8;Y$}yr7YC6gL1d|im?X0d+fPMig@=4f$V#k( z3bNW8`P`Hi$>?i%Yhd=))nMyuunjfX#v1H;7lX%Q8GW-_=WrR#;$QV_5X;qKok1*L zi}eVjnOh-jN1tnnj9%cnK-khu2xz^b7RAVeXvwoxN-dhvZ{(4{7zb*wp&D#^4Ys2O zyIB}VVmzj`X7m^7fF&2Z!TXXYr^mFGjQ%p6!Ikav)rwqmMt`kV?)9o%8eomw z8(wgjTB==timMTDgjC2_S3N;Ks_vB6if2;wu`o8!0x-44cKvVe z7z5|fCjl*9%zi4226xHWG@xpje|8P4(aYzaOicR@GVZ*oefo>a0>H}rGPp9o3a-qr zsh?}8_gCgOL4j{^i}eNS>oxLX*pCQb;5)BxwUoc-fu;O|*Og^y4Uh~Tn}4rb@hzar z{(wRLB#i5jw}o1e%e83VWb~gYZ41$KTIf@R!V$FkMHq8<*llHMU!xVTue3q<{Z;>m z+d@E(_LZl_TQovg)Z$mqowoKR*-)K=-@GB2+7}u9_nO_Rnhj#NEM1KCMM2`hBV`aS{veiIagfxu=G%eL`~H z5Y?eE#n)+Te*=cZ*PpbhzRr|}&2$+J_l^Aw%3&x0XdfJ-2+YNYAjUq04>5*50S#RgqP`MtA>S%lM-K zjqSdiuDx3;w+TpsbL$;OO9=BqN81)K)x7H3+Znxs45;|3TVXIA!+Hz49t(7{ttyAy zsaPXd_T_N+yg@bu+U#RDK;=vKB-z7ij*`7VTZ3l_ zJX3R+?DHG=r^$X+dz>6V1Lcec?R7O(v8po;0-Hq(FReE;G-$73T(R19cSuQNXy&`x ztAK;Y++kqKa&w9Haz;PGcw4|JPbm}|vU%$g?InovU@#x`W#N(+Gy3zaCL=or)Qg}O z$k$biU#TL`)gs4%sKj?{3i@VM=-FDK6F^j6R6T6p151s{J#A02n&b8qQ02|E_Kfmo zU>TkUhFmEciab(_T&W@t*CGrD z`#N_&eW)GJX55q0o5p>L(Ce69U)+qp*pBf55Pt;`g^CmehCkWj#islZgExcSvIoNEDehL3L6wYxL=OgJNe&h~q>wY#7; z5JmI+9uX&Zk}(0Pg}_>PC}gE+x!wUZhNm8;Ab(o$*&<)-0_}DZtxj|Xh!;1=_m;zy zVBD+%dOdAmrmi)qnWJ7Q(!qlrB4`J7MyKCe)d1nHYqw?ecDkfh;%bwx@M~Y;6IF#* z%RNZp76^HH7jnxEFMO_P16^smeUY~*9o@$HlYg4w6n)daR0(OB^QWt(VXRlpNQWV` z6C0)V?YC%xJaz=%gFUN!PYF&qV%6*QuAh)nnXbQZ-9jdA(XoV2Si1{a{NdahbSbdi zz~~^m1qWFJxtI>JJsxvskhvF372u|JGn5Qey}l2~pXDt3(Z?+Sn~Vg8vr!1H6p zo}*L8bNkf-kMjLv5OePi->a*`EBycc?r?mpt{2FJKe~rp%dR(}OXY)fX52^~3Lc>+ z6iN#;rs10QUhVdC(nr&_RvMF{odT*N(YCcdD7FHd%;^Aj!zk=b04t2b&I0oypoeU& z#^t#xlCMS1D}C5QOszAcU%&&;s?Hmm8#U{S}A6DncXRNhH2Mh!5IAIfaEyfDL5W0X5)X?aRTV>G4{;c-fVSXaI6Huiz8}u>u{-Zn zN6C_a2N(L--FGW&X#iW~WB1;xuw?-(<70?m+HQjpyxhl@EK!&fkXz+r_usFu6#;Cm zk3IOH!d3>bb{~8AVTG*@V4Hkw*)oMS1h6eW_UNMuTNA)KeC+YZ6}B#bb^F+pPbzGE z0Ndtc&pe~B4FRm^W6wROu#Ew%>|-y!sIcn;*szbi{IbIE9JF%J9r3Z(UQ<{$fbH@z zY*9670-b3-HhP1jqs_+_EKqX6tLI)Hd*cm-N(ANg4!H&;Xb zfUocNu{ItArfvYS%Hy^}syQ~Ln6@mVA2x!|W{v<;w~W8jmg8nxjyU?;#tTTN)iPZX3d|oV9tVcV`D=s9PZn^a@~TtZFA<%SFZ-s;>ngow6UeRF`a1cn$W%N&9V2Oyvzlql^%zznnrcc% zBes`SM16Tbmj&J^DehqBq_8j z=eveF^V#i%VpnN9G5T^h75awy0$4zvO!Fkrl$29frLr<{An(>Ya|5nju4>gu=PzorR((@J@5AxT+x9-QK=Xx^r(>!Y%L&?%7vb-gW29v zPXS|5Is)umxiGl3FFy$1qZe18H}A&ZKz@)soG%WLCkMPiR&6Q|cm|Y9Lzw?!9?^_> z$q(mM@3q`q_Xzb+23q!mWMFHqUm1*bFj!FYRZt;P&i7+XxvXF7%aK1+9D6a?U(RK`oh6-IKZ|BhU_TyUEg-Q`!z zBN&vGeM4B>{k5Q+zk+{Vz^$JOU$to(cN2e2avz1AN?Q-T`vTyFtch!H3WSh4fS`?*eb5_ z`C}NESbwhkHl7dtsz02r!f>HWxvO(vu#7E(1nF1YG`h}mX<(qqo82aFJ#=E)V8meQ z1~+axOMV>FLe9G}j^Y5;fN}@Q*v#@pZ(3*{rA6)ASh19VUaTVO7!is!kk54$u=$PT zVTWS(=K9p+b><81*3sd1SKVbGT&+~M9PGMk6AGBvksHLqEz+idC5m}=x7h%jcD{1H z=q}a(zo+cRV-6#nNp|JBY1j7lqrm=EwFHge54aU4%?xbvWlifYm9)QTgFCRKTwl`G zmCtL0X5GW4Fo1>JRlwR&!JfbI55VFdN%s69IQR!@aUx#Zpx&{PYCKqq7}gj;g<%=w zt|2yvCDvOiW9PWKDD9K^uB&51*!2Qv;DFBc7qYa8(WZ!2{n&`Q3l%ljXjsB7Tqx#x zz0R)D3R60W_3EY!h@bzv*082(34#doKGW6ERp|6w)Qzp(bzApnP63$Pywa#VmQ|?o zs(Ub2L3g1i+m-LmVUKi`j6t{J9Qp$YM~912dXnqyfyZz%)#|K;2Dg^;`D}N2$JONe zaZs0Tsv=$3oXdp{H!JWC?`xp6g#DSSp3%h;Eo=9{b_FZdYXiBiZKXod8<_HY{wC~P z6gvai4mvgmG|y{lxx0Bs4a+_7YFKWrQeky^sZ5nICgZspRM|`dpvp)C^-3aDvS`gebw*)sYFW|XC7qQz_Pwfl%2@-5%7wjr zcQo5gkt^KMt;QzY)uu?0wNif;IlLlWU5fbIyQ);uO~2;NIEhu0?(J9&+5U2=4|^J& zrW8!B&PwvC!m7r4ymNw#6&9YkxGW{)m8RRWhC@2QP*KrbhCX`^ezEZ zcCQAiR8qsa+Es#J_<}y&J1`Z!xp8kKib0nYDvZ6u_*|D_2UD$Rzp9q7GkEu*0PmVx zhIfbZ_pyq;%JYIhlk_&ks-X7};Sc1zE1YYdf4iev@b2fLDuH*4^ytJbscWMuB0Z>i z=HW3Fskw;a$~l%vw^PNX%PSQEUhs-|Byx?t_9{C)X&C#!3nm#y%RM8z^#q#YXPxQgm93dot{T*4t@Kp-$ofqc3@sP6 zv}9HVSZH3YcXaz*RSSNX-F}b$4Yk`xDh$7cyZsOJWs!Q`fJ=t*L4`R~46;23ss|Zv z!@xbzbEA5;pyfK67UuBzuN|EzbZY;?Co%0`j^j_M@GW{s#Knu2IXz8ZgIGgE3^OqT(I=fX=KR$pwhOh{bun{umb%|fx{_iKk&R+6k-XSP{<53k z(6HE#<|szQ0VWPo`3bOl#3A%&yS`f{4!cT@keBcSuA{KNGZYceGjWVq`1RFsYO!6s zNo0pOfkDAVJH&e^%nFN}#YqCf;*?8H6Df!@K+!RN19grVJH>e-8^i_Zw=qL;ago&H zcS@It;$30zF%)hV2VDSU9joQLO7(@E4|G8a?G z-(fKoDw-l<8W%HM#!UPb+9Q+e5>x6Tk4~$5l)N^Z5X?G$av)n5c(YDC%sR)c69=%?xiwuf8My*zw?AIAUNfCA|$@Tr^o(a%vXo+rKdrNS#vGZAWT6~{^NHt{BrCE^6^ zOiZYYi1)ZSNp0~R|7lnC8JC=O$vGk|;ymn=H1-QrPl}5!`O$53iOB8ZvRk`CWSL+F z$Wl>fSUeFC;|ww0u*Axh(6UoZAS@{+8t8eFAtqDxHZg_B5-}C+LiLD%X@;0Vtgx7A zSk}0@@vs-)9M2U zo*_<>`zW+p#3=(ilv<{zsZCOxahsl{rsoWCo^Y(h3x>EztRD?=39KJ6`o-chajqDG znV@y17-w4MMKdDCn_>d7CYoZB39Yw_$wZckDK42x;51XrAUe|&vx&|##X^Eth(%P* zm|{86Ri;==wA~b&h;A`O2hna*Y$IAUMVaWZDMpCyGR1D9drYyH=sr{ICwjmX2Z OpS&|7&YR+bY5gC#b9EH} literal 0 HcmV?d00001 diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir index 5695d1a9fb1..70e8d5e80dc 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir @@ -1296,9 +1296,9 @@ func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tenso } // CHECK-LABEL: "op_dynamic_reshape" -func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { - // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 - %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor) -> tensor +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%arg0, %arg1) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.0_18_0.mlir.bc index ccd9c1d914de8ea11d8f02ba9d3be3c9903d4e73..b0fa65099658a3142f8f83a1df4a8fe15e5d93c8 100644 GIT binary patch delta 1238 zcmYk4UuYav6vp?=x3{;G=_Z^^)-gNQZL+%wV+;{Z4C%71>ym^JLmI;-q-hLEYfCHm zry+(k#5I-}Hbrc_MA}M?K2&IGX?2Mbi#~{@wEEw0Df0Wmzt?Ol`HW4?l0J z$~n3-Hl}IPA&mas$Wr-V%}s=pF;+u@>aZOzvWm*`$Cvh{VJjg+t>s<+NoZ`O^di8@pL sUKdOK7GQS>Gyh+icZ=ookzh>?73;+NC delta 1175 zcmYk4T}&KR7>4(ocMr3}bRpj^>j)EEU|Gmm5u-&y$+oVWg%BhZ$wH7)ve2Jc6DuW^ z&$g^53nF)0^~sfkfonRarf*L>n+Eg z#1`xIrwH|TruSJy4A<5$V7OF@#RpgM5rfFhS*sS9Q8QW@Y5867?)fx}?rwBPSoHLu zhedBMdaGH)Vu-Qm>qB3ZMSnl~>sbs8U_iZ$#}Q{SIEcY!7MTn(K^DWq7_MY7GJ=tI z7NesW?PM`FhOrJ7Y}qq$ah1po8?+a^Ch|{k3~5gzBkYU-4vMyO}-rN z(B>B8C*f9i`Zau&ZWOLmZBW27zKw;#k-MHC!!xVVkSKg+eL=g5cBeJvm#@Ar&|38S zvO3`!am^u>^_iNOl@cnJiY8RF%bMHfseK32uf{(l?J75$dT!>GjXg3`RJ T@|!2+9 diff --git a/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.0_16_0.mlir b/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.0_16_0.mlir index 0ac5297f1c3..6d1a12ad99e 100644 --- a/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.0_16_0.mlir +++ b/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.0_16_0.mlir @@ -103,21 +103,29 @@ func.func @select_and_scatter_with_promotable_types( %0 = stablehlo.constant dense<0.000000e+00> : tensor // expected-error @+1 {{failed to legalize operation 'vhlo.select_and_scatter_v1' that was explicitly marked illegal}} - %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ - ^bb0(%arg3: tensor, %arg4: tensor): - %2 = "stablehlo.compare"(%arg3, %arg4) { - comparison_direction = #stablehlo - } : (tensor, tensor) -> tensor - "stablehlo.return"(%2) : (tensor) -> () - }, { - ^bb0(%arg3: tensor, %arg4: tensor): - %2 = stablehlo.add %arg3, %arg4 : tensor - "stablehlo.return"(%2) : (tensor) -> () - }) { - window_dimensions = array, - window_strides = array, - padding = dense<0> : tensor<4x2xi64> - } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> - tensor<10x24x24x64xf64> - func.return + %1 = "stablehlo.select_and_scatter"(%arg0, %arg1, %0) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = "stablehlo.compare"(%arg3, %arg4) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + "stablehlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %2 = stablehlo.add %arg3, %arg4 : tensor + "stablehlo.return"(%2) : (tensor) -> () + }) { + window_dimensions = array, + window_strides = array, + padding = dense<0> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> + tensor<10x24x24x64xf64> + func.return +} + +// ----- + +// expected-error @+1 {{failed to legalize operation 'vhlo.func_v1' that was explicitly marked illegal}} +func.func @type_per_axis_quantization(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { + %0 = stablehlo.add %arg0, %arg0 : tensor<2x!quant.uniform> + func.return %0 : tensor<2x!quant.uniform> }