Skip to content

Commit

Permalink
fix: create batch files to pass IDs to save Redis RAM
Browse files Browse the repository at this point in the history
  • Loading branch information
mickol34 committed Nov 6, 2024
1 parent a7b16a0 commit e81a04a
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 28 deletions.
28 changes: 22 additions & 6 deletions src/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .models.job import Job
from .models.jobagent import JobAgent
from .models.match import Match
from .models.queryresult import QueryResult
from .models.jobfile import JobFile
from .schema import MatchesSchema, ConfigSchema
from .config import app_config

Expand Down Expand Up @@ -112,16 +112,32 @@ def add_match(self, job: JobId, match: Match) -> None:
session.add(match)
session.commit()

def add_queryresult(self, job_id: int | None, files: List[str]) -> None:
def __get_jobfile(self, session: Session, jobfile_id: str) -> JobFile:
"""Internal helper to get a jobfile from the database."""
return session.exec(select(JobFile).where(JobFile.id == jobfile_id)).one()

def get_jobfile(self, jobfile_id: str) -> JobFile:
"""Retrieves a jobfile from the database."""
with self.session() as session:
return self.__get_jobfile(session, jobfile_id)

def get_jobfiles_ids_by_job_id(self, job_id: int | None) -> List[int | None]:
with self.session() as session:
jobfiles = session.exec(select(JobFile).where(JobFile.job_id == job_id)).all()
return [jobfile.id for jobfile in jobfiles]

def add_jobfile(self, job_id: int | None, files: List[str]) -> None:
"""Creates new JobFile instance, adds it to database and returns it's ID."""
with self.session() as session:
obj = QueryResult(job_id=job_id, files=files)
obj = JobFile(job_id=job_id, files=files)
session.add(obj)
session.commit()

def remove_queryresult(self, job_id: int | None) -> None:
def remove_jobfile(self, jobfile: JobFile) -> None:
"""Removes all JobFile instances with given Job.id."""
with self.session() as session:
session.query(QueryResult).where(
QueryResult.job_id == job_id
session.query(JobFile).where(
JobFile.id == jobfile.id
).delete()
session.commit()

Expand Down
5 changes: 3 additions & 2 deletions src/lib/ursadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def query(
command += f"with taints {taints_whole_str} "
if dataset:
command += f'with datasets ["{dataset}"] '
command += f"{query};"
command += f"into iterator {query};"

start = time.perf_counter()
res = self.__execute(command, recv_timeout=-1)
Expand All @@ -75,7 +75,8 @@ def query(

return {
"time": (end - start),
"files": res["result"]["files"],
"iterator": res["result"]["iterator"],
"file_count": res["result"]["file_count"],
}

def pop(self, iterator: str, count: int) -> PopResult:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""create Queryresult model
"""create Jobfile model
Revision ID: 4e4c88411541
Revises: dbb81bd4d47f
Create Date: 2024-10-17 14:31:49.278443
Expand All @@ -16,7 +16,7 @@

def upgrade() -> None:
op.create_table(
"queryresult",
"jobfile",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("job_id", sa.Integer(), nullable=False),
sa.Column("files", sa.ARRAY(sa.String()), nullable=True),
Expand All @@ -29,4 +29,4 @@ def upgrade() -> None:


def downgrade() -> None:
op.drop_table("queryresult")
op.drop_table("jobfile")
2 changes: 1 addition & 1 deletion src/models/queryresult.py → src/models/jobfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Union


class QueryResult(SQLModel, table=True):
class JobFile(SQLModel, table=True):
id: Union[int, None] = Field(default=None, primary_key=True)
job_id: Union[int, None] = Field(foreign_key="job.internal_id")
files: List[str] = Field(sa_column=Column(ARRAY(String)))
30 changes: 14 additions & 16 deletions src/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,9 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None:
if "error" in result:
raise RuntimeError(result["error"])

files = result["files"]
agent.db.add_queryresult(job.internal_id, files)

file_count = len(files)
file_count = result["file_count"]
iterator = result["iterator"]
logging.info(f"Iterator {iterator} contains {file_count} files")

total_files = agent.db.update_job_files(job_id, file_count)
if job.files_limit and total_files > job.files_limit:
Expand All @@ -257,32 +256,31 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None:
# add len(batch_sizes) new tasks, -1 to account for this task
agent.add_tasks_in_progress(job, len(batch_sizes) - 1)

batched_files = (
files[batch_end - batch_size : batch_end]
for batch_end, batch_size in zip(
accumulate(batch_sizes), batch_sizes
)
)
for batch_size in batch_sizes:
pop_result = agent.ursa.pop(iterator, batch_size)
agent.db.add_jobfile(job.internal_id, pop_result.files)

for batch_files in batched_files:
jobfile_ids = agent.db.get_jobfiles_ids_by_job_id(job.internal_id)
logging.critical(f'Jobfile_ids: {jobfile_ids}')
for jobfile_id in jobfile_ids:
agent.queue.enqueue(
run_yara_batch,
job_id,
batch_files,
jobfile_id,
job_timeout=app_config.rq.job_timeout,
)

agent.db.dataset_query_done(job_id)
agent.db.remove_queryresult(job.internal_id)


def run_yara_batch(job_id: JobId, batch_files: List[str]) -> None:
def run_yara_batch(job_id: JobId, jobfile_id: str) -> None:
"""Actually scans files, and updates a database with the results."""
with job_context(job_id) as agent:
job = agent.db.get_job(job_id)
if job.status == "cancelled":
logging.info("Job was cancelled, returning...")
return

agent.execute_yara(job, batch_files)
jobfile = agent.db.get_jobfile(jobfile_id)
agent.execute_yara(job, jobfile.files)
agent.add_tasks_in_progress(job, -1)
agent.db.remove_jobfile(jobfile)

0 comments on commit e81a04a

Please sign in to comment.