Skip to content

Commit

Permalink
fix: Add rolling users validation for oncall shift API
Browse files Browse the repository at this point in the history
  • Loading branch information
ravishankar15 committed Sep 20, 2024
1 parent d9b1196 commit 3f6219b
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 3 deletions.
2 changes: 1 addition & 1 deletion engine/apps/api/serializers/on_call_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OnCallShiftSerializer(EagerLoadingMixin, serializers.ModelSerializer):
allow_null=True,
required=False,
child=UsersFilteredByOrganizationField(
queryset=User.objects, required=False, allow_null=True
queryset=User.objects, db_verification=True, required=False, allow_null=True
), # todo: filter by team?
)
updated_shift = serializers.CharField(read_only=True, allow_null=True, source="updated_shift.public_primary_key")
Expand Down
77 changes: 77 additions & 0 deletions engine/apps/api/tests/test_oncall_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,48 @@ def test_update_future_on_call_shift_removing_users(
assert response.data["rolling_users"][0] == "User(s) are required"


@pytest.mark.django_db
def test_update_on_call_shift_invalid_rolling_users(
on_call_shift_internal_api_setup,
make_on_call_shift,
make_user_auth_headers,
):
token, user1, _, _, schedule = on_call_shift_internal_api_setup

client = APIClient()
start_date = (timezone.now() + timezone.timedelta(days=1)).replace(microsecond=0)

name = "Test Shift Rotation"
on_call_shift = make_on_call_shift(
schedule.organization,
shift_type=CustomOnCallShift.TYPE_ROLLING_USERS_EVENT,
schedule=schedule,
name=name,
start=start_date,
duration=timezone.timedelta(hours=1),
rotation_start=start_date,
rolling_users=[{user1.pk: user1.public_primary_key}],
)
data_to_update = {
"name": name,
"priority_level": 2,
"shift_start": start_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
"shift_end": (start_date + timezone.timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%SZ"),
"rotation_start": start_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
"until": None,
"frequency": None,
"interval": None,
"by_day": None,
"rolling_users": [["fuzz"]],
}

url = reverse("api-internal:oncall_shifts-detail", kwargs={"pk": on_call_shift.public_primary_key})
response = client.put(url, data=data_to_update, format="json", **make_user_auth_headers(user1, token))

assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json() == {"rolling_users": {"0": ["User does not exist fuzz"]}}


@pytest.mark.django_db
def test_update_started_on_call_shift(
on_call_shift_internal_api_setup,
Expand Down Expand Up @@ -1202,6 +1244,41 @@ def test_create_on_call_shift_invalid_data_rolling_users(
assert response.data["rolling_users"][0] == "Cannot set multiple user groups for non-recurrent shifts"


@pytest.mark.django_db
def test_create_on_call_shift_invalid_rolling_users(on_call_shift_internal_api_setup, make_user_auth_headers):
token, user1, user2, _, schedule = on_call_shift_internal_api_setup
client = APIClient()
url = reverse("api-internal:oncall_shifts-list")
start_date = timezone.now().replace(microsecond=0, tzinfo=None)

data = {
"name": "Test Shift",
"type": CustomOnCallShift.TYPE_ROLLING_USERS_EVENT,
"schedule": schedule.public_primary_key,
"priority_level": 1,
"shift_start": start_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
"shift_end": (start_date + timezone.timedelta(hours=2)).strftime("%Y-%m-%dT%H:%M:%SZ"),
"rotation_start": start_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
"until": None,
"frequency": 1,
"interval": 1,
"by_day": [
CustomOnCallShift.ICAL_WEEKDAY_MAP[CustomOnCallShift.MONDAY],
CustomOnCallShift.ICAL_WEEKDAY_MAP[CustomOnCallShift.FRIDAY],
],
"week_start": CustomOnCallShift.ICAL_WEEKDAY_MAP[CustomOnCallShift.MONDAY],
"rolling_users": [[user1.public_primary_key], [user2.public_primary_key, "fuzz"]],
}

with patch("apps.schedules.models.CustomOnCallShift.refresh_schedule") as mock_refresh_schedule:
response = client.post(url, data, format="json", **make_user_auth_headers(user1, token))

expected_payload = {"rolling_users": {"1": ["User does not exist fuzz"]}}
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json() == expected_payload
mock_refresh_schedule.assert_not_called()


@pytest.mark.django_db
def test_create_on_call_shift_override_invalid_data(on_call_shift_internal_api_setup, make_user_auth_headers):
token, user1, _, _, schedule = on_call_shift_internal_api_setup
Expand Down
4 changes: 3 additions & 1 deletion engine/apps/public_api/serializers/on_call_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ class CustomOnCallShiftSerializer(EagerLoadingMixin, serializers.ModelSerializer
rolling_users = RollingUsersField(
allow_null=True,
required=False,
child=UsersFilteredByOrganizationField(queryset=User.objects, required=False, allow_null=True),
child=UsersFilteredByOrganizationField(
queryset=User.objects, db_verification=True, required=False, allow_null=True
),
)
rotation_start = serializers.DateTimeField(required=False)

Expand Down
61 changes: 61 additions & 0 deletions engine/apps/public_api/tests/test_on_call_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,38 @@ def test_create_on_call_shift_invalid_time_zone(make_organization_and_user_with_
assert response.json() == {"time_zone": ["Invalid timezone"]}


@pytest.mark.django_db
def test_create_on_call_shift_invalid_rolling_users(make_organization_and_user_with_token):
_, user, token = make_organization_and_user_with_token()
client = APIClient()

url = reverse("api-public:on_call_shifts-list")

start = timezone.now()
until = start + timezone.timedelta(days=30)
data = {
"team_id": None,
"name": "test name",
"type": "rolling_users",
"level": 1,
"start": start.strftime("%Y-%m-%dT%H:%M:%S"),
"rotation_start": start.strftime("%Y-%m-%dT%H:%M:%S"),
"duration": 10800,
"week_start": "MO",
"frequency": "weekly",
"interval": 2,
"until": until.strftime("%Y-%m-%dT%H:%M:%S"),
"by_day": ["MO", "WE", "FR"],
"time_zone": None,
"rolling_users": [[user.public_primary_key], ["fuzz"]],
}

response = client.post(url, data=data, format="json", HTTP_AUTHORIZATION=f"{token}")

assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json() == {"rolling_users": {"1": ["User does not exist fuzz"]}}


@pytest.mark.django_db
def test_update_on_call_shift(make_organization_and_user_with_token, make_on_call_shift, make_schedule):
organization, user, token = make_organization_and_user_with_token()
Expand Down Expand Up @@ -633,6 +665,35 @@ def test_update_on_call_shift_invalid_field(make_organization_and_user_with_toke
assert response.status_code == status.HTTP_400_BAD_REQUEST


@pytest.mark.django_db
def test_update_on_call_shift_invalid_rolling_users(make_organization_and_user_with_token, make_on_call_shift):
organization, user, token = make_organization_and_user_with_token()
client = APIClient()

start_date = timezone.now().replace(microsecond=0)
data = {
"start": start_date,
"rotation_start": start_date,
"duration": timezone.timedelta(seconds=7200),
"frequency": CustomOnCallShift.FREQUENCY_WEEKLY,
"interval": 2,
"by_day": ["MO", "FR"],
"rolling_users": [[user.public_primary_key]],
}

data_to_update = {"rolling_users": [[user.public_primary_key], ["fuzz"]]}

on_call_shift = make_on_call_shift(
organization=organization, shift_type=CustomOnCallShift.TYPE_ROLLING_USERS_EVENT, **data
)

url = reverse("api-public:on_call_shifts-detail", kwargs={"pk": on_call_shift.public_primary_key})

response = client.put(url, data=data_to_update, format="json", HTTP_AUTHORIZATION=f"{token}")

assert response.status_code == status.HTTP_400_BAD_REQUEST


@pytest.mark.django_db
def test_delete_on_call_shift(make_organization_and_user_with_token, make_on_call_shift):
organization, _, token = make_organization_and_user_with_token()
Expand Down
12 changes: 11 additions & 1 deletion engine/common/api_helpers/custom_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class UsersFilteredByOrganizationField(serializers.Field):

def __init__(self, **kwargs):
self.queryset = kwargs.pop("queryset", None)
self.db_verification = kwargs.pop("db_verification", False)
super().__init__(**kwargs)

def to_representation(self, value):
Expand All @@ -102,7 +103,16 @@ def to_internal_value(self, data):
if not request or not queryset:
return None

return queryset.filter(organization=request.user.organization, public_primary_key__in=data).distinct()
users = queryset.filter(organization=request.user.organization, public_primary_key__in=data).distinct()
users_ppk = [u.public_primary_key for u in users]

if not self.db_verification:
return users

for d in data:
if d not in users_ppk:
raise ValidationError(f"User does not exist {d}")
return users


class IntegrationFilteredByOrganizationField(serializers.RelatedField):
Expand Down

0 comments on commit 3f6219b

Please sign in to comment.