Skip to content

Commit

Permalink
Add events bus when inserting or updating new object in arangodb (#1144)
Browse files Browse the repository at this point in the history
  • Loading branch information
udgover authored Oct 19, 2024
1 parent 54e7fff commit 46ab0f9
Show file tree
Hide file tree
Showing 25 changed files with 688 additions and 43 deletions.
88 changes: 85 additions & 3 deletions core/database_arango.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import sys
import time
import traceback
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type, TypeVar

if TYPE_CHECKING:
Expand All @@ -22,6 +23,8 @@
from arango.exceptions import DocumentInsertError, GraphCreateError

from core.config.config import yeti_config
from core.events import message
from core.events.producer import producer

from .interfaces import AbstractYetiConnector

Expand Down Expand Up @@ -205,7 +208,6 @@ def _insert(self, document_json: str):
if not err.error_code == 1210: # Unique constraint violation
raise
return None

newdoc["__id"] = newdoc.pop("_key")
return newdoc

Expand All @@ -231,7 +233,6 @@ def _update(self, document_json):
msg = f"Update failed when adding {document_json}: {exception}"
logging.error(msg)
raise RuntimeError(msg)

newdoc["__id"] = newdoc.pop("_key")
return newdoc

Expand All @@ -255,16 +256,24 @@ def save(
if doc_dict.get("id") is not None:
exclude = ["tags"] + self._exclude_overwrite
result = self._update(self.model_dump_json(exclude=exclude))
event_type = message.EventType.update
else:
exclude = ["tags", "id"] + self._exclude_overwrite
result = self._insert(self.model_dump_json(exclude=exclude))
event_type = message.EventType.new
if not result:
exclude = exclude_overwrite + self._exclude_overwrite
result = self._update(self.model_dump_json(exclude=exclude))
event_type = message.EventType.update
yeti_object = self.__class__(**result)
# TODO: Override this if we decide to implement YetiTagModel
if hasattr(self, "tags"):
yeti_object.get_tags()
try:
event = message.ObjectEvent(type=event_type, yeti_object=yeti_object)
producer.publish_event(event)
except Exception:
logging.exception("Error while publishing event")
return yeti_object

@classmethod
Expand Down Expand Up @@ -404,6 +413,13 @@ def link_to_tag(
edge = json.loads(tag_relationship.model_dump_json())
edge["_id"] = tag_relationship.id
graph.update_edge(edge)
try:
event = message.TagEvent(
type=message.EventType.update, tagged_object=self, tag_object=tag
)
producer.publish_event(event)
except Exception:
logging.exception("Error while publishing event")
return tag_relationship

# Relationship doesn't exist, check if tag is already in the db
Expand All @@ -428,6 +444,13 @@ def link_to_tag(
return_new=True,
)["new"]
result["__id"] = result.pop("_key")
try:
event = message.TagEvent(
type=message.EventType.new, tagged_object=self, tag_object=tag_obj
)
producer.publish_event(event)
except Exception:
logging.exception("Error while publishing event")
return TagRelationship.load(result)

def expire_tag(self, tag_name: str) -> "TagRelationship":
Expand Down Expand Up @@ -462,6 +485,18 @@ def clear_tags(self):
self.get_tags()
results = graph.edge_collection("tagged").edges(self.extended_id)
for edge in results["edges"]:
try:
tag_relationship = self._db.collection("tagged").get(edge["_id"])
tag_collection, tag_id = tag_relationship["target"].split("/")
tag_obj = self._db.collection(tag_collection).get(tag_id)
event = message.TagEvent(
type=message.EventType.delete,
tagged_object=self,
tag_object=tag_obj,
)
producer.publish_event(event)
except Exception:
logging.exception("Error while publishing event")
graph.edge_collection("tagged").delete(edge["_id"])

def link_to(
Expand Down Expand Up @@ -501,6 +536,16 @@ def link_to(
edge = json.loads(relationship.model_dump_json())
edge["_id"] = neighbors[0]["_id"]
graph.update_edge(edge)
try:
event = message.LinkEvent(
type=message.EventType.update,
source_object=self,
target_object=target,
relationship=relationship,
)
producer.publish_event(event)
except Exception:
logging.exception("Error while publishing event")
return relationship

relationship = Relationship(
Expand All @@ -519,7 +564,18 @@ def link_to(
return_new=True,
)["new"]
result["__id"] = result.pop("_key")
return Relationship.load(result)
relationship = Relationship.load(result)
try:
event = message.LinkEvent(
type=message.EventType.new,
source_object=self,
target_object=target,
relationship=relationship,
)
producer.publish_event(event)
except Exception:
logging.exception("Error while publishing event")
return relationship

def swap_link(self):
"""Swaps the source and target of a relationship."""
Expand Down Expand Up @@ -932,6 +988,32 @@ def delete(self, all_versions=True):
else:
col = self._db.collection(self._collection_name)
col.delete(self.id)
try:
event_type = message.EventType.delete
if self._collection_name == "tagged":
source_collection, source_id = self.source.split("/")
tag_collection, tag_id = self.target.split("/")
source_obj = self._db.collection(source_collection).get(source_id)
tag_obj = self._db.collection(tag_collection).get(tag_id)
event = message.TagEvent(
type=event_type, tagged_object=source_obj, tag_object=tag_obj
)
elif self._collection_name == "links":
source_collection, source_id = self.source.split("/")
target_collection, target_id = self.target.split("/")
source_obj = self._db.collection(source_collection).get(source_id)
target_obj = self._db.collection(target_collection).get(target_id)
event = message.LinkEvent(
type=event_type,
source_object=source_obj,
target_object=target_obj,
relationship=self,
)
else:
event = message.ObjectEvent(type=event_type, yeti_object=self)
producer.publish_event(event)
except Exception:
logging.exception("Error while publishing event")

@classmethod
def _get_collection(cls):
Expand Down
Empty file added core/events/__init__.py
Empty file.
165 changes: 165 additions & 0 deletions core/events/consumers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import argparse
import hashlib
import json
import logging
import multiprocessing
import os

from kombu import Connection, Exchange, Queue
from kombu.mixins import ConsumerMixin

from core.config.config import yeti_config
from core.events.message import EventMessage, LogMessage
from core.schemas.task import EventTask, LogTask, TaskType
from core.taskmanager import TaskManager
from core.taskscheduler import get_plugins_list

# Register root logger for tasks
logger = logging.getLogger("task")
logger.propagate = False
formatter = logging.Formatter(
"[%(asctime)s: %(levelname)s/%(processName)s] %(message)s"
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)


class Consumer(ConsumerMixin):
def __init__(self, task_class: EventTask | LogTask, stop_event, connection, queues):
self.task_class = task_class
self._stop_event = stop_event
self.connection = connection
self.queues = queues
self._logger = None
get_plugins_list(task_class)

@property
def should_stop(self):
return self._stop_event.is_set()

@property
def logger(self):
if self._logger is None:
name = self.task_class.__name__.lower().replace("task", "")
self._logger = logging.getLogger(f"task.{name}")
self._logger.propagate = False
formatter = logging.Formatter(
"[%(asctime)s: %(levelname)s/%(processName)s] %(message)s"
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
self._logger.addHandler(handler)
return self._logger

def get_consumers(self, consumer, channel):
return [
consumer(queues=self.queues, callbacks=[self.on_message], accept=["json"])
]


class EventConsumer(Consumer):
def __init__(self, stop_event, connection, queues):
super().__init__(EventTask, stop_event, connection, queues)

def on_message(self, body, received_message):
try:
message = EventMessage(**json.loads(body))
message_digest = hashlib.sha256(body.encode()).hexdigest()
ts = int(message.timestamp.timestamp())
self.logger.debug(f"Message received at {ts} - digest: {message_digest}")
for task in TaskManager.tasks():
if task.enabled is False or task.type != TaskType.event:
continue
if message.event.match(task.compiled_acts_on):
self.logger.info(f"Running task {task.name}")
task.run(message)
except Exception:
self.logger.exception(
f"[PID:{os.getpid()}] - Error processing message in events queue with {body}"
)
received_message.ack()


class LogConsumer(Consumer):
def __init__(self, stop_event, connection, queues):
super().__init__(LogTask, stop_event, connection, queues)

def on_message(self, body, received_message):
try:
message = LogMessage(**json.loads(body))
for task in TaskManager.tasks():
if task.enabled is False or task.type != TaskType.log:
continue
task.run(message)
except Exception:
self.logger.exception(f"Error processing message in logs queue with {body}")
received_message.ack()


class Worker(multiprocessing.Process):
def __init__(self, queue, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stop_event = multiprocessing.Event()
exchange = Exchange(queue, type="direct")
queues = [Queue(queue, exchange, routing_key=queue)]
broker = f"redis://{yeti_config.get('redis', 'host')}/"
self._connection = Connection(broker, heartbeat=4)
self._connection.connect()
self._worker = EventConsumer(self.stop_event, self._connection, queues)

def run(self):
logger.info(f"Worker {self.name} started")
while not self.stop_event.is_set():
try:
self._worker.run()
except Exception:
logger.exception("Consumer failed, restarting")
except KeyboardInterrupt:
logger.info(f"Worker {self.name} exiting...")
self._connection.release()
return


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="yeti-consumer", description="Consume events and logs from the event bus"
)
parser.add_argument(
"--concurrency", type=int, default=None, help="Number of consumers to start"
)
parser.add_argument(
"type", choices=["events", "logs"], help="Type of consumer to start"
)
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
args = parser.parse_args()
logger.setLevel(logging.DEBUG if args.debug else logging.INFO)
if not args.concurrency:
concurrency = multiprocessing.cpu_count()
else:
concurrency = args.concurrency
logger.info(f"Starting {concurrency} {args.type} workers")
processes = []
stop_event = multiprocessing.Event()
for i in range(concurrency):
name = f"{args.type}-worker-{i+1}"
p = Worker(queue=args.type, name=name)
p.start()
logger.info(f"Starting {p.name} pid={p.pid}")
processes.append(p)
try:
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Shutdown requested, exiting gracefully...")
try:
logger.info(f"Terminating worker {p.name} pid={p.pid}")
for p in processes:
p.stop_event.set()
p.join()
logger.info(f"Worker {p.name} pid={p.pid} exited")
except KeyboardInterrupt:
logger.info("Forcefully killing remaining workers")
for p in processes:
p.kill()
logger.info(f"Worker {p.name} pid={p.pid} killed")
Loading

0 comments on commit 46ab0f9

Please sign in to comment.