Skip to content

Commit

Permalink
add code
Browse files Browse the repository at this point in the history
  • Loading branch information
Jianting He committed Jul 29, 2022
1 parent f75d7f5 commit 7eb4eac
Show file tree
Hide file tree
Showing 9 changed files with 1,690 additions and 1 deletion.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
# RNVulDet
# RNVulDet

Ethereum smart contract random number vulnerability detector.

## Usage
```
python3.10 main.py BYTECODE_FILE [-o OUTPUT_FILE]
```
75 changes: 75 additions & 0 deletions disassembler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import opcodes
import instruction


class Disassembler:
def __init__(self, bytecode: bytes):
self.bytecode = bytecode
self.instructions: dict[int, instruction.Instruction] = {}
self.instructions_list: list[instruction.Instruction] = []
self.jumpdests: set[int] | None = None
self.invalid_jumpdests: set[int] | None = None
self.opcodes: set[int] = set()

def disassemble(self):
offset, pc, bytecode = 0, 0, self.bytecode
end = len(bytecode)
dead = False

while offset < end:
opcode = bytecode[offset]

if opcode == opcodes.JUMPDEST:
dead = False

if opcode in opcodes.opcodes:
push_data_size = opcodes.opcodes[opcode][1]
else:
push_data_size = 0

push_data = self.get_push_data(offset + 1, push_data_size, end)

inst = instruction.Instruction(offset, pc, opcode, push_data)
self.add_instruction(inst, dead=dead)

if inst.is_halt_or_unconditional_jump_op():
dead = True

offset += 1 + push_data_size
pc += 1

if not dead:
# append STOP instruction
inst = instruction.Instruction(offset, pc, opcodes.STOP, None)
self.add_instruction(inst, dead=dead)

self.jumpdests = {
offset
for offset, inst in self.instructions.items()
if inst.opcode == opcodes.JUMPDEST
}
self.invalid_jumpdests = {0, 2, 7} - self.jumpdests

def add_instruction(self, inst: instruction.Instruction, dead: bool):
self.instructions[inst.offset] = inst
self.instructions_list.append(inst)
if not dead:
self.opcodes.add(inst.opcode)

def at(self, pc=None, offset=None) -> instruction.Instruction:
if pc is not None:
return self.instructions_list[pc]
else:
return self.instructions[offset]

def get_push_data(self, offset: int, push_data_size, end):
if not push_data_size:
return None
data_end = offset + push_data_size
if data_end <= end:
data_bytes = self.bytecode[offset:data_end]
else:
# append 0s
data_bytes = self.bytecode[offset:end] + bytes(data_end - end)
push_data = int.from_bytes(data_bytes, "big")
return push_data
173 changes: 173 additions & 0 deletions engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from __future__ import annotations
import logging

import disassembler
import tracker
import opcodes
from structures import PathItem, StoItem

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from instruction_instance import InstructionInstance


class Engine:
def __init__(self, bytecode: bytes):
self.bytecode = bytecode
self.conditions = []
self.call_values = []
self.to_addresses = []
self.todo_keys = []
self.disasm = disassembler.Disassembler(self.bytecode)
self.step: int = 0

def run(self) -> bool:
self.disasm.disassemble()
if not (opcodes.special_op & self.disasm.opcodes) and (
not (opcodes.time_op & self.disasm.opcodes)
or not (opcodes.mod_op & self.disasm.opcodes)
):
return False
if opcodes.CALL not in self.disasm.opcodes:
return False
logging.info("== first step ==")
self.step = 1
self.tracker = tracker.Tracker(self.bytecode, self.disasm, step=self.step)
self.dfs(start_offset=0, depth=0, step=self.step)
for attr_name in ("conditions", "to_addresses", "call_values", "todo_keys"):
attr = getattr(self, attr_name)
logging.info("%s %s", len(attr), attr_name)
if (
not self.conditions
and not self.call_values
and not self.to_addresses
and self.todo_keys
):
logging.info("== second step ==")
self.step = 2
self.tracker = tracker.Tracker(
self.bytecode, self.disasm, step=self.step, todo_keys=self.todo_keys
)
self.dfs(start_offset=0, depth=0, step=self.step)
for attr_name in ("conditions", "to_addresses", "call_values"):
attr = getattr(self, attr_name)
logging.info("%s %s", len(attr), attr_name)

