Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build: manually update PyTorch version #3977

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13826,6 +13826,34 @@ def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
}];
}

def Torch_Aten_AssertTensorMetadataOp : Torch_Op<"aten._assert_tensor_metadata", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()`";
let arguments = (ins
AnyTorchTensorType:$a,
AnyTorchOptionalListOfTorchIntType:$size,
AnyTorchOptionalListOfTorchIntType:$stride,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalIntType:$layout
);
let results = (outs
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_AssertTensorMetadataOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 0);
}
void Aten_AssertTensorMetadataOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 0);
}
}];
let hasFolder = 1;
}

def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [
AllowsTypeRefinement,
ReadOnly
Expand Down
48 changes: 48 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5378,6 +5378,54 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) {
return getA();
}

//===----------------------------------------------------------------------===//
// Aten_AssertTensorMetadataOp
//===----------------------------------------------------------------------===//

LogicalResult Aten_AssertTensorMetadataOp::fold(
FoldAdaptor adaptor, SmallVectorImpl<::mlir::OpFoldResult> &results) {
Value input = getA();
auto inputType = cast<BaseTensorType>(input.getType());
if (!inputType.hasDtype() || !inputType.hasSizes())
return failure();

// TODO: Add checks for stride, device, and layout when we can extract that
// information from the torch tensor. For now, we can only get the shape and
// dtype info from the tensor hence adding checks for them.

// convert size to a list of integers.
SmallVector<int64_t> size;
if (!isa<Torch::NoneType>(getSize().getType())) {
if (!matchPattern(getSize(), m_TorchListOfConstantInts(size))) {
return emitOpError("expected dtype to be a constant int");
}
if (!llvm::all_of(llvm::zip(inputType.getSizes(), size),
[](const auto &pair) {
return std::get<0>(pair) == std::get<1>(pair);
}))
return emitOpError("Failed to fold the _assert_tensor_metadata op since "
"the sizes do not match");
}

// convert dtype to an integer.
int64_t dtype;
if (!isa<Torch::NoneType>(getDtype().getType())) {
if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) {
return emitOpError("expected dtype to be a constant int");
}
FailureOr<Type> inputDtype =
getTypeForScalarType(getContext(), (torch_upstream::ScalarType)dtype);
if (failed(inputDtype))
return failure();
if (inputType.getDtype() != inputDtype)
return emitOpError("Failed to fold the _assert_tensor_metadata op since "
"the dtype does not match");
}

getOperation()->erase();
return success();
}

//===----------------------------------------------------------------------===//
// AtenMaxPoolWithIndicesOp
//===----------------------------------------------------------------------===//
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11916,7 +11916,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
Expand All @@ -11928,11 +11938,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool2d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.batch_norm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<tuple<int, int>>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.optional<tuple<int, int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2901,9 +2901,10 @@ def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_facto
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2]))
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2], error_types={torch.uint8}))
def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype != torch.uint8
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2]))
Expand All @@ -2916,14 +2917,16 @@ def aten〇adaptive_avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], output_
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2]))
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2], error_types={torch.uint8}))
def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype != torch.uint8
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2]))
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2], error_types={torch.uint8}))
def aten〇avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype != torch.uint8
return self_dtype

# @check_dtype_function(_check_tensors_with_the_same_dtype(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,10 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)")
emit(
"aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()",
has_folder=True,
)
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")
Expand Down
2 changes: 1 addition & 1 deletion pytorch-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3f159d635772fa2a8fd352d96b95100d885f8169
37626ee0e6ff5dc1d38664690bd2ff6c790aab0c
2 changes: 1 addition & 1 deletion pytorch-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch/
--pre
torch==2.6.0.dev20241216
torch==2.7.0.dev20250120
2 changes: 1 addition & 1 deletion torchvision-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torchvision/
--pre
torchvision==0.22.0.dev20241216
torchvision==0.22.0.dev20250120
Loading