Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arango concurrency #1180

Merged
merged 12 commits into from
Nov 29, 2024
179 changes: 136 additions & 43 deletions core/database_arango.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

TESTING = "unittest" in sys.modules.keys()

ASYNC_JOB_WAIT_TIME = 0.01

TYetiObject = TypeVar("TYetiObject", bound="ArangoYetiConnector")


Expand Down Expand Up @@ -228,35 +230,55 @@ def clear(self, truncate=True):
def collection(self, name):
if self.db is None:
self.connect()

if name not in self.collections:
if self.db.has_collection(name):
self.collections[name] = self.db.collection(name)
async_db = self.db.begin_async_execution(return_result=True)
job = async_db.has_collection(name)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
data = job.result()
if data:
self.collections[name] = async_db.collection(name)
else:
self.collections[name] = self.db.create_collection(name)

job = async_db.create_collection(name)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
self.collections[name] = job.result()
return self.collections[name]

def graph(self, name):
if self.db is None:
self.connect()
if name not in self.graphs:
async_db = self.db.begin_async_execution(return_result=True)
job = async_db.has_graph(name)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
data = job.result()
if data:
graph = async_db.graph(name)
self.graphs[name] = graph
else:
job = async_db.create_graph(name)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
self.graphs[name] = job.result()
return self.graphs[name]

try:
return self.db.create_graph(name)
except GraphCreateError as err:
if err.error_code in [1207, 1925]:
return self.db.graph(name)
raise

# graph is in async context
def create_edge_definition(self, graph, definition):
if self.db is None:
self.connect()

if not self.db.has_collection(definition["edge_collection"]):
collection = graph.create_edge_definition(**definition)
job = graph.create_edge_definition(**definition)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
collection = job.result()
else:
collection = graph.replace_edge_definition(**definition)

job = graph.replace_edge_definition(**definition)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
collection = job.result()
self.collections[definition["edge_collection"]] = collection
return collection

Expand All @@ -282,38 +304,60 @@ def extended_id(self):
return self._collection_name + "/" + self.id

def _insert(self, document_json: str):
document: dict = json.loads(document_json)
newdoc = None
try:
newdoc = self._get_collection().insert(document, return_new=True)["new"]
async_col = self._db.collection(self._collection_name)
job = async_col.insert(json.loads(document_json), return_new=True)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
newdoc = job.result()
newdoc = newdoc["new"]
except DocumentInsertError as err:
if not err.error_code == 1210: # Unique constraint violation
raise
return None
if not newdoc:
return None
newdoc["__id"] = newdoc.pop("_key")
return newdoc

def _update(self, document_json):
# document = self._db.update(self._collection_name, document_json)
document = json.loads(document_json)
doc_id = document.pop("id")
async_col = self._db.collection(self._collection_name)
newdoc = None
if doc_id:
document["_key"] = doc_id
newdoc = self._get_collection().update(document, return_new=True)["new"]
job = async_col.update(document, return_new=True)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
newdoc = job.result()
newdoc = newdoc["new"]
else:
if "value" in document:
filters = {"value": document["value"]}
else:
filters = {"name": document["name"]}
if "type" in document:
filters["type"] = document["type"]
self._get_collection().update_match(filters, document)

logging.debug(f"filters: {filters}")
job = async_col.update_match(filters, document)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
try:
newdoc = list(self._get_collection().find(filters, limit=1))[0]
job = async_col.find(filters, limit=1)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
result = job.result()
newdoc = list(result)[0]
except IndexError as exception:
msg = f"Update failed when adding {document_json}: {exception}"
logging.error(msg)
raise RuntimeError(msg)

if not newdoc:
return None
newdoc["__id"] = newdoc.pop("_key")
return newdoc

Expand Down Expand Up @@ -391,7 +435,11 @@ def get(cls: Type[TYetiObject], id: str) -> TYetiObject | None:

