diff --git a/core/database_arango.py b/core/database_arango.py index b816e133a..30f9a54d9 100644 --- a/core/database_arango.py +++ b/core/database_arango.py @@ -35,6 +35,8 @@ TESTING = "unittest" in sys.modules.keys() +ASYNC_JOB_WAIT_TIME = 0.01 + TYetiObject = TypeVar("TYetiObject", bound="ArangoYetiConnector") @@ -228,35 +230,55 @@ def clear(self, truncate=True): def collection(self, name): if self.db is None: self.connect() - if name not in self.collections: - if self.db.has_collection(name): - self.collections[name] = self.db.collection(name) + async_db = self.db.begin_async_execution(return_result=True) + job = async_db.has_collection(name) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + data = job.result() + if data: + self.collections[name] = async_db.collection(name) else: - self.collections[name] = self.db.create_collection(name) - + job = async_db.create_collection(name) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + self.collections[name] = job.result() return self.collections[name] def graph(self, name): if self.db is None: self.connect() + if name not in self.graphs: + async_db = self.db.begin_async_execution(return_result=True) + job = async_db.has_graph(name) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + data = job.result() + if data: + graph = async_db.graph(name) + self.graphs[name] = graph + else: + job = async_db.create_graph(name) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + self.graphs[name] = job.result() + return self.graphs[name] - try: - return self.db.create_graph(name) - except GraphCreateError as err: - if err.error_code in [1207, 1925]: - return self.db.graph(name) - raise - + # graph is in async context def create_edge_definition(self, graph, definition): if self.db is None: self.connect() if not self.db.has_collection(definition["edge_collection"]): - collection = graph.create_edge_definition(**definition) + job = graph.create_edge_definition(**definition) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + collection = job.result() else: - collection = graph.replace_edge_definition(**definition) - + job = graph.replace_edge_definition(**definition) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + collection = job.result() self.collections[definition["edge_collection"]] = collection return collection @@ -282,22 +304,36 @@ def extended_id(self): return self._collection_name + "/" + self.id def _insert(self, document_json: str): - document: dict = json.loads(document_json) + newdoc = None try: - newdoc = self._get_collection().insert(document, return_new=True)["new"] + async_col = self._db.collection(self._collection_name) + job = async_col.insert(json.loads(document_json), return_new=True) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + newdoc = job.result() + newdoc = newdoc["new"] except DocumentInsertError as err: if not err.error_code == 1210: # Unique constraint violation raise return None + if not newdoc: + return None newdoc["__id"] = newdoc.pop("_key") return newdoc def _update(self, document_json): + # document = self._db.update(self._collection_name, document_json) document = json.loads(document_json) doc_id = document.pop("id") + async_col = self._db.collection(self._collection_name) + newdoc = None if doc_id: document["_key"] = doc_id - newdoc = self._get_collection().update(document, return_new=True)["new"] + job = async_col.update(document, return_new=True) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + newdoc = job.result() + newdoc = newdoc["new"] else: if "value" in document: filters = {"value": document["value"]} @@ -305,15 +341,23 @@ def _update(self, document_json): filters = {"name": document["name"]} if "type" in document: filters["type"] = document["type"] - self._get_collection().update_match(filters, document) - logging.debug(f"filters: {filters}") + job = async_col.update_match(filters, document) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) try: - newdoc = list(self._get_collection().find(filters, limit=1))[0] + job = async_col.find(filters, limit=1) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + result = job.result() + newdoc = list(result)[0] except IndexError as exception: msg = f"Update failed when adding {document_json}: {exception}" logging.error(msg) raise RuntimeError(msg) + + if not newdoc: + return None newdoc["__id"] = newdoc.pop("_key") return newdoc @@ -391,7 +435,11 @@ def get(cls: Type[TYetiObject], id: str) -> TYetiObject | None: Returns: A Yeti object.""" - document = cls._get_collection().get(id) + async_col = cls._db.collection(cls._collection_name) + job = async_col.get(id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + document = job.result() if not document: return None document["__id"] = document.pop("_key") @@ -410,10 +458,14 @@ def find(cls: Type[TYetiObject], **kwargs) -> TYetiObject | None: if "type" not in kwargs and getattr(cls, "_type_filter", None): kwargs["type"] = cls._type_filter - documents = list(cls._get_collection().find(kwargs, limit=1)) + async_col = cls._db.collection(cls._collection_name) + job = async_col.find(kwargs, limit=1) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + documents = job.result() if not documents: return None - document = documents[0] + document = documents.pop() document["__id"] = document.pop("_key") return cls.load(document) @@ -522,12 +574,15 @@ def link_to_tag( fresh=True, ) - result = graph.edge_collection("tagged").link( + job = graph.edge_collection("tagged").link( self.extended_id, tag_obj.extended_id, data=json.loads(tag_relationship.model_dump_json()), return_new=True, - )["new"] + ) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + result = job.result()["new"] result["__id"] = result.pop("_key") if self._collection_name != "auditlog": try: @@ -569,13 +624,22 @@ def clear_tags(self): graph = self._db.graph("tags") self.get_tags() - results = graph.edge_collection("tagged").edges(self.extended_id) + job = graph.edge_collection("tagged").edges(self.extended_id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + results = job.result() for edge in results["edges"]: if self._collection_name != "auditlog": try: - tag_relationship = self._db.collection("tagged").get(edge["_id"]) + job = self._db.collection("tagged").get(edge["_id"]) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + tag_relationship = job.result() tag_collection, tag_id = tag_relationship["target"].split("/") - tag_obj = self._db.collection(tag_collection).get(tag_id) + job = self._db.collection(tag_collection).get(tag_id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + tag_obj = job.result() event = message.TagEvent( type=message.EventType.delete, tagged_object=self, @@ -584,7 +648,9 @@ def clear_tags(self): producer.publish_event(event) except Exception: logging.exception("Error while publishing event") - graph.edge_collection("tagged").delete(edge["_id"]) + job = graph.edge_collection("tagged").delete(edge["_id"]) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) def link_to( self, target, relationship_type: str, description: str @@ -599,7 +665,7 @@ def link_to( # Avoid circular dependency from core.schemas.graph import Relationship - graph = self._db.graph("threat_graph") + async_graph = self._db.graph("threat_graph") # Check if a relationship with the same link_type already exists aql = """ @@ -624,7 +690,9 @@ def link_to( relationship.count += 1 edge = json.loads(relationship.model_dump_json()) edge["_id"] = neighbors[0]["_id"] - graph.update_edge(edge) + job = async_graph.update_edge(edge) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) if self._collection_name != "auditlog": try: event = message.LinkEvent( @@ -647,12 +715,16 @@ def link_to( created=datetime.datetime.now(datetime.timezone.utc), modified=datetime.datetime.now(datetime.timezone.utc), ) - result = graph.edge_collection("links").link( - relationship.source, - relationship.target, + col = async_graph.edge_collection("links") + job = col.link( + self.extended_id, + target.extended_id, data=json.loads(relationship.model_dump_json()), return_new=True, - )["new"] + ) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + result = job.result()["new"] result["__id"] = result.pop("_key") relationship = Relationship.load(result) if self._collection_name != "auditlog": @@ -677,7 +749,9 @@ def swap_link(self): edge["_to"] = self.target edge["_id"] = f"links/{self.id}" graph = self._db.graph("threat_graph") - graph.update_edge(edge) + job = graph.update_edge(edge) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) self.save() # TODO: Consider extracting this to its own class, given it's only meant @@ -1129,13 +1203,20 @@ def fulltext_filter(cls, keywords): def delete(self, all_versions=True): """Deletes an object from the database.""" - if self._db.graph("threat_graph").has_vertex_collection(self._collection_name): + job = self._db.graph("threat_graph").has_vertex_collection( + self._collection_name + ) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + if job.result(): col = self._db.graph("threat_graph").vertex_collection( self._collection_name ) else: col = self._db.collection(self._collection_name) - col.delete(self.id) + job = col.delete(self.id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) if self._collection_name == "auditlog": return try: @@ -1143,16 +1224,28 @@ def delete(self, all_versions=True): if self._collection_name == "tagged": source_collection, source_id = self.source.split("/") tag_collection, tag_id = self.target.split("/") - source_obj = self._db.collection(source_collection).get(source_id) - tag_obj = self._db.collection(tag_collection).get(tag_id) + job = self._db.collection(source_collection).get(source_id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + source_obj = job.result() + job = self._db.collection(tag_collection).get(tag_id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + tag_obj = job.result() event = message.TagEvent( type=event_type, tagged_object=source_obj, tag_object=tag_obj ) elif self._collection_name == "links": source_collection, source_id = self.source.split("/") target_collection, target_id = self.target.split("/") - source_obj = self._db.collection(source_collection).get(source_id) - target_obj = self._db.collection(target_collection).get(target_id) + job = self._db.collection(source_collection).get(source_id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + source_obj = job.result() + job = self._db.collection(target_collection).get(target_id) + while job.status() != "done": + time.sleep(ASYNC_JOB_WAIT_TIME) + target_obj = job.result() event = message.LinkEvent( type=event_type, source_object=source_obj, diff --git a/core/events/consumers.py b/core/events/consumers.py index 9061712d4..2dd1a5d4f 100644 --- a/core/events/consumers.py +++ b/core/events/consumers.py @@ -9,7 +9,13 @@ from kombu.mixins import ConsumerMixin from core.config.config import yeti_config -from core.events.message import EventMessage, LogMessage +from core.events.message import ( + EventMessage, + LinkEvent, + LogMessage, + ObjectEvent, + TagEvent, +) from core.schemas.task import EventTask, LogTask, TaskType from core.taskmanager import TaskManager from core.taskscheduler import get_plugins_list @@ -62,22 +68,45 @@ class EventConsumer(Consumer): def __init__(self, stop_event, connection, queues): super().__init__(EventTask, stop_event, connection, queues) + def debug(self, message, body): + message_digest = hashlib.sha256(body.encode()).hexdigest() + ts = int(message.timestamp.timestamp()) + if isinstance(message.event, ObjectEvent): + self.logger.debug( + f"Message received at {ts} - digest: {message_digest} | {message.event.event_message}" + ) + if isinstance(message.event, LinkEvent): + source = message.event.link_source_event + target = message.event.link_target_event + self.logger.debug( + f"Message received at {ts} - digest: {message_digest} | {source} --> {target}" + ) + if isinstance(message.event, TagEvent): + self.logger.debug( + f"Message received at {ts} - digest: {message_digest} | {message.event.tag_message}" + ) + def on_message(self, body, received_message): try: message = EventMessage(**json.loads(body)) - message_digest = hashlib.sha256(body.encode()).hexdigest() - ts = int(message.timestamp.timestamp()) - self.logger.debug(f"Message received at {ts} - digest: {message_digest}") - for task in TaskManager.tasks(): - if task.enabled is False or task.type != TaskType.event: - continue - if message.event.match(task.compiled_acts_on): - self.logger.info(f"Running task {task.name}") - task.run(message) except Exception: self.logger.exception( - f"[PID:{os.getpid()}] - Error processing message in events queue with {body}" + "Error parsing message in events queue. Discarding message" ) + received_message.ack() + return + self.debug(message, body) + for task in TaskManager.tasks(): + if task.enabled is False or task.type != TaskType.event: + continue + if message.event.match(task.compiled_acts_on): + self.logger.info(f"Running task {task.name}") + try: + task.run(message) + except Exception: + self.logger.exception( + f"[PID:{os.getpid()}] - Error processing message in events queue with {body}" + ) received_message.ack() @@ -91,10 +120,12 @@ def on_message(self, body, received_message): for task in TaskManager.tasks(): if task.enabled is False or task.type != TaskType.log: continue + self.logger.info(f"Running task {task.name} on {message}") task.run(message) except Exception: self.logger.exception(f"Error processing message in logs queue with {body}") - received_message.ack() + finally: + received_message.ack() class Worker(multiprocessing.Process): @@ -110,6 +141,8 @@ def __init__(self, queue, *args, **kwargs): def run(self): logger.info(f"Worker {self.name} started") + from core.database_arango import db + while not self.stop_event.is_set(): try: self._worker.run() @@ -122,6 +155,7 @@ def run(self): if __name__ == "__main__": + multiprocessing.set_start_method("spawn") parser = argparse.ArgumentParser( prog="yeti-consumer", description="Consume events and logs from the event bus" ) diff --git a/core/schemas/observable.py b/core/schemas/observable.py index e97275ee9..64552d8f4 100644 --- a/core/schemas/observable.py +++ b/core/schemas/observable.py @@ -148,7 +148,12 @@ def create(*, value: str, type: str | None = None, **kwargs) -> ObservableTypes: def save( - *, value: str, type: str | None = None, tags: List[str] = None, **kwargs + *, + value: str, + type: str | None = None, + tags: List[str] = None, + overwrite=False, + **kwargs, ) -> ObservableTypes: """ Save an observable object. If the object is already in the database, it will be updated. @@ -160,14 +165,26 @@ def save( tags is an optional list of tags to add to the observable. """ - observable_obj = create(value=value, type=type, **kwargs).save() + observable_obj = create(value=value, type=type, **kwargs) + db_obs = find(value=observable_obj.value, type=observable_obj.type) + if db_obs: + if overwrite: + observable_obj = observable_obj.save() + else: + observable_obj = db_obs + else: + observable_obj = observable_obj.save() if tags: observable_obj.tag(tags) return observable_obj -def find(*, value, **kwargs) -> ObservableTypes: - return Observable.find(value=refang(value), **kwargs) +def find(value: str, type: str = None) -> ObservableTypes: + if type: + obs = Observable.find(value=refang(value), type=type) + else: + obs = Observable.find(value=refang(value)) + return obs def create_from_text(text: str) -> Tuple[List["ObservableTypes"], List[str]]: diff --git a/plugins/analytics/public/dockerhub.py b/plugins/analytics/public/dockerhub.py index ddeb0cd8c..5b7e584be 100644 --- a/plugins/analytics/public/dockerhub.py +++ b/plugins/analytics/public/dockerhub.py @@ -234,7 +234,7 @@ def each(self, user_obj: Observable): ): user_obj.created = date_joined user_obj = user_obj.save() - del metadata["date_joined"] + metadata.pop("date_joined", None) user_obj.add_context("hub.docker.com", metadata) for image in DockerHubApi.user_images(user_obj.value): image_name = f'{image["namespace"]}/{image["name"]}' diff --git a/plugins/events/public/dockerhub.py b/plugins/events/public/dockerhub.py index 32ad07509..4f2b4df6a 100644 --- a/plugins/events/public/dockerhub.py +++ b/plugins/events/public/dockerhub.py @@ -43,10 +43,12 @@ def run(self, message: EventMessage) -> None: f"Skipping {container_image.type} {container_image.value} not from docker.io" ) return + self.logger.info(f"Fetching metadata for {container_image.value}") metadata = DockerHubApi.image_full_details(container_image.value) if not metadata: self.logger.info(f"Image metadata for {container_image.value} not found") return + self.logger.info(f"Adding context for {container_image.value}") context = get_image_context(metadata) container_image.add_context("hub.docker.com", context) make_relationships(container_image, metadata) diff --git a/poetry.lock b/poetry.lock index 1917e3186..c68d26ab6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1985,13 +1985,13 @@ test = ["flaky", "pretend", "pytest (>=3.0.1)"] [[package]] name = "python-arango" -version = "7.9.1" +version = "8.1.2" description = "Python Driver for ArangoDB" optional = false python-versions = ">=3.8" files = [ - {file = "python-arango-7.9.1.tar.gz", hash = "sha256:18f7d365fb6cf45778fa73b559e3865d0a1c00081de65ef00ba238db52e374ab"}, - {file = "python_arango-7.9.1-py3-none-any.whl", hash = "sha256:23ec7b3aad774db5f99df20f6a1036385c85eb5c9864e47628bc622ea812f2f8"}, + {file = "python_arango-8.1.2-py3-none-any.whl", hash = "sha256:2b9f604b0f4eaf5209893cdb7a2f96448aa27d540300939b0a854ded75c031cb"}, + {file = "python_arango-8.1.2.tar.gz", hash = "sha256:4a39525ed426b23d7ae031e071f786ac35e6aa571d158ec54c59b74d6ae7a27f"}, ] [package.dependencies] @@ -2004,7 +2004,7 @@ setuptools = ">=42" urllib3 = ">=1.26.0" [package.extras] -dev = ["black (>=22.3.0)", "flake8 (>=4.0.1)", "isort (>=5.10.1)", "mock", "mypy (>=0.942)", "pre-commit (>=2.17.0)", "pytest (>=7.1.1)", "pytest-cov (>=3.0.0)", "sphinx", "sphinx-rtd-theme", "types-pkg-resources", "types-requests", "types-setuptools"] +dev = ["black (>=22.3.0)", "flake8 (>=4.0.1)", "isort (>=5.10.1)", "mock", "mypy (>=0.942)", "pre-commit (>=2.17.0)", "pytest (>=7.1.1)", "pytest-cov (>=3.0.0)", "sphinx", "sphinx-rtd-theme", "types-requests", "types-setuptools"] [[package]] name = "python-dateutil" @@ -2883,4 +2883,4 @@ s3 = ["boto3"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "136fff7a12b53a9850aea529fccacb9d45e9d576e4c5620df7632bd25e54bfec" +content-hash = "248a01002d6f5561f9502092e66dfec66fcd3fa27d7430c53b1d62e856682383" diff --git a/pyproject.toml b/pyproject.toml index f469e0111..3b2c1139d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ readme = "README.md" python = ">=3.10,<3.12" uvicorn = "^0.23.2" fastapi = "^0.109.0" -python-arango = "^7.9.1" +python-arango = "^8.1.2" celery = "^5.3.4" validators = "^0.34.0" python-jose = {extras = ["cryptography"], version = "^3.3.0"}