Skip to content

Commit

Permalink
feat(schema)!: add spinner schema validation
Browse files Browse the repository at this point in the history
  • Loading branch information
leiteg committed Sep 5, 2024
1 parent 8205d17 commit 8f46d9f
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ dependencies = [
"ipykernel==6.29.4",
"tokenize-rt==5.2.0",
"seaborn==0.13.2",
"scipy==1.13.1"
"scipy==1.13.1",
"pydantic~=2.8",
]

[project.optional-dependencies]
Expand Down
201 changes: 201 additions & 0 deletions spinner/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from __future__ import annotations

import ast
from functools import cache, cached_property
from typing import Any, Literal, Self

from jinja2 import Environment, Template, meta
from pydantic import (
BaseModel,
Field,
PositiveFloat,
PositiveInt,
RootModel,
model_validator,
)


class SpinnerMetadata(BaseModel):
description: str
version: str = Field(pattern=r"v?\d+\.\d+(\.\d+)?")
runs: int = Field(gt=0)
timeout: PositiveFloat | None = Field(default=None, gt=0.0)
retry: bool = Field(default=False)
retry_limit: PositiveInt = Field(default=1, ge=0)


class SpinnerCommand(BaseModel):
template: str

def __hash__(self) -> int:
return hash(self.template)

@cache
def parse(self, env: Environment | None = None) -> Template:
if not env:
env = Environment()
return env.parse(self.template)


SpinnerOutputType = Literal["contains"]


class SpinnerLambda(BaseModel):
name: str
func: str = Field(alias="lambda")

@model_validator(mode="after")
def validate_lambda(self) -> Self:
try:
ast.parse(source=self.func)
except SyntaxError as e:
raise ValueError(f"syntax error: {e}") from e
return self


class SpinnerOutput(BaseModel):
type: SpinnerOutputType
pattern: str
to_float: SpinnerLambda


class SpinnerPlot(BaseModel):
title: str
x_axis: str
y_axis: str
group_by: str

@model_validator(mode="after")
def validate_x_axis(self) -> Self:
# TODO: Implement validation.
return self

@model_validator(mode="after")
def validate_y_axis(self) -> Self:
# TODO: Implement validation.
return self

@model_validator(mode="after")
def validate_group_by(self) -> Self:
# TODO: Implement validation.
return self


class SpinnerApplication(BaseModel):
command: SpinnerCommand
output: list[SpinnerOutput] = Field(default_factory=list)
plot: list[SpinnerPlot] = Field(default_factory=list)

@cached_property
def placeholders(self) -> set[str]:
return meta.find_undeclared_variables(self.command.parse())


class SpinnerApplications(RootModel):
root: dict[str, SpinnerApplication] = Field(default_factory=dict)

def items(self):
return self.root.items()

def __iter__(self):
return iter(self.root)

def __getitem__(self, key) -> SpinnerApplication | None:
return self.root.get(key)


class SpinnerBenchmark(RootModel):
root: dict[str, list[Any]] = Field(default_factory=dict)

def parameters(self):
return self.root.items()

def __iter__(self):
return iter(self.root)

def __getitem__(self, key) -> list[Any] | None:
return self.root.get(key)


class SpinnerBenchmarks(RootModel):
root: dict[str, SpinnerBenchmark] = Field(default_factory=dict)

def items(self):
return self.root.items()

def __iter__(self):
return iter(self.root)

def __getitem__(self, key) -> SpinnerApplication | None:
return self.root.get(key)


class SpinnerConfig(BaseModel):
metadata: SpinnerMetadata
applications: SpinnerApplications = Field(default_factory=dict)
benchmarks: SpinnerBenchmarks = Field(default_factory=dict)

@model_validator(mode="after")
def validate_benchmark_keys(self) -> Self:
for key in self.benchmarks:
if key not in self.applications:
raise ValueError(f"benchmark {key!r} is not an application.")
return self

@model_validator(mode="after")
def validate_benchmark_parameters(self) -> Self:
for name, parameters in self.benchmarks.items():
for parameter in parameters:
if parameter not in self.applications[name].placeholders:
raise ValueError(f"parameter {parameter!r} is not valid.")
return self


if __name__ == "__main__":
import yaml
from pydantic import ValidationError
from rich import print

raw_data = """
metadata:
description: Lorem ipsum
version: v1.0
runs: 10
applications:
example_1:
command:
template:
sleep {{nodes}}
example_2:
command:
template: >
sleep {{sleep_duration}}
output:
- type: contains
pattern: "Runtime: "
to_float:
name: runtime
lambda: >
print("Hello")
- type: contains
pattern: "Runtime: "
to_float:
name: runtime
lambda: >
print("Hello")
benchmarks:
example_1:
nodes: [1, 2, 3]
example_2:
sleep_duration: [1, 2, 3]
"""

data = yaml.safe_load(raw_data)

try:
model = SpinnerConfig(**data)
print(model)
except ValidationError as e:
for error in e.errors():
print("-", error["msg"])
print(" ", error["loc"])

0 comments on commit 8f46d9f

Please sign in to comment.