Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/releases/2.3.0' into docs/custom…
Browse files Browse the repository at this point in the history
…-field

# Conflicts:
#	argilla-frontend/CHANGELOG.md
#	argilla-server/src/argilla_server/search_engine/commons.py
  • Loading branch information
davidberenstein1957 committed Oct 3, 2024
2 parents 2c4dd9b + 7549a45 commit c746743
Show file tree
Hide file tree
Showing 16 changed files with 214 additions and 232 deletions.
2 changes: 2 additions & 0 deletions argilla-frontend/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ These are the section headers that we use:

## [Unreleased]()

## [2.3.0](https://github.com/argilla-io/argilla/compare/v2.2.0...v2.3.0)

### Added

- Added new field `CustomField` [#5462](https://github.com/argilla-io/argilla/pull/5462)
Expand Down
2 changes: 1 addition & 1 deletion argilla-frontend/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "argilla",
"version": "2.3.0dev0",
"version": "2.3.0",
"private": true,
"scripts": {
"dev": "nuxt",
Expand Down
7 changes: 7 additions & 0 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@ These are the section headers that we use:

## [Unreleased]()

## [2.3.0](https://github.com/argilla-io/argilla/compare/v2.2.0...v2.3.0)

### Added

- Added support for `CustomField`. ([#5422](https://github.com/argilla-io/argilla/pull/5422))
- Added helm chart for argilla. ([#5512](https://github.com/argilla-io/argilla/pull/5512))

### Fixed

- Fixed error when creating default user with existing default workspace. ([#5558](https://github.com/argilla-io/argilla/pull/5558))

## [2.2.0](https://github.com/argilla-io/argilla/compare/v2.1.0...v2.2.0)

- Added filtering by `name`, and `status` support to endpoint `GET /api/v1/me/datasets`. ([#5374](https://github.com/argilla-io/argilla/pull/5374))
Expand Down
2 changes: 1 addition & 1 deletion argilla-server/src/argilla_server/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
# coding: utf-8
#

__version__ = "2.3.0dev0"
__version__ = "2.3.0"
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from argilla_server.database import AsyncSessionLocal
from argilla_server.models import User, UserRole, Workspace

from .utils import get_or_new_workspace


async def _create_default(api_key: str, password: str, quiet: bool):
"""Creates a user with default credentials on database suitable to start experimenting with argilla."""
Expand All @@ -37,7 +39,7 @@ async def _create_default(api_key: str, password: str, quiet: bool):
role=UserRole.owner,
api_key=api_key,
password_hash=accounts.hash_password(password),
workspaces=[Workspace(name=DEFAULT_USERNAME)],
workspaces=[await get_or_new_workspace(session, DEFAULT_USERNAME)],
)

if not quiet:
Expand Down
22 changes: 17 additions & 5 deletions argilla-server/src/argilla_server/search_engine/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,7 @@ def es_mapping_for_field(field: Field) -> dict:
elif field.is_custom:
return {
es_field_for_record_field(field.name): {
"type": "object",
"dynamic": True,
"properties": {},
"type": "text",
}
}
elif field.is_image:
Expand Down Expand Up @@ -532,17 +530,19 @@ def _inverse_vector(vector_value: List[float]) -> List[float]:
return [vector_value[i] * -1 for i in range(0, len(vector_value))]

def _map_record_to_es_document(self, record: Record) -> Dict[str, Any]:
dataset = record.dataset

document = {
"id": str(record.id),
"external_id": record.external_id,
"fields": record.fields,
"fields": self._map_record_fields_to_es(record.fields, dataset.fields),
"status": record.status,
"inserted_at": record.inserted_at,
"updated_at": record.updated_at,
}

if record.metadata_:
document["metadata"] = self._map_record_metadata_to_es(record.metadata_, record.dataset.metadata_properties)
document["metadata"] = self._map_record_metadata_to_es(record.metadata_, dataset.metadata_properties)
if record.responses:
document["responses"] = self._map_record_responses_to_es(record.responses)
if record.suggestions:
Expand Down Expand Up @@ -833,6 +833,18 @@ def _map_record_response_to_es(response: Response) -> Dict[str, Any]:
},
}

@classmethod
def _map_record_fields_to_es(cls, fields: dict, dataset_fields: List[Field]) -> dict:
for field in dataset_fields:
if field.is_image:
fields[field.name] = None
elif field.is_custom:
fields[field.name] = str(fields.get(field.name, ""))
else:
fields[field.name] = fields.get(field.name, "")

return fields

async def __terms_aggregation(self, index_name: str, field_name: str, query: dict, size: int) -> List[dict]:
aggregation_name = "terms_agg"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from argilla_server.constants import DEFAULT_API_KEY, DEFAULT_PASSWORD, DEFAULT_USERNAME
from argilla_server.contexts import accounts
from argilla_server.models import User, UserRole
from argilla_server.models import User, UserRole, Workspace
from tests.factories import WorkspaceSyncFactory

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -87,3 +88,13 @@ def test_create_default_with_existent_default_user_and_quiet(sync_db: "Session",
assert result.exit_code == 0
assert result.output == ""
assert sync_db.query(User).count() == 1


def test_create_default_with_existent_default_workspace(sync_db: "Session", cli_runner: "CliRunner", cli: "Typer"):
WorkspaceSyncFactory.create(name=DEFAULT_USERNAME)

result = cli_runner.invoke(cli, "database users create_default")

assert result.exit_code == 0
assert result.output != ""
assert sync_db.query(User).count() == 1
42 changes: 38 additions & 4 deletions argilla-server/tests/unit/search_engine/test_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
VectorSettingsFactory,
ImageFieldFactory,
ChatFieldFactory,
CustomFieldFactory,
)


Expand Down Expand Up @@ -623,6 +624,34 @@ async def test_search_for_chat_field(self, search_engine: BaseElasticAndOpenSear
assert len(result.items) == 2
assert result.total == 2

async def test_search_for_custom_field(self, search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch):
custom_field = await CustomFieldFactory.create(name="field")

dataset = await DatasetFactory.create(fields=[custom_field])

records = await RecordFactory.create_batch(
size=2,
dataset=dataset,
fields={
custom_field.name: {
"a": "This is a value",
"b": 100,
}
},
)

await refresh_dataset(dataset)
await refresh_records(records)

await search_engine.create_index(dataset)
await search_engine.index_records(dataset, records)

for query in ["value", 100]:
result = await search_engine.search(dataset, query=TextQuery(q=query, field=custom_field.name))

assert len(result.items) == 2
assert result.total == 2

@pytest.mark.parametrize(
"statuses, expected_items",
[
Expand Down Expand Up @@ -1064,11 +1093,16 @@ async def test_index_records_with_metadata(
async def test_index_records_with_vectors(
self, search_engine: BaseElasticAndOpenSearchEngine, opensearch: OpenSearch
):
dataset = await DatasetFactory.create()
text_fields = await TextFieldFactory.create_batch(size=5, dataset=dataset)
vectors_settings = await VectorSettingsFactory.create_batch(size=5, dataset=dataset, dimensions=5)
text_fields = await TextFieldFactory.create_batch(size=5)
vectors_settings = await VectorSettingsFactory.create_batch(size=5, dimensions=5)

dataset = await DatasetFactory.create(fields=text_fields, vectors_settings=vectors_settings, questions=[])

records = await RecordFactory.create_batch(
size=5, fields={field.name: f"This is the value for {field.name}" for field in text_fields}, responses=[]
size=5,
fields={field.name: f"This is the value for {field.name}" for field in text_fields},
dataset=dataset,
responses=[],
)

for record in records:
Expand Down
6 changes: 6 additions & 0 deletions argilla/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ These are the section headers that we use:

## [Unreleased]()

## [2.3.0](https://github.com/argilla-io/argilla/compare/v2.2.2...v2.3.0)

### Added

- Added support for `CustomField`. ([#5422](https://github.com/argilla-io/argilla/pull/5422))
- Added `inserted_at` and `updated_at` to `Resource` model as properties. ([#5540](https://github.com/argilla-io/argilla/pull/5540))
- Added `limit` argument when fetching records. ([#5525](https://github.com/argilla-io/argilla/pull/5525)
- Added similarity search support. ((#5546)[https://github.com/argilla-io/argilla/pull/5546])
Expand All @@ -26,9 +29,12 @@ These are the section headers that we use:
### Changed

- Changed the __repr__ method for `SettingsProperties` to display the details of all the properties in `Setting` object. ([#5380](https://github.com/argilla-io/argilla/issues/5380))
- Changed error messages when creating datasets with insufficient permissions. ([#5540](https://github.com/argilla-io/argilla/pull/5554))
-
### Fixed

- Fixed the deployment yaml used to create a new Argilla server in K8s. Added `USERNAME` and `PASSWORD` to the environment variables of pod template. ([#5434](https://github.com/argilla-io/argilla/issues/5434))
- Fixed serialization of `ChatField` when collecting records from the hub and exporting to `datasets`. ([#5554](https://github.com/argilla-io/argilla/pull/5553))

## [2.2.2](https://github.com/argilla-io/argilla/compare/v2.2.1...v2.2.2)

Expand Down
2 changes: 1 addition & 1 deletion argilla/src/argilla/_exceptions/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class BadRequestError(ArgillaAPIError):


class ForbiddenError(ArgillaAPIError):
message = "Forbidden request to the server"
message = "User role is forbidden from performing this action by server"


class NotFoundError(ArgillaAPIError):
Expand Down
2 changes: 1 addition & 1 deletion argilla/src/argilla/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2.3.0dev0"
__version__ = "2.3.0"
21 changes: 19 additions & 2 deletions argilla/src/argilla/datasets/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
from typing import Optional, Union
from uuid import UUID, uuid4

try:
from typing import Self
except ImportError:
from typing_extensions import Self

from argilla._api import DatasetsAPI
from argilla._exceptions import NotFoundError, SettingsError
from argilla._exceptions import NotFoundError, SettingsError, ForbiddenError
from argilla._models import DatasetModel
from argilla._resource import Resource
from argilla.client import Argilla
Expand Down Expand Up @@ -157,7 +162,16 @@ def create(self) -> "Dataset":
Returns:
Dataset: The created dataset object.
"""
super().create()
try:
super().create()
except ForbiddenError as e:
settings_url = f"{self._client.api_url}/user-settings"
user_role = self._client.me.role.value
user_name = self._client.me.username
workspace_name = self.workspace.name
message = f"""User '{user_name}' is not authorized to create a dataset in workspace '{workspace_name}'
with role '{user_role}'. Go to {settings_url} to view your role."""
raise ForbiddenError(message) from e
try:
return self._publish()
except Exception as e:
Expand Down Expand Up @@ -277,3 +291,6 @@ def _sanitize_name(cls, name: str):
for character in ["/", "\\", ".", ",", ";", ":", "-", "+", "="]:
name = name.replace(character, "-")
return name

def _with_client(self, client: Argilla) -> "Self":
return super()._with_client(client=client)
18 changes: 17 additions & 1 deletion argilla/src/argilla/records/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,18 @@ def __init__(self, record: Record, fields: Optional[Dict[str, FieldValue]] = Non
self.record = record

def to_dict(self) -> dict:
return {key: cast_image(value) if self._is_image(key) else value for key, value in self.items()}
fields = {}

for key, value in self.items():
if value is None:
continue
elif self._is_image(key):
fields[key] = cast_image(value)
elif self._is_chat(key):
fields[key] = [message.model_dump() if not isinstance(message, dict) else message for message in value]
else:
fields[key] = value
return fields

def __getitem__(self, key: str) -> FieldValue:
value = super().__getitem__(key)
Expand All @@ -311,6 +322,11 @@ def _is_image(self, key: str) -> bool:
return False
return self.record.dataset.settings.schema[key].type == "image"

def _is_chat(self, key: str) -> bool:
if not self.record.dataset:
return False
return self.record.dataset.settings.schema[key].type == "chat"


class RecordMetadata(dict):
"""This is a container class for the metadata of a Record."""
Expand Down
19 changes: 19 additions & 0 deletions argilla/tests/integration/test_export_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def dataset(client) -> rg.Dataset:
fields=[
rg.TextField(name="text"),
rg.ImageField(name="image"),
rg.ChatField(name="chat"),
],
questions=[
rg.LabelQuestion(name="label", labels=["positive", "negative"]),
Expand All @@ -58,18 +59,36 @@ def mock_data() -> List[dict[str, Any]]:
{
"text": "Hello World, how are you?",
"image": "http://mock.url/image",
"chat": [
{
"role": "user",
"content": "Hello World, how are you?",
}
],
"label": "positive",
"id": uuid.uuid4(),
},
{
"text": "Hello World, how are you?",
"image": "http://mock.url/image",
"chat": [
{
"role": "user",
"content": "Hello World, how are you?",
}
],
"label": "negative",
"id": uuid.uuid4(),
},
{
"text": "Hello World, how are you?",
"image": "http://mock.url/image",
"chat": [
{
"role": "user",
"content": "Hello World, how are you?",
}
],
"label": "positive",
"id": uuid.uuid4(),
},
Expand Down
Loading

0 comments on commit c746743

Please sign in to comment.