Skip to content

Commit

Permalink
Addressing changes requested
Browse files Browse the repository at this point in the history
  • Loading branch information
javiber committed May 17, 2024
1 parent 08d0dc0 commit 25f7938
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 32 deletions.
2 changes: 1 addition & 1 deletion temporian/core/event_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3254,7 +3254,7 @@ def moving_quantile(
... features={"value": [np.nan, 1, 5, 10, 15, 20]},
... )
>>> a.moving_quantile(tp.duration.seconds(4), quantile=0.5)
>>> a.moving_quantile(4, quantile=0.5)
indexes: ...
(6 events):
timestamps: [0. 1. 2. 5. 6. 7.]
Expand Down
9 changes: 6 additions & 3 deletions temporian/core/operators/window/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
class BaseWindowOperator(Operator, ABC):
"""Interface definition and common logic for window operators."""

extra_attribute_def: List[Mapping[str, Any]] = []

def __init__(
self,
input: EventSetNode,
Expand Down Expand Up @@ -125,10 +123,15 @@ def has_variable_winlen(self) -> bool:
def add_extra_attributes(self):
pass

@classmethod
def extra_attribute_def(cls) -> List[Mapping[str, Any]]:
return []

@classmethod
def build_op_definition(cls) -> pb.OperatorDef:
extra_attr_def = [
pb.OperatorDef.Attribute(**attr) for attr in cls.extra_attribute_def
pb.OperatorDef.Attribute(**attr)
for attr in cls.extra_attribute_def()
]
return pb.OperatorDef(
key=cls.operator_def_key(),
Expand Down
26 changes: 16 additions & 10 deletions temporian/core/operators/window/moving_quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Moving count operator class and public API function definition."""

from typing import Optional
from typing import List, Mapping, Optional, Any

from temporian.core import operator_lib
from temporian.core.compilation import compile
Expand All @@ -27,14 +27,6 @@


class MovingQuantileOperator(BaseWindowOperator):
extra_attribute_def = [
{
"key": "quantile",
"is_optional": True,
"type": pb.OperatorDef.Attribute.Type.FLOAT_64,
}
]

def __init__(
self,
input: EventSetNode,
Expand All @@ -47,9 +39,13 @@ def __init__(
"`quantile` must be a float between 0 and 1. "
f"Received {quantile}"
)
self.quantile = quantile
self._quantile = quantile
super().__init__(input, window_length, sampling)

@property
def quantile(self) -> float:
return self._quantile

def add_extra_attributes(self):
self.add_attribute("quantile", self.quantile)

Expand All @@ -68,6 +64,16 @@ def get_feature_dtype(self, feature: FeatureSchema) -> DType:
return DType.FLOAT32
return feature.dtype

@classmethod
def extra_attribute_def(cls) -> List[Mapping[str, Any]]:
return [
{
"key": "quantile",
"is_optional": True,
"type": pb.OperatorDef.Attribute.Type.FLOAT_64,
}
]


operator_lib.register_operator(MovingQuantileOperator)

Expand Down
26 changes: 8 additions & 18 deletions temporian/implementation/numpy_cc/operators/custom_heap.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ class CustomHeap {
void push(T value) {
heap.push_back(value);
auto it = std::prev(heap.end());
// Notice that this breaks if a value repeats, not a problem in our case
// since we are using the Heap to store the indices
val_to_node[value] = it;
// TODO: better sorting?
// TODO: there is no better way to insert in order with a list
// but exploring with trees could make this better
while (it != heap.begin()) {
auto parent = std::prev(it);
if (!compare(*parent, *it)) {
break;
}
// TODO: check that this swap is doing what I want
std::swap(*parent, *it);
val_to_node[*it] = it;
val_to_node[*parent] = parent;
Expand All @@ -35,6 +37,8 @@ class CustomHeap {
auto value = heap.back();
heap.pop_back();
auto it = val_to_node.find(value);
// all other pointers in val_to_node are still valid because
// heap is a double linked list
val_to_node.erase(it);
return value;
}
Expand All @@ -52,25 +56,11 @@ class CustomHeap {
auto it = val_to_node.find(value);
if (it != val_to_node.end()) {
heap.erase(it->second);
// all other pointers in val_to_node are still valid because
// heap is a double linked list
val_to_node.erase(it);
} else {
// TODO: exception meant for debugging, remove it
throw std::invalid_argument("removing a value that doesn't exists");
}
}
int size() { return heap.size(); }
int empty() { return heap.empty(); }

void print() {
std::cout << "my_heap{" << std::endl << " [";
std::for_each(heap.begin(), heap.end(),
[](const int n) { std::cout << n << ' '; });
std::cout << "]" << std::endl;

// std::cout << " {" << std::endl;
// for (const auto& pair : val_to_node) {
// std::cout << " " << pair.first << ": " << *(pair.second) << std::endl;
// }
// std::cout << " }" << std::endl;
}
};

0 comments on commit 25f7938

Please sign in to comment.