-
-
Notifications
You must be signed in to change notification settings - Fork 708
/
Copy pathtest_schema.py
235 lines (165 loc) · 6.06 KB
/
test_schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
from typing import TypeVar
from datetime import datetime, date, time
from instructor import openai_schema
from decimal import Decimal
from uuid import UUID
from typing import Annotated, Union, Optional, Literal, Any
from collections import OrderedDict
import pytest
import sys
from pydantic import BaseModel, Field
T = TypeVar("T")
def test_annotation_schema():
class User(BaseModel):
details: dict[
Annotated[str, Field(description="User name", min_length=1)],
Annotated[int, Field(description="User ID", gt=3)],
] = Field(max_length=1)
assert openai_schema(User).model_json_schema() == User.model_json_schema()
class User(BaseModel):
name: str
age: int
class AdminUser(BaseModel):
organization: str
name: str
email: str
def test_new_union_types():
import sys
if sys.version_info >= (3, 10):
class Users(BaseModel):
users: list[AdminUser | User]
assert openai_schema(Users).model_json_schema() == Users.model_json_schema()
def test_old_union_type():
class UsersOldUnion(BaseModel):
users: list[Union[AdminUser, User]]
assert (
openai_schema(UsersOldUnion).model_json_schema()
== UsersOldUnion.model_json_schema()
)
def test_tuple_with_multiple_args():
class TupleModel(BaseModel):
coordinates: tuple[int, int, int]
name_and_age: tuple[str, int]
assert (
openai_schema(TupleModel).model_json_schema() == TupleModel.model_json_schema()
)
def test_dict_with_multiple_value_types():
from collections import OrderedDict
class DictModel(BaseModel):
regular_dict: dict[str, Union[int, str]]
ordered_dict: OrderedDict[str, Union[float, bool]]
assert openai_schema(DictModel).model_json_schema() == DictModel.model_json_schema()
def test_nested_complex_types():
class ComplexModel(BaseModel):
nested_tuple_dict: dict[str, tuple[int, str, bool]]
list_of_dicts: list[dict[str, Union[int, str]]]
assert (
openai_schema(ComplexModel).model_json_schema()
== ComplexModel.model_json_schema()
)
def test_openai_schema_tuple_mapping():
class TestModel(BaseModel):
field: tuple[str, int, int]
assert openai_schema(TestModel).model_json_schema() == TestModel.model_json_schema()
def test_openai_schema_dict_mapping():
class TestModel(BaseModel):
field: dict[str, str]
assert openai_schema(TestModel).model_json_schema() == TestModel.model_json_schema()
def test_openai_schema_ordered_dict_mapping():
class TestModel(BaseModel):
field: OrderedDict[str, int]
assert openai_schema(TestModel).model_json_schema() == TestModel.model_json_schema()
@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
def test_openai_schema_supports_optional_none_310():
class DummyWithOptionalNone(BaseModel):
"""
Class with a single attribute that can be a string or None.
Validates support of UnionType in schema generation.
"""
attr: str | None
assert (
openai_schema(DummyWithOptionalNone).model_json_schema()
== DummyWithOptionalNone.model_json_schema()
)
def test_openai_schema_supports_optional_none() -> None:
class DummyWithOptionalNone(BaseModel):
"""
Class with a single attribute that can be a string or None.
Validates support of UnionType in schema generation.
"""
attr: Optional[str] # In python 3.10+ this is written as `attr: str | None`
attr2: Union[str, None]
assert (
openai_schema(DummyWithOptionalNone).model_json_schema()
== DummyWithOptionalNone.model_json_schema()
)
def test_default_values_and_validators():
class UserWithDefaults(BaseModel):
name: str = "John Doe"
age: int = Field(default=30, ge=0)
assert (
openai_schema(UserWithDefaults).model_json_schema()
== UserWithDefaults.model_json_schema()
)
def test_inheritance():
class BaseUser(BaseModel):
name: str
class ExtendedUser(BaseUser):
age: int
assert (
openai_schema(ExtendedUser).model_json_schema()
== ExtendedUser.model_json_schema()
)
def test_alias_and_field_customization():
class AliasModel(BaseModel):
actual_name: str = Field(..., alias="name")
age: int = Field(..., title="User Age", description="The age of the user")
assert (
openai_schema(AliasModel).model_json_schema() == AliasModel.model_json_schema()
)
def test_standard_python_types():
class StandardTypesModel(BaseModel):
timestamp: datetime
date_field: date
time_field: time
price: Decimal
unique_id: UUID
assert (
openai_schema(StandardTypesModel).model_json_schema()
== StandardTypesModel.model_json_schema()
)
def test_any_type():
class AnyTypeModel(BaseModel):
any_field: Any
assert (
openai_schema(AnyTypeModel).model_json_schema()
== AnyTypeModel.model_json_schema()
)
def test_literal_type():
class LiteralTypeModel(BaseModel):
status: Literal["active", "inactive", "pending"]
assert (
openai_schema(LiteralTypeModel).model_json_schema()
== LiteralTypeModel.model_json_schema()
)
def test_str_any_dict():
import sys
if sys.version_info >= (3, 10):
class ChatResponse(BaseModel):
action_data: dict[str, Any] | None = Field(
default=None,
description="The required data for the action that will be performed.",
)
content: str = Field(
description="A contextual response to the user's message."
)
else:
class ChatResponse(BaseModel):
action_data: Union[dict[str, Any], None] = Field(
default=None,
description="The required data for the action that will be performed.",
)
assert (
openai_schema(ChatResponse).model_json_schema()
== ChatResponse.model_json_schema()
)