From a812eeae9c677593684443c1c50ede078f78c7fd Mon Sep 17 00:00:00 2001 From: Fred Baguelin Date: Wed, 27 Nov 2024 18:44:52 +0100 Subject: [PATCH 01/12] Improve concurrency with custom http client and lock --- core/database_arango.py | 254 +++++++++++++++++++++++++++++++++------- 1 file changed, 209 insertions(+), 45 deletions(-) diff --git a/core/database_arango.py b/core/database_arango.py index b816e133a..f52737e35 100644 --- a/core/database_arango.py +++ b/core/database_arango.py @@ -37,6 +37,86 @@ TYetiObject = TypeVar("TYetiObject", bound="ArangoYetiConnector") +from arango.http import HTTPClient +from arango.response import Response +from requests import Session +from requests.adapters import HTTPAdapter +from urllib3.util import Retry + + +class LogAndRetryHTTPClient(HTTPClient): + def __init__(self, retries=0): + self._logger = logging.getLogger() + self._lock = None + self._retries = retries + + @property + def lock(self): + if not self._lock: + self._lock = MockLock() + return self._lock + + def set_lock(self, lock): + self._lock = lock + + def create_session(self, host): + session = Session() + if self._retries: + retry_strategy = Retry( + total=3, + backoff_factor=1, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=["HEAD", "GET", "OPTIONS", "POST"], + ) + http_adapter = HTTPAdapter(max_retries=retry_strategy) + session.mount("https://", http_adapter) + session.mount("http://", http_adapter) + return session + + def send_request( + self, session, method, url, params=None, data=None, headers=None, auth=None + ): + # Acquire multiprocessing lock + # Mandatory with events consumers + self.lock.acquire() + response = None + try: + response = session.request( + method=method, + url=url, + params=params, + data=data, + headers=headers, + auth=auth, + verify=False, # Disable SSL verification + timeout=5, # Use timeout of 5 seconds + ) + # Return an instance of arango.response.Response. + response = Response( + method=response.request.method, + url=response.url, + headers=response.headers, + status_code=response.status_code, + status_text=response.reason, + raw_body=response.text, + ) + except Exception: + self._logger.error(f"Error while sending request to {url}") + finally: + self.lock.release() + return response + + +class MockLock: + def __init__(self): + pass + + def acquire(self): + pass + + def release(self): + pass + class ArangoDatabase: """Class that contains the base class for the database. @@ -48,6 +128,7 @@ def __init__(self): self.db = None self.collections = dict() self.graphs = dict() + self._http_client = LogAndRetryHTTPClient() def connect( self, @@ -67,7 +148,9 @@ def connect( database = "yeti_test" host_string = f"http://{host}:{port}" - client = ArangoClient(hosts=host_string, request_timeout=None) + client = ArangoClient( + hosts=host_string, request_timeout=None, http_client=self._http_client + ) sys_db = client.db("_system", username=username, password=password) for _ in range(0, 4): @@ -229,14 +312,99 @@ def collection(self, name): if self.db is None: self.connect() + async_db = self.db.begin_async_execution(return_result=True) if name not in self.collections: - if self.db.has_collection(name): - self.collections[name] = self.db.collection(name) + job = async_db.has_collection(name) + while job.status() != "done": + time.sleep(0.1) + data = job.result() + if data: + col = async_db.collection(name) + self.collections[name] = col else: - self.collections[name] = self.db.create_collection(name) - + job = async_db.create_collection(name) + while job.status() != "done": + time.sleep(0.1) + col = job.result() + self.collections[name] = col return self.collections[name] + def insert(self, collection, document_json): + if self.db is None: + self.connect() + document: dict = json.loads(document_json) + try: + async_db = self.db.begin_async_execution(return_result=True) + async_col = async_db.collection(collection) + job = async_col.insert(document, return_new=True) + while job.status() != "done": + time.sleep(0.1) + newdoc = job.result() + except DocumentInsertError as err: + if not err.error_code == 1210: # Unique constraint violation + raise + return None + return newdoc + + def update(self, collection, document_json): + document = json.loads(document_json) + doc_id = document.pop("id") + async_db = self.db.begin_async_execution(return_result=True) + async_col = async_db.collection(collection) + if doc_id: + document["_key"] = doc_id + job = async_col.update(document, return_new=True) + while job.status() != "done": + time.sleep(0.1) + newdoc = job.result() + else: + if "value" in document: + filters = {"value": document["value"]} + else: + filters = {"name": document["name"]} + if "type" in document: + filters["type"] = document["type"] + logging.debug(f"filters: {filters}") + job = async_col.update_match(filters, document) + while job.status() != "done": + time.sleep(0.1) + newdoc = job.result() + if newdoc != 1: + return None + try: + job = async_col.find(filters, limit=1) + while job.status() != "done": + time.sleep(0.1) + result = job.result() + newdoc = list(result)[0] + except IndexError as exception: + msg = f"Update failed when adding {document_json}: {exception}" + logging.error(msg) + raise RuntimeError(msg) + return newdoc + + def get(self, collection, id): + if self.db is None: + self.connect() + async_db = self.db.begin_async_execution(return_result=True) + async_col = async_db.collection(collection) + job = async_col.get(id) + while job.status() != "done": + time.sleep(0.1) + result = job.result() + return result + + def find(self, collection, kwargs, limit): + if self.db is None: + self.connect() + async_db = self.db.begin_async_execution(return_result=True) + async_col = async_db.collection(collection) + job = async_col.find(kwargs, limit=limit) + while job.status() != "done": + time.sleep(0.1) + result = job.result() + return result + def graph(self, name): if self.db is None: self.connect() @@ -265,6 +433,9 @@ def __getattr__(self, key): self.connect() return getattr(self.db, key) + def set_lock(self, lock): + self._http_client.set_lock(lock) + db = ArangoDatabase() @@ -282,40 +453,24 @@ def extended_id(self): return self._collection_name + "/" + self.id def _insert(self, document_json: str): - document: dict = json.loads(document_json) - try: - newdoc = self._get_collection().insert(document, return_new=True)["new"] - except DocumentInsertError as err: - if not err.error_code == 1210: # Unique constraint violation - raise + document = self._db.insert(self._collection_name, document_json) + if not document: return None - newdoc["__id"] = newdoc.pop("_key") - return newdoc + document = document["new"] + document["__id"] = document.pop("_key") + return document def _update(self, document_json): - document = json.loads(document_json) - doc_id = document.pop("id") - if doc_id: - document["_key"] = doc_id - newdoc = self._get_collection().update(document, return_new=True)["new"] - else: - if "value" in document: - filters = {"value": document["value"]} - else: - filters = {"name": document["name"]} - if "type" in document: - filters["type"] = document["type"] - self._get_collection().update_match(filters, document) - - logging.debug(f"filters: {filters}") - try: - newdoc = list(self._get_collection().find(filters, limit=1))[0] - except IndexError as exception: - msg = f"Update failed when adding {document_json}: {exception}" - logging.error(msg) - raise RuntimeError(msg) - newdoc["__id"] = newdoc.pop("_key") - return newdoc + document = self._db.update(self._collection_name, document_json) + if not document: + return None + try: + if "new" in document: + document = document["new"] + document["__id"] = document.pop("_key") + return document + except Exception: + return None def save( self: TYetiObject, exclude_overwrite: list[str] = ["created", "tags", "context"] @@ -391,7 +546,7 @@ def get(cls: Type[TYetiObject], id: str) -> TYetiObject | None: Returns: A Yeti object.""" - document = cls._get_collection().get(id) + document = cls._db.get(cls._collection_name, id) if not document: return None document["__id"] = document.pop("_key") @@ -410,10 +565,11 @@ def find(cls: Type[TYetiObject], **kwargs) -> TYetiObject | None: if "type" not in kwargs and getattr(cls, "_type_filter", None): kwargs["type"] = cls._type_filter - documents = list(cls._get_collection().find(kwargs, limit=1)) + documents = cls._db.find(cls._collection_name, kwargs, limit=1) + # documents = list(cls._get_collection().find(kwargs, limit=1)) if not documents: return None - document = documents[0] + document = documents.pop() document["__id"] = document.pop("_key") return cls.load(document) @@ -599,7 +755,8 @@ def link_to( # Avoid circular dependency from core.schemas.graph import Relationship - graph = self._db.graph("threat_graph") + async_db = self._db.begin_async_execution(return_result=True) + async_graph = async_db.graph("threat_graph") # Check if a relationship with the same link_type already exists aql = """ @@ -624,7 +781,10 @@ def link_to( relationship.count += 1 edge = json.loads(relationship.model_dump_json()) edge["_id"] = neighbors[0]["_id"] - graph.update_edge(edge) + job = async_graph.update_edge(edge) + while job.status() != "done": + time.sleep(0.1) + # graph.update_edge(edge) if self._collection_name != "auditlog": try: event = message.LinkEvent( @@ -647,12 +807,16 @@ def link_to( created=datetime.datetime.now(datetime.timezone.utc), modified=datetime.datetime.now(datetime.timezone.utc), ) - result = graph.edge_collection("links").link( - relationship.source, - relationship.target, + col = async_graph.edge_collection("links") + job = col.link( + self.extended_id, + target.extended_id, data=json.loads(relationship.model_dump_json()), return_new=True, - )["new"] + ) + while job.status() != "done": + time.sleep(0.1) + result = job.result()["new"] result["__id"] = result.pop("_key") relationship = Relationship.load(result) if self._collection_name != "auditlog": From afdb372147503b9e676ae14355e6ac7c349917b3 Mon Sep 17 00:00:00 2001 From: Fred Baguelin Date: Wed, 27 Nov 2024 18:45:42 +0100 Subject: [PATCH 02/12] Add multiprocessing lock for events --- core/events/consumers.py | 62 +++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/core/events/consumers.py b/core/events/consumers.py index 9061712d4..eaeb712aa 100644 --- a/core/events/consumers.py +++ b/core/events/consumers.py @@ -9,7 +9,13 @@ from kombu.mixins import ConsumerMixin from core.config.config import yeti_config -from core.events.message import EventMessage, LogMessage +from core.events.message import ( + EventMessage, + LinkEvent, + LogMessage, + ObjectEvent, + TagEvent, +) from core.schemas.task import EventTask, LogTask, TaskType from core.taskmanager import TaskManager from core.taskscheduler import get_plugins_list @@ -62,22 +68,40 @@ class EventConsumer(Consumer): def __init__(self, stop_event, connection, queues): super().__init__(EventTask, stop_event, connection, queues) + def debug(self, message, body): + message_digest = hashlib.sha256(body.encode()).hexdigest() + ts = int(message.timestamp.timestamp()) + if isinstance(message.event, ObjectEvent): + self.logger.debug( + f"Message received at {ts} - digest: {message_digest} | {message.event.event_message}" + ) + if isinstance(message.event, LinkEvent): + source = message.event.link_source_event + target = message.event.link_target_event + self.logger.debug( + f"Message received at {ts} - digest: {message_digest} | {source} --> {target}" + ) + if isinstance(message.event, TagEvent): + self.logger.debug( + f"Message received at {ts} - digest: {message_digest} | {message.event.tag_message}" + ) + def on_message(self, body, received_message): - try: - message = EventMessage(**json.loads(body)) - message_digest = hashlib.sha256(body.encode()).hexdigest() - ts = int(message.timestamp.timestamp()) - self.logger.debug(f"Message received at {ts} - digest: {message_digest}") - for task in TaskManager.tasks(): - if task.enabled is False or task.type != TaskType.event: - continue - if message.event.match(task.compiled_acts_on): - self.logger.info(f"Running task {task.name}") + message = EventMessage(**json.loads(body)) + self.debug(message, body) + for task in TaskManager.tasks(): + if task.enabled is False or task.type != TaskType.event: + continue + if message.event.match(task.compiled_acts_on): + self.logger.info( + f"Running task {task.name} on {message.event.event_message}" + ) + try: task.run(message) - except Exception: - self.logger.exception( - f"[PID:{os.getpid()}] - Error processing message in events queue with {body}" - ) + except Exception: + self.logger.exception( + f"[PID:{os.getpid()}] - Error processing message in events queue with {body}" + ) received_message.ack() @@ -91,10 +115,12 @@ def on_message(self, body, received_message): for task in TaskManager.tasks(): if task.enabled is False or task.type != TaskType.log: continue + self.logger.info(f"Running task {task.name} on {message}") task.run(message) except Exception: self.logger.exception(f"Error processing message in logs queue with {body}") - received_message.ack() + finally: + received_message.ack() class Worker(multiprocessing.Process): @@ -110,6 +136,9 @@ def __init__(self, queue, *args, **kwargs): def run(self): logger.info(f"Worker {self.name} started") + from core.database_arango import db + + db.set_lock(lock) while not self.stop_event.is_set(): try: self._worker.run() @@ -122,6 +151,7 @@ def run(self): if __name__ == "__main__": + lock = multiprocessing.Lock() parser = argparse.ArgumentParser( prog="yeti-consumer", description="Consume events and logs from the event bus" ) From 920350be1170c601c8d5c37eeda128e089d0119c Mon Sep 17 00:00:00 2001 From: Fred Baguelin Date: Wed, 27 Nov 2024 18:47:02 +0100 Subject: [PATCH 03/12] Handle overwrite option in save --- core/schemas/observable.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/core/schemas/observable.py b/core/schemas/observable.py index e97275ee9..6c5836828 100644 --- a/core/schemas/observable.py +++ b/core/schemas/observable.py @@ -148,7 +148,12 @@ def create(*, value: str, type: str | None = None, **kwargs) -> ObservableTypes: def save( - *, value: str, type: str | None = None, tags: List[str] = None, **kwargs + *, + value: str, + type: str | None = None, + tags: List[str] = None, + overwrite=False, + **kwargs, ) -> ObservableTypes: """ Save an observable object. If the object is already in the database, it will be updated. @@ -160,14 +165,26 @@ def save( tags is an optional list of tags to add to the observable. """ - observable_obj = create(value=value, type=type, **kwargs).save() + observable_obj = create(value=value, type=type, **kwargs) + db_obs = find(value=observable_obj.value, type=observable_obj.type, **kwargs) + if db_obs: + if overwrite: + observable_obj = observable_obj.save() + else: + observable_obj = db_obs + else: + observable_obj = observable_obj.save() if tags: observable_obj.tag(tags) return observable_obj def find(*, value, **kwargs) -> ObservableTypes: - return Observable.find(value=refang(value), **kwargs) + if "type" in kwargs: + obs = Observable.find(value=refang(value), type=kwargs["type"]) + else: + obs = Observable.find(value=refang(value)) + return obs def create_from_text(text: str) -> Tuple[List["ObservableTypes"], List[str]]: From 8989b5af52fdacad01634ca621ca8099489b8e01 Mon Sep 17 00:00:00 2001 From: Fred Baguelin Date: Wed, 27 Nov 2024 18:47:43 +0100 Subject: [PATCH 04/12] Fixes date_joined removal --- plugins/analytics/public/dockerhub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/analytics/public/dockerhub.py b/plugins/analytics/public/dockerhub.py index ddeb0cd8c..5b7e584be 100644 --- a/plugins/analytics/public/dockerhub.py +++ b/plugins/analytics/public/dockerhub.py @@ -234,7 +234,7 @@ def each(self, user_obj: Observable): ): user_obj.created = date_joined user_obj = user_obj.save() - del metadata["date_joined"] + metadata.pop("date_joined", None) user_obj.add_context("hub.docker.com", metadata) for image in DockerHubApi.user_images(user_obj.value): image_name = f'{image["namespace"]}/{image["name"]}' From 757789803a2c38abd2e46f7d378209c5e91b1b06 Mon Sep 17 00:00:00 2001 From: Fred Baguelin Date: Wed, 27 Nov 2024 18:48:50 +0100 Subject: [PATCH 05/12] Add more logging --- plugins/events/public/dockerhub.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/plugins/events/public/dockerhub.py b/plugins/events/public/dockerhub.py index 32ad07509..4f2b4df6a 100644 --- a/plugins/events/public/dockerhub.py +++ b/plugins/events/public/dockerhub.py @@ -43,10 +43,12 @@ def run(self, message: EventMessage) -> None: f"Skipping {container_image.type} {container_image.value} not from docker.io" ) return + self.logger.info(f"Fetching metadata for {container_image.value}") metadata = DockerHubApi.image_full_details(container_image.value) if not metadata: self.logger.info(f"Image metadata for {container_image.value} not found") return + self.logger.info(f"Adding context for {container_image.value}") context = get_image_context(metadata) container_image.add_context("hub.docker.com", context) make_relationships(container_image, metadata) From 311fa706b9099f902f0bbc17cd199967545fdccb Mon Sep 17 00:00:00 2001 From: Fred Baguelin Date: Wed, 27 Nov 2024 18:55:22 +0100 Subject: [PATCH 06/12] Update python arango deps --- poetry.lock | 10 +++++----- pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1917e3186..c68d26ab6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1985,13 +1985,13 @@ test = ["flaky", "pretend", "pytest (>=3.0.1)"] [[package]] name = "python-arango" -version = "7.9.1" +version = "8.1.2" description = "Python Driver for ArangoDB" optional = false python-versions = ">=3.8" files = [ - {file = "python-arango-7.9.1.tar.gz", hash = "sha256:18f7d365fb6cf45778fa73b559e3865d0a1c00081de65ef00ba238db52e374ab"}, - {file = "python_arango-7.9.1-py3-none-any.whl", hash = "sha256:23ec7b3aad774db5f99df20f6a1036385c85eb5c9864e47628bc622ea812f2f8"}, + {file = "python_arango-8.1.2-py3-none-any.whl", hash = "sha256:2b9f604b0f4eaf5209893cdb7a2f96448aa27d540300939b0a854ded75c031cb"}, + {file = "python_arango-8.1.2.tar.gz", hash = "sha256:4a39525ed426b23d7ae031e071f786ac35e6aa571d158ec54c59b74d6ae7a27f"}, ] [package.dependencies] @@ -2004,7 +2004,7 @@ setuptools = ">=42" urllib3 = ">=1.26.0" [package.extras] -dev = ["black (>=22.3.0)", "flake8 (>=4.0.1)", "isort (>=5.10.1)", "mock", "mypy (>=0.942)", "pre-commit (>=2.17.0)", "pytest (>=7.1.1)", "pytest-cov (>=3.0.0)", "sphinx", "sphinx-rtd-theme", "types-pkg-resources", "types-requests", "types-setuptools"] +dev = ["black (>=22.3.0)", "flake8 (>=4.0.1)", "isort (>=5.10.1)", "mock", "mypy (>=0.942)", "pre-commit (>=2.17.0)", "pytest (>=7.1.1)", "pytest-cov (>=3.0.0)", "sphinx", "sphinx-rtd-theme", "types-requests", "types-setuptools"] [[package]] name = "python-dateutil" @@ -2883,4 +2883,4 @@ s3 = ["boto3"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "136fff7a12b53a9850aea529fccacb9d45e9d576e4c5620df7632bd25e54bfec" +content-hash = "248a01002d6f5561f9502092e66dfec66fcd3fa27d7430c53b1d62e856682383" diff --git a/pyproject.toml b/pyproject.toml index f469e0111..3b2c1139d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ readme = "README.md" python = ">=3.10,<3.12" uvicorn = "^0.23.2" fastapi = "^0.109.0" -python-arango = "^7.9.1" +python-arango = "^8.1.2" celery = "^5.3.4" validators = "^0.34.0" python-jose = {extras = ["cryptography"], version = "^3.3.0"} From 3a52bf18196b1615b3338d36cb5253f648f7ef7f Mon Sep 17 00:00:00 2001 From: Fred Baguelin Date: Thu, 28 Nov 2024 12:05:49 +0100 Subject: [PATCH 07/12] Fully switch to async database for collections and graphs --- core/database_arango.py | 267 +++++++++++++++++++++------------------- 1 file changed, 141 insertions(+), 126 deletions(-) diff --git a/core/database_arango.py b/core/database_arango.py index f52737e35..8cbf48bcc 100644 --- a/core/database_arango.py +++ b/core/database_arango.py @@ -35,6 +35,8 @@ TESTING = "unittest" in sys.modules.keys() +ASYNC_JOB_WAIT_TIME = 0.01 + TYetiObject = TypeVar("TYetiObject", bound="ArangoYetiConnector") from arango.http import HTTPClient @@ -311,120 +313,55 @@ def clear(self, truncate=True): def collection(self, name): if self.db is None: self.connect() - - async_db = self.db.begin_async_execution(return_result=True) if name not in self.collections: + async_db = self.db.begin_async_execution(return_result=True) job = async_db.has_collection(name) while job.status() != "done": - time.sleep(0.1) + time.sleep(ASYNC_JOB_WAIT_TIME) data = job.result() if data: - col = async_db.collection(name) - self.collections[name] = col + self.collections[name] = async_db.collection(name) else: job = async_db.create_collection(name) while job.status() != "done": - time.sleep(0.1) - col = job.result() - self.collections[name] = col + time.sleep(ASYNC_JOB_WAIT_TIME) + self.collections[name] = job.result() return self.collections[name] - def insert(self, collection, document_json): + def graph(self, name): if self.db is None: self.connect() - document: dict = json.loads(document_json) - try: + if name not in self.graphs: async_db = self.db.begin_async_execution(return_result=True) - async_col = async_db.collection(collection) - job = async_col.insert(document, return_new=True) - while job.status() != "done": - time.sleep(0.1) - newdoc = job.result() - except DocumentInsertError as err: - if not err.error_code == 1210: # Unique constraint violation - raise - return None - return newdoc - - def update(self, collection, document_json): - document = json.loads(document_json) - doc_id = document.pop("id") - async_db = self.db.begin_async_execution(return_result=True) - async_col = async_db.collection(collection) - if doc_id: - document["_key"] = doc_id - job = async_col.update(document, return_new=True) + job = async_db.has_graph(name) while job.status() != "done": - time.sleep(0.1) - newdoc = job.result() - else: - if "value" in document: - filters = {"value": document["value"]} + time.sleep(ASYNC_JOB_WAIT_TIME) + data = job.result() + if data: + graph = async_db.graph(name) + self.graphs[name] = graph else: - filters = {"name": document["name"]} - if "type" in document: - filters["type"] = document["type"] - logging.debug(f"filters: {filters}") - job = async_col.update_match(filters, document) - while job.status() != "done": - time.sleep(0.1) - newdoc = job.result() - if newdoc != 1: - return None - try: - job = async_col.find(filters, limit=1) + job = async_db.create_graph(name) while job.status() != "done": - time.sleep(0.1) - result = job.result() - newdoc = list(result)[0] - except IndexError as exception: - msg = f"Update failed when adding {document_json}: {exception}" - logging.error(msg) - raise RuntimeError(msg) - return newdoc - - def get(self, collection, id): - if self.db is None: - self.connect() - async_db = self.db.begin_async_execution(return_result=True) - async_col = async_db.collection(collection) - job = async_col.get(id) - while job.status() != "done": - time.sleep(0.1) - result = job.result() - return result - - def find(self, collection, kwargs, limit): - if self.db is None: - self.connect() - async_db = self.db.begin_async_execution(return_result=True) - async_col = async_db.collection(collection) - job = async_col.find(kwargs, limit=limit) - while job.status() != "done": - time.sleep(0.1) - result = job.result() - return result - - def graph(self, name): - if self.db is None: - self.connect() - - try: - return self.db.create_graph(name) - except GraphCreateError as err: - if err.error_code in [1207, 1925]: - return self.db.graph(name) - raise + time.sleep(ASYNC_JOB_WAIT_TIME) + self.graphs[name] = job.result() + return self.graphs[name] + # graph is in async context def create_edge_definition(self, graph, definition): if self.db is None: self.connect() if not self.db.has_collection(definition["edge_collection"]): - collection = graph.create_edge_definition(**definition) + job = graph.create_edge_definition(**definition) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + collection = job.result() else: - collection = graph.replace_edge_definition(**definition) - + job = graph.replace_edge_definition(**definition) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + collection = job.result() self.collections[definition["edge_collection"]] = collection return collection @@ -453,24 +390,62 @@ def extended_id(self): return self._collection_name + "/" + self.id def _insert(self, document_json: str): - document = self._db.insert(self._collection_name, document_json) - if not document: + newdoc = None + try: + async_col = self._db.collection(self._collection_name) + job = async_col.insert(json.loads(document_json), return_new=True) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + newdoc = job.result() + newdoc = newdoc["new"] + except DocumentInsertError as err: + if not err.error_code == 1210: # Unique constraint violation + raise return None - document = document["new"] - document["__id"] = document.pop("_key") - return document + if not newdoc: + return None + newdoc["__id"] = newdoc.pop("_key") + return newdoc def _update(self, document_json): - document = self._db.update(self._collection_name, document_json) - if not document: - return None - try: - if "new" in document: - document = document["new"] - document["__id"] = document.pop("_key") - return document - except Exception: + # document = self._db.update(self._collection_name, document_json) + document = json.loads(document_json) + doc_id = document.pop("id") + async_col = self._db.collection(self._collection_name) + newdoc = None + if doc_id: + document["_key"] = doc_id + job = async_col.update(document, return_new=True) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + newdoc = job.result() + newdoc = newdoc["new"] + else: + if "value" in document: + filters = {"value": document["value"]} + else: + filters = {"name": document["name"]} + if "type" in document: + filters["type"] = document["type"] + logging.debug(f"filters: {filters}") + job = async_col.update_match(filters, document) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + try: + job = async_col.find(filters, limit=1) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + result = job.result() + newdoc = list(result)[0] + except IndexError as exception: + msg = f"Update failed when adding {document_json}: {exception}" + logging.error(msg) + raise RuntimeError(msg) + + if not newdoc: return None + newdoc["__id"] = newdoc.pop("_key") + return newdoc def save( self: TYetiObject, exclude_overwrite: list[str] = ["created", "tags", "context"] @@ -546,7 +521,11 @@ def get(cls: Type[TYetiObject], id: str) -> TYetiObject | None: Returns: A Yeti object.""" - document = cls._db.get(cls._collection_name, id) + async_col = cls._db.collection(cls._collection_name) + job = async_col.get(id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + document = job.result() if not document: return None document["__id"] = document.pop("_key") @@ -565,8 +544,11 @@ def find(cls: Type[TYetiObject], **kwargs) -> TYetiObject | None: if "type" not in kwargs and getattr(cls, "_type_filter", None): kwargs["type"] = cls._type_filter - documents = cls._db.find(cls._collection_name, kwargs, limit=1) - # documents = list(cls._get_collection().find(kwargs, limit=1)) + async_col = cls._db.collection(cls._collection_name) + job = async_col.find(kwargs, limit=1) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + documents = job.result() if not documents: return None document = documents.pop() @@ -678,12 +660,15 @@ def link_to_tag( fresh=True, ) - result = graph.edge_collection("tagged").link( + job = graph.edge_collection("tagged").link( self.extended_id, tag_obj.extended_id, data=json.loads(tag_relationship.model_dump_json()), return_new=True, - )["new"] + ) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + result = job.result()["new"] result["__id"] = result.pop("_key") if self._collection_name != "auditlog": try: @@ -725,13 +710,22 @@ def clear_tags(self): graph = self._db.graph("tags") self.get_tags() - results = graph.edge_collection("tagged").edges(self.extended_id) + job = graph.edge_collection("tagged").edges(self.extended_id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + results = job.result() for edge in results["edges"]: if self._collection_name != "auditlog": try: - tag_relationship = self._db.collection("tagged").get(edge["_id"]) + job = self._db.collection("tagged").get(edge["_id"]) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + tag_relationship = job.result() tag_collection, tag_id = tag_relationship["target"].split("/") - tag_obj = self._db.collection(tag_collection).get(tag_id) + job = self._db.collection(tag_collection).get(tag_id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + tag_obj = job.result() event = message.TagEvent( type=message.EventType.delete, tagged_object=self, @@ -740,7 +734,9 @@ def clear_tags(self): producer.publish_event(event) except Exception: logging.exception("Error while publishing event") - graph.edge_collection("tagged").delete(edge["_id"]) + job = graph.edge_collection("tagged").delete(edge["_id"]) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) def link_to( self, target, relationship_type: str, description: str @@ -755,8 +751,7 @@ def link_to( # Avoid circular dependency from core.schemas.graph import Relationship - async_db = self._db.begin_async_execution(return_result=True) - async_graph = async_db.graph("threat_graph") + async_graph = self._db.graph("threat_graph") # Check if a relationship with the same link_type already exists aql = """ @@ -783,8 +778,7 @@ def link_to( edge["_id"] = neighbors[0]["_id"] job = async_graph.update_edge(edge) while job.status() != "done": - time.sleep(0.1) - # graph.update_edge(edge) + time.sleep(ASYNC_JOB_WAIT_TIME) if self._collection_name != "auditlog": try: event = message.LinkEvent( @@ -815,7 +809,7 @@ def link_to( return_new=True, ) while job.status() != "done": - time.sleep(0.1) + time.sleep(ASYNC_JOB_WAIT_TIME) result = job.result()["new"] result["__id"] = result.pop("_key") relationship = Relationship.load(result) @@ -841,7 +835,9 @@ def swap_link(self): edge["_to"] = self.target edge["_id"] = f"links/{self.id}" graph = self._db.graph("threat_graph") - graph.update_edge(edge) + job = graph.update_edge(edge) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) self.save() # TODO: Consider extracting this to its own class, given it's only meant @@ -1293,13 +1289,20 @@ def fulltext_filter(cls, keywords): def delete(self, all_versions=True): """Deletes an object from the database.""" - if self._db.graph("threat_graph").has_vertex_collection(self._collection_name): + job = self._db.graph("threat_graph").has_vertex_collection( + self._collection_name + ) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + if job.result(): col = self._db.graph("threat_graph").vertex_collection( self._collection_name ) else: col = self._db.collection(self._collection_name) - col.delete(self.id) + job = col.delete(self.id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) if self._collection_name == "auditlog": return try: @@ -1307,16 +1310,28 @@ def delete(self, all_versions=True): if self._collection_name == "tagged": source_collection, source_id = self.source.split("/") tag_collection, tag_id = self.target.split("/") - source_obj = self._db.collection(source_collection).get(source_id) - tag_obj = self._db.collection(tag_collection).get(tag_id) + job = self._db.collection(source_collection).get(source_id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + source_obj = job.result() + job = self._db.collection(tag_collection).get(tag_id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + tag_obj = job.result() event = message.TagEvent( type=event_type, tagged_object=source_obj, tag_object=tag_obj ) elif self._collection_name == "links": source_collection, source_id = self.source.split("/") target_collection, target_id = self.target.split("/") - source_obj = self._db.collection(source_collection).get(source_id) - target_obj = self._db.collection(target_collection).get(target_id) + job = self._db.collection(source_collection).get(source_id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + source_obj = job.result() + job = self._db.collection(target_collection).get(target_id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + target_obj = job.result() event = message.LinkEvent( type=event_type, source_object=source_obj, From 489043391b44c54845daacf19c69e1ddae650665 Mon Sep 17 00:00:00 2001 From: Fred Baguelin Date: Fri, 29 Nov 2024 09:20:14 +0100 Subject: [PATCH 08/12] Remove custom http client and lock --- core/database_arango.py | 88 +---------------------------------------- 1 file changed, 1 insertion(+), 87 deletions(-) diff --git a/core/database_arango.py b/core/database_arango.py index 8cbf48bcc..30f9a54d9 100644 --- a/core/database_arango.py +++ b/core/database_arango.py @@ -39,86 +39,6 @@ TYetiObject = TypeVar("TYetiObject", bound="ArangoYetiConnector") -from arango.http import HTTPClient -from arango.response import Response -from requests import Session -from requests.adapters import HTTPAdapter -from urllib3.util import Retry - - -class LogAndRetryHTTPClient(HTTPClient): - def __init__(self, retries=0): - self._logger = logging.getLogger() - self._lock = None - self._retries = retries - - @property - def lock(self): - if not self._lock: - self._lock = MockLock() - return self._lock - - def set_lock(self, lock): - self._lock = lock - - def create_session(self, host): - session = Session() - if self._retries: - retry_strategy = Retry( - total=3, - backoff_factor=1, - status_forcelist=[429, 500, 502, 503, 504], - allowed_methods=["HEAD", "GET", "OPTIONS", "POST"], - ) - http_adapter = HTTPAdapter(max_retries=retry_strategy) - session.mount("https://", http_adapter) - session.mount("http://", http_adapter) - return session - - def send_request( - self, session, method, url, params=None, data=None, headers=None, auth=None - ): - # Acquire multiprocessing lock - # Mandatory with events consumers - self.lock.acquire() - response = None - try: - response = session.request( - method=method, - url=url, - params=params, - data=data, - headers=headers, - auth=auth, - verify=False, # Disable SSL verification - timeout=5, # Use timeout of 5 seconds - ) - # Return an instance of arango.response.Response. - response = Response( - method=response.request.method, - url=response.url, - headers=response.headers, - status_code=response.status_code, - status_text=response.reason, - raw_body=response.text, - ) - except Exception: - self._logger.error(f"Error while sending request to {url}") - finally: - self.lock.release() - return response - - -class MockLock: - def __init__(self): - pass - - def acquire(self): - pass - - def release(self): - pass - class ArangoDatabase: """Class that contains the base class for the database. @@ -130,7 +50,6 @@ def __init__(self): self.db = None self.collections = dict() self.graphs = dict() - self._http_client = LogAndRetryHTTPClient() def connect( self, @@ -150,9 +69,7 @@ def connect( database = "yeti_test" host_string = f"http://{host}:{port}" - client = ArangoClient( - hosts=host_string, request_timeout=None, http_client=self._http_client - ) + client = ArangoClient(hosts=host_string, request_timeout=None) sys_db = client.db("_system", username=username, password=password) for _ in range(0, 4): @@ -370,9 +287,6 @@ def __getattr__(self, key): self.connect() return getattr(self.db, key) - def set_lock(self, lock): - self._http_client.set_lock(lock) - db = ArangoDatabase() From de218460c780a5c44e035a620e0798b16cc19540 Mon Sep 17 00:00:00 2001 From: Fred Baguelin Date: Fri, 29 Nov 2024 09:20:46 +0100 Subject: [PATCH 09/12] Spawn processes instead of fork --- core/events/consumers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/events/consumers.py b/core/events/consumers.py index eaeb712aa..a77eb1078 100644 --- a/core/events/consumers.py +++ b/core/events/consumers.py @@ -138,7 +138,6 @@ def run(self): logger.info(f"Worker {self.name} started") from core.database_arango import db - db.set_lock(lock) while not self.stop_event.is_set(): try: self._worker.run() @@ -151,7 +150,7 @@ def run(self): if __name__ == "__main__": - lock = multiprocessing.Lock() + multiprocessing.set_start_method("spawn") parser = argparse.ArgumentParser( prog="yeti-consumer", description="Consume events and logs from the event bus" ) From 6bfd7d2efd894549cde740203efb9f7afe69c5d5 Mon Sep 17 00:00:00 2001 From: Fred Baguelin Date: Fri, 29 Nov 2024 09:53:45 +0100 Subject: [PATCH 10/12] Update logging and better exception handling --- core/events/consumers.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/core/events/consumers.py b/core/events/consumers.py index a77eb1078..2dd1a5d4f 100644 --- a/core/events/consumers.py +++ b/core/events/consumers.py @@ -87,15 +87,20 @@ def debug(self, message, body): ) def on_message(self, body, received_message): - message = EventMessage(**json.loads(body)) + try: + message = EventMessage(**json.loads(body)) + except Exception: + self.logger.exception( + "Error parsing message in events queue. Discarding message" + ) + received_message.ack() + return self.debug(message, body) for task in TaskManager.tasks(): if task.enabled is False or task.type != TaskType.event: continue if message.event.match(task.compiled_acts_on): - self.logger.info( - f"Running task {task.name} on {message.event.event_message}" - ) + self.logger.info(f"Running task {task.name}") try: task.run(message) except Exception: From 332510b75414087bf7cf52b09ddb5e4221660714 Mon Sep 17 00:00:00 2001 From: Fred Baguelin Date: Fri, 29 Nov 2024 12:05:26 +0100 Subject: [PATCH 11/12] Update find prototype --- core/schemas/observable.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/schemas/observable.py b/core/schemas/observable.py index 6c5836828..e4f65a864 100644 --- a/core/schemas/observable.py +++ b/core/schemas/observable.py @@ -179,9 +179,9 @@ def save( return observable_obj -def find(*, value, **kwargs) -> ObservableTypes: - if "type" in kwargs: - obs = Observable.find(value=refang(value), type=kwargs["type"]) +def find(value: str, type: str = None) -> ObservableTypes: + if type: + obs = Observable.find(value=refang(value), type=type) else: obs = Observable.find(value=refang(value)) return obs From 03f0a8413b6e2b448cca948eff1cd9403b99b595 Mon Sep 17 00:00:00 2001 From: Fred Baguelin Date: Fri, 29 Nov 2024 12:35:26 +0100 Subject: [PATCH 12/12] Update find call in save function --- core/schemas/observable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/schemas/observable.py b/core/schemas/observable.py index e4f65a864..64552d8f4 100644 --- a/core/schemas/observable.py +++ b/core/schemas/observable.py @@ -166,7 +166,7 @@ def save( tags is an optional list of tags to add to the observable. """ observable_obj = create(value=value, type=type, **kwargs) - db_obs = find(value=observable_obj.value, type=observable_obj.type, **kwargs) + db_obs = find(value=observable_obj.value, type=observable_obj.type) if db_obs: if overwrite: observable_obj = observable_obj.save()