return bool(self.conditions or self.call_values or self.to_addresses)

def taint_sink(self, step: int, inst_instance: InstructionInstance):
inst = inst_instance.inst

if (
inst.opcode == opcodes.CALL
and inst_instance.operands[tracker.STK][1].value not in range(1, 10)
and inst_instance.operands[tracker.STK][2].value != 0
):
to_address = inst_instance.operands[tracker.STK][1].get_origin()
if to_address.taint_inst & {
opcodes.CALLER,
opcodes.ORIGIN,
opcodes.CALLDATALOAD,
opcodes.CALLDATACOPY,
}:
for item in self.tracker.state.path[:-1]:
condition = item.condition
if condition is not None and condition.use_special_inst():
item = (f"step{step}", condition, inst_instance)
self.conditions.append(item)
call_value = inst_instance.operands[tracker.STK][2]
if call_value.use_special_inst():
self.call_values.append((f"step{step}", inst_instance))

if to_address.use_special_inst():
item = (f"step{step}", inst_instance)
self.to_addresses.append(item)
elif step == 1 and inst.opcode == opcodes.SSTORE:
key = inst_instance.operands[tracker.STK][0].get_origin()

flag = False
if inst_instance.use_special_inst():
flag = True
for item in self.tracker.state.path[:-1]:
condition = item.condition
if condition is not None and condition.use_special_inst():
inst_instance.taint_inst.update(condition.taint_inst)
flag = True
if flag:
key_poly = key.get_polynomial()
for item in reversed(self.todo_keys):
if item.key.get_polynomial().eq(key_poly, silence=True):
break
else:
self.todo_keys.append(StoItem(key=key, inst_instance=inst_instance))

def dfs(self, start_offset, depth, step, is_jumpi_true_branch=None):
if depth > 800:
logging.warning(
f"call stack too deep, start_offset={start_offset}, depth={depth}"
)
return

if not self.tracker.update_images(start_offset):
logging.debug(f"image same, start_offset={start_offset:05x}")
return

self.tracker.state.path.append(
PathItem(start_offset, None, is_jumpi_true_branch)
)

pc = self.disasm.at(offset=start_offset).pc
while True:
inst = self.disasm.at(pc=pc)
if inst.opcode not in opcodes.opcodes:
logging.warning(f"Unknown opcode: {inst.opcode:#02x}")
break
pc += 1

inst_instance = self.tracker.update(inst)
if inst_instance is None:
break

self.taint_sink(step, inst_instance)

if inst.opcode == opcodes.JUMP:
target_offset = inst_instance.operands[tracker.STK][0].value
if target_offset not in self.disasm.invalid_jumpdests:
if target_offset in self.disasm.jumpdests:
self.dfs(target_offset, depth + 1, step)
else:
if target_offset is not None:
logging.warning(f"Bad jumpdest: {target_offset:#02x}")
else:
logging.warning("Bad jumpdest: None")

break
elif inst.opcode == opcodes.JUMPI:
target_offset = inst_instance.operands[tracker.STK][0].value
condition = inst_instance.operands[tracker.STK][1].get_origin()
self.tracker.state.path[-1].condition = condition
if target_offset not in self.disasm.invalid_jumpdests:
if target_offset in self.disasm.jumpdests:
state_cpy = self.tracker.state.copy()
self.dfs(target_offset, depth + 1, step, True)
self.tracker.state = state_cpy
del state_cpy
else:
if target_offset is not None:
logging.warning(f"Bad jumpdest: {target_offset:#02x}")
else:
logging.warning("Bad jumpdest: None")
next_offset = self.disasm.at(pc=pc).offset
assert next_offset == inst.offset + 1
self.dfs(next_offset, depth + 1, step, False)
break
elif inst.is_halt_op():
break

