Skip to content

Commit

Permalink
Add LIT test.
Browse files Browse the repository at this point in the history
  • Loading branch information
sahas3 committed Jan 9, 2025
1 parent c6a5ee5 commit 6c3d058
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,29 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten
return %0 : !torch.vtensor<[1,128],si64>
}

// -----
// CHECK-LABEL: func.func @torch.aten.to.dtype$floatToInt(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> {
// CHECK: %[[TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32>
// CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FLOOR:.*]] = tosa.floor %[[TENSOR]] : (tensor<3x5xf32>) -> tensor<3x5xf32>
// CHECK: %[[CEIL:.*]] = tosa.ceil %[[TENSOR]] : (tensor<3x5xf32>) -> tensor<3x5xf32>
// CHECK: %[[F0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[IS_NEG:.*]] = tosa.greater %[[F0]], %[[TENSOR]] : (tensor<f32>, tensor<3x5xf32>) -> tensor<3x5xi1>
// CHECK: %[[SELECT:.*]] = tosa.select %[[IS_NEG]], %[[CEIL]], %[[FLOOR]] : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32>
// CHECK: %[[CAST:.*]] = tosa.cast %[[SELECT]] : (tensor<3x5xf32>) -> tensor<3x5xi64>
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<3x5xi64> -> !torch.vtensor<[3,5],si64>
// CHECK: return %[[RES]] : !torch.vtensor<[3,5],si64>
func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> {
%int4 = torch.constant.int 4
%false = torch.constant.bool false
%none = torch.constant.none
%0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.vtensor<[3,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],si64>
return %0 : !torch.vtensor<[3,5],si64>
}

// -----
// CHECK-LABEL: func.func @torch.aten.gather(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>,
Expand Down

0 comments on commit 6c3d058

Please sign in to comment.