-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModule Tests.swift
111 lines (96 loc) · 3.1 KB
/
Module Tests.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import Dialects
import MLIR
import XCTest
final class ModuleTests: XCTestCase {
func testCanonicalization() throws {
let context = MLIR.OwnedContext(dialects: .std)
let passManager = PassManager(context: context, passes: .canonicalization)
let module: Module = try context.parse(
"""
module {
func @swap(%arg0: i1, %arg1: i1) -> (i1, i1) {
%0 = "std.addi"(%arg0, %arg0) : (i1, i1) -> i1
return %arg1, %arg0 : i1, i1
}
}
""")
passManager.runPasses(on: module)
XCTAssertEqual(
"\(module.operation)",
"""
module {
func @swap(%arg0: i1, %arg1: i1) -> (i1, i1) {
return %arg1, %arg0 : i1, i1
}
}
""")
}
func testCanonicalization2() throws {
let context = MLIR.OwnedContext(dialects: .std)
let passManager = PassManager(context: context, passes: .canonicalization)
let module: Module = try context.parse(
"""
module {
func @swap(%arg0: i1, %arg1: i1) -> (i1, i1) {
%0 = "std.addi"(%arg0, %arg0) : (i1, i1) -> i1
return %arg1, %arg0 : i1, i1
}
}
""")
passManager.runPasses(on: module)
XCTAssertEqual(
"\(module.operation)",
"""
module {
func @swap(%arg0: i1, %arg1: i1) -> (i1, i1) {
return %arg1, %arg0 : i1, i1
}
}
""")
}
func testModule() throws {
let context = MLIR.OwnedContext(dialects: .std)
let reference = """
module {
func @swap(%arg0: i1, %arg1: i1) -> (i1, i1) {
return %arg1, %arg0 : i1, i1
}
}
"""
let generic = """
"module"() ( {
"func"() ( {
^bb0(%arg0: i1, %arg1: i1): // no predecessors
"std.return"(%arg1, %arg0) : (i1, i1) -> ()
}) {sym_name = "swap", type = (i1, i1) -> (i1, i1)} : () -> ()
}) : () -> ()
"""
let location: Location = .unknown(in: context)
let constructed = Module(location: location)
constructed.body.operations.append(
.function(
"swap",
returnTypes: [IntegerType.integer(bitWidth: 1), .integer(bitWidth: 1)],
blocks: [
Block(IntegerType.integer(bitWidth: 1), IntegerType.integer(bitWidth: 1), in: context) {
ops, a, b in
ops.append(.return(b, a, at: location.viaCallsite()))
}
],
at: location.viaCallsite()))
XCTAssertTrue(constructed.body.operations.map(\.isValid).reduce(true, { $0 && $1 }))
XCTAssertTrue(constructed.operation.isValid)
let parsed: Module = try context.parse(reference)
XCTAssertEqual(parsed.body.operations.count, 1)
XCTAssertEqual(parsed.operation.regions.count, 1)
XCTAssertEqual(parsed.operation.regions.first?.blocks.count, 1)
XCTAssertEqual(
generic,
"\(constructed.operation.withPrintingOptions(alwaysPrintInGenericForm: true))")
XCTAssertEqual(
generic,
"\(parsed.operation.withPrintingOptions(alwaysPrintInGenericForm: true))")
XCTAssertEqual(reference, "\(constructed.operation)")
XCTAssertEqual(reference, "\(parsed.operation)")
}
}