Skip to content

Commit

Permalink
refactored regex to split into seperate border and row
Browse files Browse the repository at this point in the history
  • Loading branch information
nnshah1 committed May 17, 2024
1 parent 8fe0a55 commit 5fd696c
Showing 1 changed file with 69 additions and 41 deletions.
110 changes: 69 additions & 41 deletions qa/L0_logging/log_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,18 @@ def parse_timestamp(timestamp):
os.makedirs(test_logs_directory)

# Regular expression pattern to capture the headers and rows
table_regex = re.compile(
r'\+[-+]+\+\n' # Match the top border
r'\| (?P<header>.*?) \|\n' # Capture the header
r'\+[-+]+\+\n' # Match the header border
r'(?P<rows>(?:\| .*? \|\n)*)' # Capture the rows
r'\+[-+]+\+', # Match the bottom border
re.DOTALL # Enable dot to match newlines
)
# table is
# border
# header
# border
# row *
# border

table_border_regex = re.compile(r'^\+[-+]+\+$')
table_row_regex = re.compile(r'^\| (?P<row>.*?) \|$')


# Regular expression pattern
# Regular expression pattern for default log record
default_pattern = r'(?P<level>\w)(?P<month>\d{2})(?P<day>\d{2}) (?P<timestamp>\d{2}:\d{2}:\d{2}\.\d{6}) (?P<pid>\d+) (?P<file>[\w\.]+):(?P<line>\d+)] (?P<message>.*)'

# Compile the regex pattern
Expand All @@ -84,7 +86,7 @@ def parse_timestamp(timestamp):
FORMATS = [
("default", default_regex),
("ISO8601", ""),
("default_unescaped", ""),
("default_unescaped", default_regex),
("ISO8601_unescaped", ""),
]

Expand All @@ -97,68 +99,89 @@ def validator(func):
return func

@validator
def validate_level(level):
def validate_level(level, _):
assert level in LEVELS

@validator
def validate_month(month):
def validate_month(month, _):
assert month.isdigit()
month = int(month)
assert month >= 1 and month <= 12

@validator
def validate_day(day):
def validate_day(day, _):
assert day.isdigit()
day = int(day)
assert day >= 1 and day <= 31

@validator
def validate_timestamp(timestamp):
def validate_timestamp(timestamp, _):
parse_timestamp(timestamp)

@validator
def validate_pid(pid):
def validate_pid(pid, _):
assert pid.isdigit()

@validator
def validate_file(file_):
def validate_file(file_, _):
assert Path(file_).name is not None

@validator
def validate_line(line):
def validate_line(line, _):
assert line.isdigit()

def validate_table(table):
header = table.group("header").strip().split('|')
rows = table.group("rows").strip().split('\n')
def _split_row(row):
return [r.strip() for r in row.group("row").strip().split('|')]

def validate_table(table_rows):
index = 0
top_border = table_border_regex.search(table_rows[index])
assert top_border

index += 1
header = table_row_regex.search(table_rows[index])
assert header
header = _split_row(header)

index += 1
middle_border = table_border_regex.search(table_rows[index])
assert(middle_border)

# Process each row
index+=1
parsed_rows = []
for row in rows:
if row:
row_data = [r.strip() for r in row.split('|')[1:-1]]
row=""
for index, row in enumerate(table_rows[index:]):
matched = table_row_regex.search(row)
if matched:
row_data = _split_row(matched)
parsed_rows.append(row_data)


end_border = table_border_regex.search(row)
assert end_border

for row in parsed_rows:
assert len(row)==len(header)

@validator
def validate_message(message):
def validate_message(message, escaped):
heading, obj = message.split('\n',1)
if heading:
if heading and escaped:
try:
json.loads(heading)
except json.JSONDecodeError as e:
raise Exception(f"{e} First line of message in log record is not a valid JSON string")
except Exception as e:
raise type(e)(f"{e} First line of message in log record is not a valid JSON string")
if len(obj):
obj = obj.strip()
match = table_regex.search(obj)
if match:
validate_table(match)
else:
google.protobuf.text_format.Parse(obj,grpcclient.model_config_pb2.ModelConfig())
obj = obj.strip().split('\n')
if obj:
match = table_border_regex.search(obj[0])
if match:
validate_table(obj)
else:
google.protobuf.text_format.ParseLines(obj,grpcclient.model_config_pb2.ModelConfig())


class TestLogFormat:
@pytest.fixture(autouse=True)
Expand All @@ -178,18 +201,20 @@ def setup(self, request):
test_logs_directory, test_case_name + ".server.log"
)

def _launch_server(self, unescaped=None):
def _launch_server(self, escaped=None):
cmd = ["tritonserver"]

for key, value in self._server_options.items():
cmd.append(f"--{key}={value}")

env = os.environ.copy()

if unescaped:
if escaped is not None and not escaped:
env["TRITON_SERVER_ESCAPE_LOG_MESSSAGES"] = "FALSE"
elif unescaped is not None:
elif escaped is not None and escaped:
env["TRITON_SERVER_ESCAPE_LOG_MESSSAGES"] = "TRUE"
else:
del env["TRITON_SERVER_ESCAPE_LOG_MESSSAGES"]

self._server_process = subprocess.Popen(
cmd,
Expand All @@ -208,20 +233,20 @@ def _launch_server(self, unescaped=None):
if not os.path.exists(self._server_options["log-file"]):
raise Exception("Log not found")

def validate_log_record(self, record, format_regex):
def validate_log_record(self, record, format_regex, escaped):
match = format_regex.search(record)
if match:
for field, value in match.groupdict().items():
if field in validators:
try:
validators[field](value)
validators[field](value,escaped)
except Exception as e:
raise type(e)(f"{e}\nInvalid {field}: '{match.group(field)}' in log record '{record}'")

else:
raise Exception("Invalid log line")

def verify_log_format(self, file_path, format_regex):
def verify_log_format(self, file_path, format_regex, escaped):
log_records = []
with open(file_path, "rt") as file_:
current_log_record = []
Expand All @@ -236,7 +261,7 @@ def verify_log_format(self, file_path, format_regex):
log_records.append(current_log_record)
log_records = ["".join(log_record_lines) for log_record_lines in log_records]
for log_record in log_records:
self.validate_log_record(log_record, format_regex)
self.validate_log_record(log_record, format_regex, escaped)

@pytest.mark.parametrize(
"log_format,format_regex",
Expand All @@ -245,13 +270,16 @@ def verify_log_format(self, file_path, format_regex):
)
def test_log_format(self, log_format, format_regex):
self._server_options["log-format"] = log_format.replace("_unescaped", "")
self._launch_server(unescaped=True if "_unescaped" in log_format else False)

escaped = "_unescaped" not in log_format

self._launch_server(escaped)
time.sleep(1)
self._server_process.kill()
return_code = self._server_process.wait()

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable return_code is not used.
if isinstance(format_regex, str):
return
self.verify_log_format(self._server_options["log-file"], format_regex)
self.verify_log_format(self._server_options["log-file"], format_regex, escaped)

def foo_test_injection(self):
try:
Expand Down

0 comments on commit 5fd696c

Please sign in to comment.