Skip to content

Commit

Permalink
Fix find queries to discriminate on type (#1090)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomchop authored Jun 14, 2024
1 parent f0c2c1e commit 2bcff06
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 34 deletions.
3 changes: 3 additions & 0 deletions core/database_arango.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@ def find(cls: Type[TYetiObject], **kwargs) -> TYetiObject | None:
Returns:
A Yeti object.
"""
if "type" not in kwargs and getattr(cls, "_type_filter", None):
kwargs["type"] = cls._type_filter

documents = list(cls._get_collection().find(kwargs, limit=1))
if not documents:
return None
Expand Down
20 changes: 13 additions & 7 deletions core/schemas/dfiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import re
from enum import Enum
from typing import Any, ClassVar, Literal, Type
from typing import Annotated, Any, ClassVar, Literal, Type, Union

import yaml
from pydantic import BaseModel, Field, computed_field
Expand Down Expand Up @@ -198,8 +198,9 @@ def update_parents(self) -> None:


class DFIQScenario(DFIQBase):
description: str
_type_filter: ClassVar[str] = DFIQType.scenario

description: str
type: Literal[DFIQType.scenario] = DFIQType.scenario

@classmethod
Expand All @@ -225,10 +226,10 @@ def from_yaml(cls: Type["DFIQScenario"], yaml_string: str) -> "DFIQScenario":


class DFIQFacet(DFIQBase):
description: str | None
_type_filter: ClassVar[str] = DFIQType.facet

description: str | None
parent_ids: list[str]

type: Literal[DFIQType.facet] = DFIQType.facet

@classmethod
Expand All @@ -255,9 +256,10 @@ def from_yaml(cls: Type["DFIQFacet"], yaml_string: str) -> "DFIQFacet":


class DFIQQuestion(DFIQBase):
_type_filter: ClassVar[str] = DFIQType.question

description: str | None
parent_ids: list[str]

type: Literal[DFIQType.question] = DFIQType.question

@classmethod
Expand Down Expand Up @@ -329,9 +331,10 @@ class DFIQApproachView(BaseModel):


class DFIQApproach(DFIQBase):
_type_filter: ClassVar[str] = DFIQType.approach

description: DFIQApproachDescription
view: DFIQApproachView

type: Literal[DFIQType.approach] = DFIQType.approach

@classmethod
Expand Down Expand Up @@ -375,7 +378,10 @@ def from_yaml(cls: Type["DFIQApproach"], yaml_string: str) -> "DFIQApproach":
}


DFIQTypes = DFIQScenario | DFIQFacet | DFIQQuestion | DFIQApproach
DFIQTypes = Annotated[
Union[DFIQScenario, DFIQFacet, DFIQQuestion, DFIQApproach],
Field(discriminator="type"),
]
DFIQClasses = (
Type[DFIQScenario] | Type[DFIQFacet] | Type[DFIQQuestion] | Type[DFIQApproach]
)
45 changes: 26 additions & 19 deletions core/schemas/entity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
import re
from enum import Enum
from typing import ClassVar, Literal, Type
from typing import Annotated, ClassVar, Literal, Type, Union

from pydantic import Field, computed_field

Expand Down Expand Up @@ -44,9 +44,13 @@ def root_type(self):

@classmethod
def load(cls, object: dict) -> "EntityTypes":
if object["type"] in TYPE_MAPPING:
return TYPE_MAPPING[object["type"]](**object)
raise ValueError("Attempted to instantiate an undefined entity type.")
if cls._type_filter:
loader = TYPE_MAPPING[cls._type_filter]
elif object["type"] in TYPE_MAPPING:
loader = TYPE_MAPPING[object["type"]]
else:
raise ValueError("Attempted to instantiate an undefined entity type.")
return loader(**object)

@classmethod
def is_valid(cls, object: dict) -> bool:
Expand Down Expand Up @@ -211,21 +215,24 @@ def validate_entity(ent: Entity) -> bool:
return True


EntityTypes = (
AttackPattern
| Campaign
| Company
| CourseOfAction
| Identity
| IntrusionSet
| Investigation
| Malware
| Note
| Phone
| ThreatActor
| Tool
| Vulnerability
)
EntityTypes = Annotated[
Union[
AttackPattern,
Campaign,
Company,
CourseOfAction,
Identity,
IntrusionSet,
Investigation,
Malware,
Note,
Phone,
ThreatActor,
Tool,
Vulnerability,
],
Field(discriminator="type"),
]


EntityClasses = (
Expand Down
18 changes: 12 additions & 6 deletions core/schemas/indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import re
from enum import Enum
from typing import ClassVar, Literal, Type
from typing import Annotated, ClassVar, Literal, Type, Union

import yaml
from artifacts import definitions, reader, writer
Expand Down Expand Up @@ -69,10 +69,14 @@ def root_type(self):
return self._root_type

@classmethod
def load(cls, object: dict):
if object["type"] in TYPE_MAPPING:
return TYPE_MAPPING[object["type"]](**object)
return cls(**object)
def load(cls, object: dict) -> "IndicatorTypes":
if cls._type_filter:
loader = TYPE_MAPPING[cls._type_filter]
elif object["type"] in TYPE_MAPPING:
loader = TYPE_MAPPING[object["type"]]
else:
raise ValueError("Attempted to instantiate an undefined indicator type.")
return loader(**object)

def match(self, value: str) -> IndicatorMatch | None:
raise NotImplementedError
Expand Down Expand Up @@ -331,7 +335,9 @@ def save_indicators(self, create_links: bool = False):
"indicators": Indicator,
}

IndicatorTypes = Regex | Yara | Sigma | Query | ForensicArtifact
IndicatorTypes = Annotated[
Union[Regex, Yara, Sigma, Query, ForensicArtifact], Field(discriminator="type")
]
IndicatorClasses = (
Type[Regex] | Type[Yara] | Type[Sigma] | Type[Query] | Type[ForensicArtifact]
)
9 changes: 9 additions & 0 deletions tests/schemas/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def test_entity_get_correct_type(self) -> None:
self.assertIsInstance(result, ThreatActor)
self.assertEqual(result.type, "threat-actor")

def test_entity_dupe_name_type(self) -> None:
oldm = Malware(name="APT123").save()
ta = ThreatActor.find(name="APT123")
m = Malware.find(name="APT123")
self.assertEqual(ta.id, self.ta1.id)
self.assertEqual(m.id, oldm.id)
self.assertIsInstance(m, Malware)
self.assertIsInstance(ta, ThreatActor)

def test_list_entities(self) -> None:
all_entities = list(Entity.list())
threat_actor_entities = list(ThreatActor.list())
Expand Down
29 changes: 27 additions & 2 deletions tests/schemas/indicator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import unittest

from core import database_arango
from core.schemas.indicator import DiamondModel, ForensicArtifact, Indicator, Regex
from core.schemas.indicator import (
DiamondModel,
ForensicArtifact,
Indicator,
Query,
Regex,
)


class IndicatorTest(unittest.TestCase):
Expand All @@ -12,7 +18,7 @@ def setUp(self) -> None:
def tearDown(self) -> None:
database_arango.db.clear()

def test_create_entity(self) -> None:
def test_create_indicator(self) -> None:
result = Regex(
name="regex1",
pattern="asd",
Expand All @@ -39,6 +45,25 @@ def test_filter_entities_different_types(self) -> None:
self.assertEqual(len(regex_entities), 1)
self.assertEqual(regex_entities[0].model_dump_json(), regex.model_dump_json())

def test_create_indicator_same_name_diff_types(self) -> None:
regex = Regex(
name="persistence1",
pattern="asd",
location="any",
diamond=DiamondModel.capability,
).save()
regex2 = Query(
name="persistence1",
pattern="asd",
location="any",
query_type="query",
diamond=DiamondModel.capability,
).save()
self.assertNotEqual(regex.id, regex2.id)
r = Regex.find(name="persistence1")
q = Query.find(name="persistence1")
self.assertNotEqual(r.id, q.id)

def test_regex_match(self) -> None:
regex = Regex(
name="regex1",
Expand Down

0 comments on commit 2bcff06

Please sign in to comment.