# pc already added, this is next instruction
if self.disasm.at(pc=pc).opcode == opcodes.JUMPDEST:
next_offset = self.disasm.at(pc=pc).offset
assert next_offset == inst.offset + 1 + (inst.get_push_arg() or 0)
self.dfs(next_offset, depth + 1, step)
break
129 changes: 129 additions & 0 deletions instruction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import opcodes


class Instruction:
def __init__(self, offset, pc, opcode, push_data=None):
self.offset = offset
self.pc = pc
assert isinstance(opcode, int)
self.opcode = opcode
self.push_data = push_data

def is_halt_op(self):
return self.opcode not in opcodes.opcodes or self.opcode in opcodes.halt_op

def is_push_op(self):
return self.opcode in opcodes.push_op

def is_dup_op(self):
return self.opcode in opcodes.dup_op

def is_swap_op(self):
return self.opcode in opcodes.swap_op

def is_halt_or_unconditional_jump_op(self):
return self.is_halt_op() or self.opcode == opcodes.JUMP

def is_arithmetic_op(self):
return self.opcode in opcodes.arithmetic_op

def is_mem_read_op(self):
return self.opcode in opcodes.mem_read_op

def is_mem_write_op(self):
return self.opcode in opcodes.mem_write_op

def is_mem_access_op(self):
return self.opcode in opcodes.mem_access_op

def is_mem_rw_op(self):
return self.opcode in opcodes.mem_rw_op

def is_call_op(self):
return self.opcode in opcodes.call_op

def is_commutative_op(self):
return self.opcode in opcodes.commutative_op

def is_taint_op(self):
return self.opcode in opcodes.taint_op

def n_pops(self):
if self.opcode in opcodes.opcodes:
return opcodes.opcodes[self.opcode][2]
else:
return 0

def n_pushes(self):
if self.opcode in opcodes.opcodes:
return opcodes.opcodes[self.opcode][3]
else:
return 0

def get_push_arg(self):
if self.opcode in opcodes.push_op:
return opcodes.push_op[self.opcode]
else:
return None

def get_dup_arg(self):
if self.opcode in opcodes.dup_op:
return opcodes.dup_op[self.opcode]
else:
return None

def get_swap_arg(self):
if self.opcode in opcodes.swap_op:
return opcodes.swap_op[self.opcode]
else:
return None

def get_op_tuple(self, isRead):
if isRead:
return opcodes.mem_read_op[self.opcode]
else:
return opcodes.mem_write_op[self.opcode]

def get_mem_start_idx(self, isRead):
return self.get_op_tuple(isRead)[0]

def get_mem_len_idx(self, isRead):
return self.get_op_tuple(isRead)[1]

@property
def name(self):
if self.opcode in opcodes.opcodes:
return opcodes.opcodes[self.opcode][0]
elif self.opcode == 0x100:
return "VALUE"
elif self.opcode == 0x101:
return "UNKNOWN"
elif self.opcode == 0x102:
return "POSITION"
else:
return "GARBAGE %#02x" % self.opcode

def __eq__(self, _) -> bool:
raise NotImplementedError

def __hash__(self) -> int:
raise NotImplementedError

def __str__(self):
if self.push_data is not None:
return " ".join(["%05x" % self.offset, self.name, hex(self.push_data)])
else:
return " ".join(["%05x" % self.offset, self.name])

to_json = __str__

def __repr__(self):
return self.__str__()

@classmethod
def get_special_value(cls):
try:
return cls.special_value
except AttributeError:
cls.special_value = Instruction(0xFFFFE, 0xFFFFE, opcodes.SPECIAL_VALUE)
return cls.special_value
Loading

0 comments on commit 7eb4eac

Please sign in to comment.