Skip to content

Commit

Permalink
Merge pull request #166 from I-GUIDE/159-store-submission-type-in-dis…
Browse files Browse the repository at this point in the history
…covery

Store dataset source storage location in discovery collection
  • Loading branch information
pkdash authored Jun 8, 2024
2 parents b0fef37 + 6589b43 commit d0b360b
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 1 deletion.
106 changes: 106 additions & 0 deletions api/models/user.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import List, Optional, TYPE_CHECKING
Expand All @@ -16,6 +17,56 @@ class SubmissionType(str, Enum):
IGUIDE_FORM = 'IGUIDE_FORM'


class StorageProvider(str, Enum):
AWS = "AWS"
GCP = "GCP"
Azure = "Azure"
GoogleDrive = "Google Drive"
Dropbox = "Dropbox"
OneDrive = "OneDrive"
Box = "Box"
CUAHSI = "CUAHSI"


@dataclass
class ContentStorage:
url_pattern: str
storage_name: str

@classmethod
def get_storage(cls, storage_provider: StorageProvider):
if storage_provider == StorageProvider.AWS:
return cls("amazonaws.com", "AWS")

if storage_provider == StorageProvider.GCP:
return cls("storage.googleapis.com", "GCP")

if storage_provider == StorageProvider.Azure:
return cls("blob.core.windows.net", "Azure")

if storage_provider == StorageProvider.GoogleDrive:
return cls("drive.google.com", "Google Drive")

if storage_provider == StorageProvider.Dropbox:
return cls("dropbox.com", "Dropbox")

if storage_provider == StorageProvider.OneDrive:
return cls("onedrive.live.com", "OneDrive")

if storage_provider == StorageProvider.Box:
return cls("app.box.com", "Box")

if storage_provider == StorageProvider.CUAHSI:
return cls("minio.cuahsi.io", "CUAHSI")

def get_storage_name(self, url: Optional[str], repository_identifier: Optional[str]):
if repository_identifier and self.url_pattern in repository_identifier:
return self.storage_name
if url and self.url_pattern in url:
return self.storage_name
return None


class S3Path(BaseModel):
path: str
bucket: str
Expand Down Expand Up @@ -43,6 +94,61 @@ class Submission(Document):
repository_identifier: Optional[str]
s3_path: Optional[S3Path]

@property
def content_location(self):
# determine the content location based on the repository type
if self.repository == SubmissionType.HYDROSHARE:
return self.repository
elif self.repository == SubmissionType.S3:
endpoint_url = self.s3_path.endpoint_url.rstrip("/")
storage = ContentStorage.get_storage(StorageProvider.AWS)
if endpoint_url.endswith(storage.url_pattern):
return storage.storage_name
storage = ContentStorage.get_storage(StorageProvider.CUAHSI)
if endpoint_url.endswith(storage.url_pattern):
return storage.storage_name
return self.repository

# determine the content location based on the URL or repository identifier

# check for GCP
storage = ContentStorage.get_storage(StorageProvider.GCP)
storage_name = storage.get_storage_name(self.url, self.repository_identifier)
if storage_name:
return storage_name

# check for Azure
storage = ContentStorage.get_storage(StorageProvider.Azure)
storage_name = storage.get_storage_name(self.url, self.repository_identifier)
if storage_name:
return storage_name

# check for Google Drive
storage = ContentStorage.get_storage(StorageProvider.GoogleDrive)
storage_name = storage.get_storage_name(self.url, self.repository_identifier)
if storage_name:
return storage_name

# check for dropbox
storage = ContentStorage.get_storage(StorageProvider.Dropbox)
storage_name = storage.get_storage_name(self.url, self.repository_identifier)
if storage_name:
return storage_name

# check for one drive
storage = ContentStorage.get_storage(StorageProvider.OneDrive)
storage_name = storage.get_storage_name(self.url, self.repository_identifier)
if storage_name:
return storage_name

# check for box
storage = ContentStorage.get_storage(StorageProvider.Box)
storage_name = storage.get_storage_name(self.url, self.repository_identifier)
if storage_name:
return storage_name

return self.repository


class User(Document):
access_token: str
Expand Down
30 changes: 29 additions & 1 deletion tests/test_dataset_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from api.adapters.utils import RepositoryType
from api.models.catalog import Submission
from api.models.user import SubmissionType, User
from api.models.user import SubmissionType, User, S3Path

pytestmark = pytest.mark.asyncio

Expand Down Expand Up @@ -102,6 +102,11 @@ async def test_create_dataset_s3(client_test, dataset_data, test_user_access_tok
assert len(submission_response_data) == 1
assert submission_response_data[0]['repository'] == RepositoryType.S3
assert submission_response_data[0]['s3_path'] == s3_path
if object_store_type == "minio":
content_location = "CUAHSI"
else:
content_location = "AWS"
await _check_s3_submission(test_user_access_token, s3_path, content_location)


@pytest.mark.parametrize('object_store_type', ['s3', 'minio'])
Expand Down Expand Up @@ -186,6 +191,11 @@ async def test_update_dataset_s3(client_test, dataset_data, test_user_access_tok
assert len(submission_response_data) == 1
assert submission_response_data[0]['repository'] == RepositoryType.S3
assert submission_response_data[0]['s3_path'] == s3_path
if object_store_type == "minio":
content_location = "CUAHSI"
else:
content_location = "AWS"
await _check_s3_submission(test_user_access_token, s3_path, content_location)


@pytest.mark.asyncio
Expand Down Expand Up @@ -531,3 +541,21 @@ async def _check_hs_submission(hs_dataset, user_access_token, hs_published_res_i
assert user.submission(submission_id) is not None
assert user.submission(submission_id).repository == "HYDROSHARE"
assert user.submission(submission_id).repository_identifier == hs_published_res_id
assert user.submission(submission_id).content_location == "HYDROSHARE"


async def _check_s3_submission(user_access_token, s3_path, content_location="AWS"):
s3_path = S3Path(**s3_path)
# there should be one related submission record in the db
submissions = await Submission.find().to_list()
assert len(submissions) == 1
user = await User.find_one(User.access_token == user_access_token, fetch_links=True)
assert len(user.submissions) == 1
submission = user.submissions[0]
submission_id = submission.identifier
assert submission_id == user.submissions[0].identifier
assert user.submission(submission_id) is not None
assert user.submission(submission_id).repository == "S3"
assert user.submission(submission_id).s3_path == s3_path
assert user.submission(submission_id).repository_identifier == s3_path.identifier
assert submission.content_location == content_location
3 changes: 3 additions & 0 deletions triggers/update_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ async def watch_catalog(db: AsyncIOMotorClient):
submission: Submission = await Submission.find_one({"identifier": document["_id"]})
catalog_entry["registrationDate"] = submission.submitted
catalog_entry["name_for_sorting"] = str.lower(catalog_entry["name"])
catalog_entry["submission_type"] = submission.repository
# location of the dataset files e.g. AWS,GCP, Azure, Hydroshare, CUAHSI, etc.
catalog_entry["content_location"] = submission.content_location
await db["discovery"].find_one_and_replace(
{"_id": document["_id"]}, catalog_entry, upsert=True
)
Expand Down

0 comments on commit d0b360b

Please sign in to comment.