-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathday16.py
executable file
·80 lines (60 loc) · 2.13 KB
/
day16.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python3
import sys
import re
import z3
def solve(ops):
solver = z3.Solver()
variables = []
for i, possible in enumerate(ops):
x = z3.Int(i)
solver.add(z3.Or(*(x == v for v in possible)))
solver.add(z3.And(*(x != y for y in variables)))
variables.append(x)
solver.check()
m = solver.model()
return tuple(map(z3.IntNumRef.as_long, map(m.eval, variables)))
# Open the first argument as input or use stdin if no arguments were given
fin = open(sys.argv[1]) if len(sys.argv) > 1 else sys.stdin
opcodes = [
lambda r,a,b: r[a] + r[b], # 0 addr
lambda r,a,b: r[a] + b, # 1 addi
lambda r,a,b: r[a] * r[b], # 2 mulr
lambda r,a,b: r[a] * b, # 3 muli
lambda r,a,b: r[a] & r[b], # 4 banr
lambda r,a,b: r[a] & b, # 5 bani
lambda r,a,b: r[a] | r[b], # 6 borr
lambda r,a,b: r[a] | b, # 7 bori
lambda r,a,b: r[a], # 8 setr
lambda r,a,b: a, # 9 seti
lambda r,a,b: 1 if a > r[b] else 0, # 10 gtir
lambda r,a,b: 1 if r[a] > b else 0, # 11 gtri
lambda r,a,b: 1 if r[a] > r[b] else 0, # 12 gtrr
lambda r,a,b: 1 if a == r[b] else 0, # 13 eqir
lambda r,a,b: 1 if r[a] == b else 0, # 14 eqri
lambda r,a,b: 1 if r[a] == r[b] else 0 # 15 eqrr
]
opmap = list(set(range(16)) for i in range(16))
rexp = re.compile(r'-?\d+')
data = fin.read().strip().split('\n\n\n\n')
samples = [l.split('\n') for l in data[0].split('\n\n')]
program = [tuple(map(int, l.split())) for l in data[1].split('\n')]
ans = 0
for sample in samples:
before = tuple(map(int, rexp.findall(sample[0])))
instr = tuple(map(int, rexp.findall(sample[1])))
after = tuple(map(int, rexp.findall(sample[2])))
count = 0
for op_id, op in enumerate(opcodes):
if op(before, instr[1], instr[2]) == after[instr[3]]:
count += 1
elif op_id in opmap[instr[0]]:
opmap[instr[0]].remove(op_id)
if count >= 3:
ans += 1
print('Part 1:', ans)
opmap = solve(opmap)
regs = [0] * 4
for instr in program:
regs[instr[3]] = opcodes[opmap[instr[0]]](regs, instr[1], instr[2])
ans2 = regs[0]
print('Part 2:', ans2)