Returns:
A Yeti object."""
document = cls._get_collection().get(id)
async_col = cls._db.collection(cls._collection_name)
job = async_col.get(id)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
document = job.result()
if not document:
return None
document["__id"] = document.pop("_key")
Expand All @@ -410,10 +458,14 @@ def find(cls: Type[TYetiObject], **kwargs) -> TYetiObject | None:
if "type" not in kwargs and getattr(cls, "_type_filter", None):
kwargs["type"] = cls._type_filter

documents = list(cls._get_collection().find(kwargs, limit=1))
async_col = cls._db.collection(cls._collection_name)
job = async_col.find(kwargs, limit=1)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
documents = job.result()
if not documents:
return None
document = documents[0]
document = documents.pop()
document["__id"] = document.pop("_key")
return cls.load(document)

Expand Down Expand Up @@ -522,12 +574,15 @@ def link_to_tag(
fresh=True,
)

result = graph.edge_collection("tagged").link(
job = graph.edge_collection("tagged").link(
self.extended_id,
tag_obj.extended_id,
data=json.loads(tag_relationship.model_dump_json()),
return_new=True,
)["new"]
)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
result = job.result()["new"]
result["__id"] = result.pop("_key")
if self._collection_name != "auditlog":
try:
Expand Down Expand Up @@ -569,13 +624,22 @@ def clear_tags(self):
graph = self._db.graph("tags")

self.get_tags()
results = graph.edge_collection("tagged").edges(self.extended_id)
job = graph.edge_collection("tagged").edges(self.extended_id)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
results = job.result()
for edge in results["edges"]:
if self._collection_name != "auditlog":
try:
tag_relationship = self._db.collection("tagged").get(edge["_id"])
job = self._db.collection("tagged").get(edge["_id"])
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
tag_relationship = job.result()
tag_collection, tag_id = tag_relationship["target"].split("/")
tag_obj = self._db.collection(tag_collection).get(tag_id)
job = self._db.collection(tag_collection).get(tag_id)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
tag_obj = job.result()
event = message.TagEvent(
type=message.EventType.delete,
tagged_object=self,
Expand All @@ -584,7 +648,9 @@ def clear_tags(self):
producer.publish_event(event)
except Exception:
logging.exception("Error while publishing event")
graph.edge_collection("tagged").delete(edge["_id"])
job = graph.edge_collection("tagged").delete(edge["_id"])
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)

def link_to(
self, target, relationship_type: str, description: str
Expand All @@ -599,7 +665,7 @@ def link_to(
# Avoid circular dependency
from core.schemas.graph import Relationship

graph = self._db.graph("threat_graph")
async_graph = self._db.graph("threat_graph")

# Check if a relationship with the same link_type already exists
aql = """
Expand All @@ -624,7 +690,9 @@ def link_to(
relationship.count += 1
edge = json.loads(relationship.model_dump_json())
edge["_id"] = neighbors[0]["_id"]
graph.update_edge(edge)
job = async_graph.update_edge(edge)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
if self._collection_name != "auditlog":
try:
event = message.LinkEvent(
Expand All @@ -647,12 +715,16 @@ def link_to(
created=datetime.datetime.now(datetime.timezone.utc),
modified=datetime.datetime.now(datetime.timezone.utc),
)
result = graph.edge_collection("links").link(
relationship.source,
relationship.target,
col = async_graph.edge_collection("links")
job = col.link(
self.extended_id,
target.extended_id,
data=json.loads(relationship.model_dump_json()),
return_new=True,
)["new"]
)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
result = job.result()["new"]
result["__id"] = result.pop("_key")
relationship = Relationship.load(result)
if self._collection_name != "auditlog":
Expand All @@ -677,7 +749,9 @@ def swap_link(self):
edge["_to"] = self.target
edge["_id"] = f"links/{self.id}"
graph = self._db.graph("threat_graph")
graph.update_edge(edge)
job = graph.update_edge(edge)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
self.save()

# TODO: Consider extracting this to its own class, given it's only meant
Expand Down Expand Up @@ -1129,30 +1203,49 @@ def fulltext_filter(cls, keywords):

def delete(self, all_versions=True):
"""Deletes an object from the database."""
if self._db.graph("threat_graph").has_vertex_collection(self._collection_name):
job = self._db.graph("threat_graph").has_vertex_collection(
self._collection_name
)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
if job.result():
col = self._db.graph("threat_graph").vertex_collection(
self._collection_name
)
else:
col = self._db.collection(self._collection_name)
col.delete(self.id)
job = col.delete(self.id)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
if self._collection_name == "auditlog":
return
try:
event_type = message.EventType.delete
if self._collection_name == "tagged":
source_collection, source_id = self.source.split("/")
tag_collection, tag_id = self.target.split("/")
source_obj = self._db.collection(source_collection).get(source_id)
tag_obj = self._db.collection(tag_collection).get(tag_id)
job = self._db.collection(source_collection).get(source_id)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
source_obj = job.result()
job = self._db.collection(tag_collection).get(tag_id)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
tag_obj = job.result()
event = message.TagEvent(
type=event_type, tagged_object=source_obj, tag_object=tag_obj
)
elif self._collection_name == "links":
source_collection, source_id = self.source.split("/")
target_collection, target_id = self.target.split("/")
source_obj = self._db.collection(source_collection).get(source_id)
target_obj = self._db.collection(target_collection).get(target_id)
job = self._db.collection(source_collection).get(source_id)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
source_obj = job.result()
job = self._db.collection(target_collection).get(target_id)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
target_obj = job.result()
event = message.LinkEvent(
type=event_type,
source_object=source_obj,
Expand Down
Loading
Loading