diff --git a/.run/batch_document_ocr.run.xml b/.run/batch_document_ocr.run.xml
index b2fe5cd7..851d42a6 100644
--- a/.run/batch_document_ocr.run.xml
+++ b/.run/batch_document_ocr.run.xml
@@ -14,7 +14,7 @@
-
+
diff --git a/.run/marie gateway.run.xml b/.run/marie gateway - params.run.xml
similarity index 78%
rename from .run/marie gateway.run.xml
rename to .run/marie gateway - params.run.xml
index bf379305..8040840f 100644
--- a/.run/marie gateway.run.xml
+++ b/.run/marie gateway - params.run.xml
@@ -1,6 +1,7 @@
-
+
+
@@ -13,7 +14,7 @@
-
+
diff --git a/.run/marie gateway-FILE.run.xml b/.run/marie gateway-FILE.run.xml
new file mode 100644
index 00000000..79a15dee
--- /dev/null
+++ b/.run/marie gateway-FILE.run.xml
@@ -0,0 +1,25 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/README-GB.md b/README-GB.md
index c42c3c5e..1ba8cef2 100644
--- a/README-GB.md
+++ b/README-GB.md
@@ -476,6 +476,9 @@ https://github.com/fioresxcat/VAT_245/tree/fa526ac7e2ce9bb392ca66bd86305d69caee7
# Table Transformer and Table Detection
https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Using_Table_Transformer_for_table_detection_and_table_structure_recognition.ipynb
+# IDEAS
+dedoc
+https://github.com/ispras/dedoc/blob/master/dedoc/structure_constructors/abstract_structure_constructor.py
https://cloud.google.com/document-ai
@@ -505,4 +508,20 @@ event.json
act.secrets
```
MARIE_CORE_RELEASE_TOKEN=ghp_ABC
-```
\ No newline at end of file
+```
+
+## Pydantic
+```bash
+pydantic 1.10.15
+pydantic_core 2.10.1
+```
+
+
+
+
+
+# Rewriting history
+
+```bash
+ git filter-repo --mailmap mailmap --force
+```
diff --git a/README.md b/README.md
index 9e1d4ec9..cfe938a5 100644
--- a/README.md
+++ b/README.md
@@ -127,6 +127,13 @@ aws s3 cp some_file.txt s3://mybucket --profile marie --endpoint-url http://loc
aws s3 --profile marie --endpoint-url=http://127.0.0.1:8000 ls --recursive s3://
```
+Remove files from the bucket
+```shell
+aws s3 rm s3://marie --recursive --profile marie --endpoint-url http://localhost:8000
+```
+
+
+
# Production setup
diff --git a/config/service/bones.yml b/config/service/bones.yml
new file mode 100644
index 00000000..74750304
--- /dev/null
+++ b/config/service/bones.yml
@@ -0,0 +1,278 @@
+jtype: Flow
+version: '1'
+protocol: grpc
+
+# Shared configuration
+shared_config:
+ storage: &storage
+ psql: &psql_conf_shared
+ provider: postgresql
+ hostname: 127.0.0.1
+ port: 5432
+ username: postgres
+ password: 123456
+ database: postgres
+ default_table: shared_docs
+
+ message: &message
+ amazon_mq : &amazon_mq_conf_shared
+ provider: amazon-rabbitmq
+ hostname: ${{ ENV.AWS_MQ_HOSTNAME }}
+ port: 15672
+ username: ${{ ENV.AWS_MQ_USERNAME }}
+ password: ${{ ENV.AWS_MQ_PASSWORD }}
+ tls: True
+ virtualhost: /
+
+
+ rabbitmq : &rabbitmq_conf_shared
+ provider: rabbitmq
+ hostname: ${{ ENV.RABBIT_MQ_HOSTNAME }}
+ port: ${{ ENV.RABBIT_MQ_PORT }}
+ username: ${{ ENV.RABBIT_MQ_USERNAME }}
+ password: ${{ ENV.RABBIT_MQ_PASSWORD }}
+ tls: False
+ virtualhost: /
+
+
+# Toast event tracking system
+# It can be backed by Message Queue and Database backed
+toast:
+ native:
+ enabled: True
+ path: /tmp/marie/events.json
+ rabbitmq:
+ <<: *rabbitmq_conf_shared
+ enabled : False
+ psql:
+ <<: *psql_conf_shared
+ default_table: event_tracking
+ enabled : False
+
+
+# Document Storage
+# The storage service is used to store the data that is being processed
+# Storage can be backed by S3 compatible
+
+storage:
+ # S3 configuration. Will be used only if value of backend is "s3"
+ s3:
+ enabled: False
+ metadata_only: False # If True, only metadata will be stored in the storage backend
+ # api endpoint to connect to. use AWS S3 or any S3 compatible object storage endpoint.
+ endpoint_url: ${{ ENV.S3_ENDPOINT_URL }}
+ # optional.
+ # access key id when using static credentials.
+ access_key_id: ${{ ENV.S3_ACCESS_KEY_ID }}
+ # optional.
+ # secret key when using static credentials.
+ secret_access_key: ${{ ENV.S3_SECRET_ACCESS_KEY }}
+ # Bucket name in s3
+ bucket_name: ${{ ENV.S3_BUCKET_NAME }}
+ # optional.
+ # Example: "region: us-east-2"
+ region: ${{ ENV.S3_REGION }}
+ # optional.
+ # enable if endpoint is http
+ insecure: True
+ # optional.
+ # enable if you want to use path style requests
+ addressing_style: path
+
+ # postgresql configuration. Will be used only if value of backend is "psql"
+ psql:
+ <<: *psql_conf_shared
+ default_table: store_metadata
+ enabled : False
+
+# Job Queue scheduler
+scheduler:
+ psql:
+ <<: *psql_conf_shared
+ default_table: job_queue
+ enabled : True
+
+# FLOW / GATEWAY configuration
+
+with:
+ port:
+ - 51000
+ - 52000
+ protocol:
+ - http
+ - grpc
+ discovery: True
+ discovery_host: 127.0.0.1
+ discovery_port: 8500
+
+ host: 127.0.0.1
+
+ # monitoring
+ monitoring: true
+ port_monitoring: 57843
+
+ event_tracking: True
+
+ expose_endpoints:
+ /document/extract:
+ methods: ["POST"]
+ summary: Extract data-POC
+ tags:
+ - extract
+ /status:
+ methods: ["POST"]
+ summary: Status
+ tags:
+ - extract
+
+ /text/status:
+ methods: ["POST"]
+ summary: Extract data
+ tags:
+ - extract
+
+ /ner/extract:
+ methods: ["POST"]
+ summary: Extract NER
+ tags:
+ - ner
+
+ /document/classify:
+ methods: ["POST"]
+ summary: Classify document at page level
+ tags:
+ - classify
+
+prefetch: 4
+
+executors:
+# - name: extract_executor
+# uses:
+# jtype: TextExtractionExecutorMock
+# metas:
+# py_modules:
+# - marie.executor.text
+# timeout_ready: 3000000
+# replicas: 1
+## replicas: ${{ CONTEXT.gpu_device_count }}
+# env :
+# CUDA_VISIBLE_DEVICES: RR
+
+ - name: extract_t
+ uses:
+ jtype: TextExtractionExecutor
+# jtype: TextExtractionExecutorMock
+ with:
+ storage:
+ # postgresql configuration. Will be used only if value of backend is "psql"
+ psql:
+ <<: *psql_conf_shared
+ default_table: extract_metadata
+ enabled: True
+ pipeline:
+ name: 'default'
+ page_classifier:
+ - model_name_or_path: 'marie/lmv3-medical-document-classification'
+ type: 'transformers'
+ device: 'cuda'
+ enabled: False
+ name: 'medical_page_classifier'
+ - model_name_or_path: 'marie/lmv3-medical-document-payer'
+ type: 'transformers'
+ enabled: False
+ device: 'cuda'
+ name: 'medical_payer_classifier'
+ page_indexer:
+ - model_name_or_path: 'marie/layoutlmv3-medical-document-indexer'
+ enabled: False
+ type: 'transformers'
+ device: 'cuda'
+ name: 'page_indexer_patient'
+ filter:
+ type: 'regex'
+ pattern: '.*'
+ page_splitter:
+ model_name_or_path: 'marie/layoutlmv3-medical-document-splitter'
+ enabled: True
+ metas:
+ py_modules:
+ - marie.executor.text
+ timeout_ready: 3000000
+ replicas: 1
+ # replicas: ${{ CONTEXT.gpu_device_count }}
+ env:
+ CUDA_VISIBLE_DEVICES: RR
+
+# - name: extract_xyz
+# uses:
+# jtype: TextExtractionExecutorMock
+# metas:
+# py_modules:
+# - marie.executor.text
+# timeout_ready: 3000000
+# replicas: 1
+## replicas: ${{ CONTEXT.gpu_device_count }}
+# env :
+# CUDA_VISIBLE_DEVICES: RR
+
+# - name: ner_t
+# uses:
+# jtype: NerExtractionExecutor
+# with:
+# model_name_or_path : 'rms/layoutlmv3-large-corr-ner'
+# <<: *psql_conf_shared
+# storage_enabled : False
+# metas:
+# py_modules:
+## - marie_server.executors.ner.mserve_torch
+# - marie.executor.ner
+# timeout_ready: 3000000
+## replicas: 1
+# replicas: ${{ CONTEXT.gpu_device_count }}
+# env :
+# CUDA_VISIBLE_DEVICES: RR
+
+# - name: document_classifier
+# uses:
+# jtype: DocumentClassificationExecutor
+# with:
+# model_name_or_path :
+# - 'marie/layoutlmv3-document-classification'
+# - 'marie/layoutlmv3-document-classification'
+# <<: *psql_conf_shared
+# storage_enabled : False
+# metas:
+# py_modules:
+# - marie.executor.classifier
+# timeout_ready: 3000000
+## replicas: 1
+# replicas: ${{ CONTEXT.gpu_device_count }}
+# env :
+# CUDA_VISIBLE_DEVICES: RR
+##
+# - name: overlay_t
+# uses:
+# jtype: OverlayExecutor
+# with:
+# model_name_or_path : 'rms/holder'
+# <<: *storage_conf
+# storage_enabled : True
+# metas:
+# py_modules:
+# - marie.executor.overlay
+# timeout_ready: 3000000
+# replicas: 1
+
+# Authentication and Authorization configuration
+
+auth:
+ keys:
+ - name : service-A
+ api_key : mas_0aPJ9Q9nUO1Ac1vJTfffXEXs9FyGLf9BzfYgZ_RaHm707wmbfHJNPQ
+ enabled : True
+ roles : [admin, user]
+
+ - name : service-B
+ api_key : mau_t6qDi1BcL1NkLI8I6iM8z1va0nZP01UQ6LWecpbDz6mbxWgIIIZPfQ
+ enabled : True
+ roles : [admin, user]
diff --git a/config/service/deployment.yml b/config/service/deployment.yml
index 7710894c..5e22979d 100644
--- a/config/service/deployment.yml
+++ b/config/service/deployment.yml
@@ -5,11 +5,15 @@ with:
metas:
py_modules:
- marie.executor.text
+ replicas: 1
+ name: ext_dep
+ protocol: [grpc]
- replicas: 3
- name: extract_exec
- protocol: [grpc, http]
-# port: [54321, 54322, 54323]
-# port: [51000, 52000, 53000]
-# discovery: True
-# discovery_host:
\ No newline at end of file
+# # monitoring
+# monitoring: true
+# port_monitoring: 57843
+
+# host: 0.0.0.0
+# port: [52000]
+
+ #port: [52000, 51000]
diff --git a/config/service/gateway.yml b/config/service/gateway.yml
new file mode 100644
index 00000000..1e57ce25
--- /dev/null
+++ b/config/service/gateway.yml
@@ -0,0 +1,16 @@
+!HTTPGateway
+gateway:
+ title: "Marie Gateway"
+ cors: true
+ protocol:
+ - http
+ - grpc
+
+ # monitoring
+ monitoring: true
+ port_monitoring: 57843
+
+ host: 0.0.0.0
+ # host: 192.168.1.11
+ port: [52000]
+
diff --git a/config/service/marie-dev.yml b/config/service/marie-dev.yml
new file mode 100644
index 00000000..4b62279b
--- /dev/null
+++ b/config/service/marie-dev.yml
@@ -0,0 +1,277 @@
+jtype: Flow
+version: '1'
+protocol: grpc
+
+# Shared configuration
+shared_config:
+ storage: &storage
+ psql: &psql_conf_shared
+ provider: postgresql
+ hostname: 127.0.0.1
+ port: 5432
+ username: postgres
+ password: 123456
+ database: postgres
+ default_table: shared_docs
+
+ message: &message
+ amazon_mq : &amazon_mq_conf_shared
+ provider: amazon-rabbitmq
+ hostname: ${{ ENV.AWS_MQ_HOSTNAME }}
+ port: 15672
+ username: ${{ ENV.AWS_MQ_USERNAME }}
+ password: ${{ ENV.AWS_MQ_PASSWORD }}
+ tls: True
+ virtualhost: /
+
+
+ rabbitmq : &rabbitmq_conf_shared
+ provider: rabbitmq
+ hostname: ${{ ENV.RABBIT_MQ_HOSTNAME }}
+ port: ${{ ENV.RABBIT_MQ_PORT }}
+ username: ${{ ENV.RABBIT_MQ_USERNAME }}
+ password: ${{ ENV.RABBIT_MQ_PASSWORD }}
+ tls: False
+ virtualhost: /
+
+
+# Toast event tracking system
+# It can be backed by Message Queue and Database backed
+toast:
+ native:
+ enabled: True
+ path: /tmp/marie/events.json
+ rabbitmq:
+ <<: *rabbitmq_conf_shared
+ enabled : True
+ psql:
+ <<: *psql_conf_shared
+ default_table: event_tracking
+ enabled : True
+
+# Document Storage
+# The storage service is used to store the data that is being processed
+# Storage can be backed by S3 compatible
+
+storage:
+ # S3 configuration. Will be used only if value of backend is "s3"
+ s3:
+ enabled: True
+ metadata_only: False # If True, only metadata will be stored in the storage backend
+ # api endpoint to connect to. use AWS S3 or any S3 compatible object storage endpoint.
+ endpoint_url: ${{ ENV.S3_ENDPOINT_URL }}
+ # optional.
+ # access key id when using static credentials.
+ access_key_id: ${{ ENV.S3_ACCESS_KEY_ID }}
+ # optional.
+ # secret key when using static credentials.
+ secret_access_key: ${{ ENV.S3_SECRET_ACCESS_KEY }}
+ # Bucket name in s3
+ bucket_name: ${{ ENV.S3_BUCKET_NAME }}
+ # optional.
+ # Example: "region: us-east-2"
+ region: ${{ ENV.S3_REGION }}
+ # optional.
+ # enable if endpoint is http
+ insecure: True
+ # optional.
+ # enable if you want to use path style requests
+ addressing_style: path
+
+ # postgresql configuration. Will be used only if value of backend is "psql"
+ psql:
+ <<: *psql_conf_shared
+ default_table: store_metadata
+ enabled : False
+
+# Job Queue scheduler
+scheduler:
+ psql:
+ <<: *psql_conf_shared
+ default_table: job_queue
+ enabled : True
+
+# FLOW / GATEWAY configuration
+
+with:
+ port:
+ - 51000
+ - 52000
+ protocol:
+ - http
+ - grpc
+ discovery: True
+ discovery_host: 127.0.0.1
+ discovery_port: 8500
+
+ host: 127.0.0.1
+
+ # monitoring
+ monitoring: true
+ port_monitoring: 57843
+
+ event_tracking: True
+
+ expose_endpoints:
+ /document/extract:
+ methods: ["POST"]
+ summary: Extract data-POC
+ tags:
+ - extract
+ /status:
+ methods: ["POST"]
+ summary: Status
+ tags:
+ - extract
+
+ /text/status:
+ methods: ["POST"]
+ summary: Extract data
+ tags:
+ - extract
+
+ /ner/extract:
+ methods: ["POST"]
+ summary: Extract NER
+ tags:
+ - ner
+
+ /document/classify:
+ methods: ["POST"]
+ summary: Classify document at page level
+ tags:
+ - classify
+
+prefetch: 4
+
+executors:
+# - name: extract_executor
+# uses:
+# jtype: TextExtractionExecutorMock
+# metas:
+# py_modules:
+# - marie.executor.text
+# timeout_ready: 3000000
+# replicas: 1
+## replicas: ${{ CONTEXT.gpu_device_count }}
+# env :
+# CUDA_VISIBLE_DEVICES: RR
+
+ - name: extract_t
+ uses:
+# jtype: TextExtractionExecutor
+ jtype: TextExtractionExecutorMock
+ with:
+ storage:
+ # postgresql configuration. Will be used only if value of backend is "psql"
+ psql:
+ <<: *psql_conf_shared
+ default_table: extract_metadata
+ enabled: True
+ pipeline:
+ name: 'default'
+ page_classifier:
+ - model_name_or_path: 'marie/lmv3-medical-document-classification'
+ type: 'transformers'
+ device: 'cuda'
+ enabled: False
+ name: 'medical_page_classifier'
+ - model_name_or_path: 'marie/lmv3-medical-document-payer'
+ type: 'transformers'
+ enabled: False
+ device: 'cuda'
+ name: 'medical_payer_classifier'
+ page_indexer:
+ - model_name_or_path: 'marie/layoutlmv3-medical-document-indexer'
+ enabled: False
+ type: 'transformers'
+ device: 'cuda'
+ name: 'page_indexer_patient'
+ filter:
+ type: 'regex'
+ pattern: '.*'
+ page_splitter:
+ model_name_or_path: 'marie/layoutlmv3-medical-document-splitter'
+ enabled: True
+ metas:
+ py_modules:
+ - marie.executor.text
+ timeout_ready: 3000000
+ replicas: 4
+ # replicas: ${{ CONTEXT.gpu_device_count }}
+ env:
+ CUDA_VISIBLE_DEVICES: RR
+
+# - name: extract_xyz
+# uses:
+# jtype: TextExtractionExecutorMock
+# metas:
+# py_modules:
+# - marie.executor.text
+# timeout_ready: 3000000
+# replicas: 1
+## replicas: ${{ CONTEXT.gpu_device_count }}
+# env :
+# CUDA_VISIBLE_DEVICES: RR
+
+# - name: ner_t
+# uses:
+# jtype: NerExtractionExecutor
+# with:
+# model_name_or_path : 'rms/layoutlmv3-large-corr-ner'
+# <<: *psql_conf_shared
+# storage_enabled : False
+# metas:
+# py_modules:
+## - marie_server.executors.ner.mserve_torch
+# - marie.executor.ner
+# timeout_ready: 3000000
+## replicas: 1
+# replicas: ${{ CONTEXT.gpu_device_count }}
+# env :
+# CUDA_VISIBLE_DEVICES: RR
+
+# - name: document_classifier
+# uses:
+# jtype: DocumentClassificationExecutor
+# with:
+# model_name_or_path :
+# - 'marie/layoutlmv3-document-classification'
+# - 'marie/layoutlmv3-document-classification'
+# <<: *psql_conf_shared
+# storage_enabled : False
+# metas:
+# py_modules:
+# - marie.executor.classifier
+# timeout_ready: 3000000
+## replicas: 1
+# replicas: ${{ CONTEXT.gpu_device_count }}
+# env :
+# CUDA_VISIBLE_DEVICES: RR
+##
+# - name: overlay_t
+# uses:
+# jtype: OverlayExecutor
+# with:
+# model_name_or_path : 'rms/holder'
+# <<: *storage_conf
+# storage_enabled : True
+# metas:
+# py_modules:
+# - marie.executor.overlay
+# timeout_ready: 3000000
+# replicas: 1
+
+# Authentication and Authorization configuration
+
+auth:
+ keys:
+ - name : service-A
+ api_key : mas_0aPJ9Q9nUO1Ac1vJTfffXEXs9FyGLf9BzfYgZ_RaHm707wmbfHJNPQ
+ enabled : True
+ roles : [admin, user]
+
+ - name : service-B
+ api_key : mau_t6qDi1BcL1NkLI8I6iM8z1va0nZP01UQ6LWecpbDz6mbxWgIIIZPfQ
+ enabled : True
+ roles : [admin, user]
diff --git a/config/service/marie.yml b/config/service/marie.yml
index 43661cb7..de42db83 100644
--- a/config/service/marie.yml
+++ b/config/service/marie.yml
@@ -142,7 +142,7 @@ with:
tags:
- classify
-prefetch: 4
+prefetch: 1
executors:
# - name: extract_executor
@@ -159,8 +159,8 @@ executors:
- name: extract_t
uses:
- jtype: TextExtractionExecutor
-# jtype: TextExtractionExecutorMock
+# jtype: TextExtractionExecutor
+ jtype: TextExtractionExecutorMock
with:
storage:
# postgresql configuration. Will be used only if value of backend is "psql"
diff --git a/config/zoo/unilm/dit/object_detection/document_boundary/maskrcnn/maskrcnn_dit_base.yaml b/config/zoo/unilm/dit/object_detection/document_boundary/maskrcnn/maskrcnn_dit_base.yaml
index f2b69ba0..53017f2a 100644
--- a/config/zoo/unilm/dit/object_detection/document_boundary/maskrcnn/maskrcnn_dit_base.yaml
+++ b/config/zoo/unilm/dit/object_detection/document_boundary/maskrcnn/maskrcnn_dit_base.yaml
@@ -3,7 +3,8 @@ MODEL:
PIXEL_MEAN: [ 127.5, 127.5, 127.5 ]
PIXEL_STD: [ 127.5, 127.5, 127.5 ]
# WEIGHTS: "https://layoutlm.blob.core.windows.net/dit/dit-pts/dit-base-224-p16-500k-62d53a.pth"
- WEIGHTS: "unilm/dit/object_detection/document_boundary/model_final.pth" # 040102024
+# WEIGHTS: "unilm/dit/object_detection/document_boundary/model_final.pth" # 040102024
+ WEIGHTS: "unilm/dit/object_detection/document_boundary/model_0039999.pth" # 040102024
VIT:
NAME: "dit_base_patch16"
diff --git a/config/zoo/unilm/dit/text_detection/mask_rcnn_dit_prod.yaml b/config/zoo/unilm/dit/text_detection/mask_rcnn_dit_prod.yaml
index 8389b0e4..e90e67ed 100644
--- a/config/zoo/unilm/dit/text_detection/mask_rcnn_dit_prod.yaml
+++ b/config/zoo/unilm/dit/text_detection/mask_rcnn_dit_prod.yaml
@@ -8,7 +8,8 @@ MODEL:
# WEIGHTS: "unilm/dit/text_detection/tuned-4000-LARGE/model_0095999.pth" # GOOD- IN PROD AS OF 04/08/2024
# WEIGHTS: "unilm/dit/text_detection/tuned-4000-LARGE-04082024/model_0109999.pth" # GOOD
# WEIGHTS: "unilm/dit/text_detection/tuned-4000-LARGE-04082024/model_0145999.pth" # 04-11-2024
- WEIGHTS: "unilm/dit/text_detection/tuned-4000-LARGE-04082024/model_0155999.pth" # 04-16-2024
+# WEIGHTS: "unilm/dit/text_detection/tuned-4000-LARGE-04082024/model_0155999.pth" # 04-16-2024
+ WEIGHTS: "unilm/dit/text_detection/tuned-4000-LARGE-05302024/model_0147999.pth" # 05-30-2024
VIT:
NAME: "dit_large_patch16"
OUT_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
diff --git a/docs/docs/guides/service-discovery.md b/docs/docs/guides/service-discovery.md
new file mode 100644
index 00000000..fb7c2034
--- /dev/null
+++ b/docs/docs/guides/service-discovery.md
@@ -0,0 +1,60 @@
+# Service Discovery
+
+Service discovery is a mechanism that allows services to find and communicate with each other.
+Here we are going to show how to use the `EtcdServiceResolver` to resolve services from etcd.
+
+## Install etcd
+
+```bash
+docker run -d -p 2379:2379 --name etcd \
+-v /usr/share/ca-certificates/:/etc/ssl/certs \
+quay.io/coreos/etcd:v3.5.14 /usr/local/bin/etcd -advertise-client-urls \
+http://0.0.0.0:2379 -listen-client-urls http://0.0.0.0:2379
+```
+
+Verify the installation by running the following command:
+
+```bash
+docker exec -it etcd etcdctl version
+```
+
+```bash
+docker exec -it etcd etcdctl put 'hello' 'value-1'
+docker exec -it etcd etcdctl put 'world' 'value-2'
+
+docker exec -it etcd etcdctl get "" --prefix=true
+docker exec -it etcd etcdctl get "" --from-key
+```
+
+
+## Purge the etcd data
+
+```bash
+docker exec -it etcd etcdctl del "" --from-key=true
+```
+
+
+## Install the `etcd3` package
+
+Install the `etcd3` package from source code as the `GRPC` version or `marie` is not compatible with the current version of the `etcd3` package.
+
+```bash
+git clone git@github.com:kragniz/python-etcd3.git
+cd python-etcd3
+python setup.py install
+pip show etcd3
+```
+
+# Test the registry and resolver
+
+Start the resolver and the registry services.
+
+```bash
+python ./marie/serve/discovery/resolver.py --port 2379 --host 0.0.0.0 --service-key gateway/service_test
+```
+
+```bash
+python ./registry.py --port 2379 --host 0.0.0.0 --service-key gateway/service_test --service-addr 127.0.0.1:5001 --my-id service001
+```
+
+
diff --git a/examples/batch_document_ocr.py b/examples/batch_document_ocr.py
index 7d772c4b..cc134e70 100644
--- a/examples/batch_document_ocr.py
+++ b/examples/batch_document_ocr.py
@@ -99,10 +99,22 @@ def process_request(
"page_cleaner": {
"enabled": False,
},
+ "page_boundary": {
+ "enabled": False,
+ },
+ "template_matching": {
+ "enabled": False,
+ "definition_id": "120791",
+ },
}
],
}
+ # Hostname: 'GEXT-04'
+ # VirtualHost: '/'
+ # Port: '5672'
+ # TLS: 'false'
+ # docker run -d --name extract-rabbitmq rabbitmq:3-management-alpine
# Upload file to api
logger.info(f"Uploading to marie-ai for processing : {file}")
diff --git a/marie/boxes/dit/ulim_dit_box_processor.py b/marie/boxes/dit/ulim_dit_box_processor.py
index a4fc7bb8..8f284890 100644
--- a/marie/boxes/dit/ulim_dit_box_processor.py
+++ b/marie/boxes/dit/ulim_dit_box_processor.py
@@ -76,9 +76,11 @@ def _convert_boxes(boxes):
def visualize_bboxes(
- image: Union[np.ndarray, PIL.Image.Image], bboxes: np.ndarray, format="xyxy",
- blackout=False,
- blackout_color=(0, 0, 0, 255),
+ image: Union[np.ndarray, PIL.Image.Image],
+ bboxes: np.ndarray,
+ format="xyxy",
+ blackout=False,
+ blackout_color=(0, 0, 0, 255),
) -> PIL.Image:
"""Visualize bounding boxes on the image#
Args:
@@ -104,7 +106,7 @@ def visualize_bboxes(
if format == "xywh":
box = [box[0], box[1], box[0] + box[2], box[1] + box[3]]
- fill_color = (
+ fill_color_rgbaa = (
int(np.random.random() * 256),
int(np.random.random() * 256),
int(np.random.random() * 256),
@@ -112,13 +114,13 @@ def visualize_bboxes(
)
width = 1
if blackout:
- fill_color = blackout_color
+ fill_color_rgbaa = blackout_color
width = 0
draw.rectangle(
box,
outline="#993300",
- fill=fill_color,
+ fill=fill_color_rgbaa,
width=width,
)
@@ -149,24 +151,18 @@ def is_surrounded_by_black(image):
def is_mostly_on_black_background(image, threshold=0.5):
- # Convert the image to grayscale
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
-
- # Count the number of black pixels
black_pixels = np.sum(gray == 0)
-
- # Calculate the total number of pixels in the image
total_pixels = gray.size
-
- # Calculate the ratio of black pixels to total pixels
ratio = black_pixels / total_pixels
-
- # If the ratio is above the threshold, the image is mostly on a black background
return ratio > threshold
def blackout_bboxes(
- image: Union[np.ndarray, PIL.Image.Image], bboxes: np.ndarray, bbox_format="xyxy", fill_color=(255, 255, 255)
+ image: Union[np.ndarray, PIL.Image.Image],
+ bboxes: np.ndarray,
+ bbox_format="xyxy",
+ fill_color=(255, 255, 255),
) -> np.ndarray:
"""
Blackout bounding boxes on the image.
@@ -184,9 +180,7 @@ def blackout_bboxes(
"""
if image is None:
- raise ValueError(
- "Input image can't be empty"
- )
+ raise ValueError("Input image can't be empty")
if isinstance(image, PIL.Image.Image):
image = np.array(image)
@@ -196,10 +190,10 @@ def blackout_bboxes(
box = [int(x) for x in box]
if bbox_format == "xywh":
box = [box[0], box[1], box[0] + box[2], box[1] + box[3]]
- snippet = image[box[1]:box[3], box[0]:box[2]]
+ snippet = image[box[1] : box[3], box[0] : box[2]]
if is_surrounded_by_black(snippet) or is_mostly_on_black_background(snippet):
continue
- image[box[1]:box[3], box[0]:box[2], :] = fill_color
+ image[box[1] : box[3], box[0] : box[2], :] = fill_color
return image
@@ -295,7 +289,7 @@ def lines_from_bboxes(image, bboxes):
def crop_to_content_box(
- frame: np.ndarray, content_aware=False
+ frame: np.ndarray, content_aware=False
) -> Tuple[np.ndarray, np.ndarray]:
"""
Crop given image to content and return new box with the offset.
@@ -352,7 +346,7 @@ def crop_to_content_box(
h = indices[0].max() - y
w = indices[1].max() - x
- cropped = frame[y: y + h + 1, x: x + w + 1].copy()
+ cropped = frame[y : y + h + 1, x : x + w + 1].copy()
dt = time.time() - start
# create offset box in LTRB format (left, top, right, bottom) from XYWH format
offset = [x, y, img_w - w, img_h - h]
@@ -388,11 +382,11 @@ class BoxProcessorUlimDit(BoxProcessor):
"""
def __init__(
- self,
- work_dir: str = "/tmp/boxes",
- models_dir: str = __model_path__,
- cuda: bool = False,
- refinement: bool = True,
+ self,
+ work_dir: str = "/tmp/boxes",
+ models_dir: str = __model_path__,
+ cuda: bool = False,
+ refinement: bool = True,
):
super().__init__(work_dir, models_dir, cuda)
self.logger = MarieLogger(self.__class__.__name__)
@@ -427,8 +421,16 @@ def psm_word(self, image):
raise Exception("Not implemented : PSM_WORD")
return self.psm_sparse(image)
+ @torch.no_grad()
def psm_sparse_step(self, image: np.ndarray, adj_x: int, adj_y: int):
-
+ """
+ Perform a step in the PSM sparse pipeline.
+
+ :param image: The input image.
+ :param adj_x: The adjustment in the x-direction.
+ :param adj_y: The adjustment in the y-direction.
+ :return: A tuple containing the bounding boxes, predicted classes, and scores.
+ """
try:
rp = self.predictor(image)
# detach the predictions from GPU to avoid memory leak
@@ -502,11 +504,12 @@ def psm_sparse_step(self, image: np.ndarray, adj_x: int, adj_y: int):
return [], [], []
def psm_sparse(
- self,
- image: np.ndarray,
- bbox_optimization: Optional[bool] = False,
- bbox_context_aware: Optional[bool] = True,
- enable_visualization: Optional[bool] = False,
+ self,
+ image: np.ndarray,
+ bbox_optimization: Optional[bool] = False,
+ bbox_context_aware: Optional[bool] = True,
+ bbox_refinement: Optional[bool] = None,
+ enable_visualization: Optional[bool] = False,
):
try:
self.logger.debug(f"Starting box predictions : {image.shape}")
@@ -518,8 +521,8 @@ def psm_sparse(
adj_y = 0
# Both height and width are smaller than the minimum size then frame the image
if (
- image.shape[0] < self.min_size_test[0]
- or image.shape[1] < self.min_size_test[1]
+ image.shape[0] < self.min_size_test[0]
+ or image.shape[1] < self.min_size_test[1]
):
self.logger.debug(
f"Image size is too small : {image.shape}, resizing to {self.min_size_test}"
@@ -535,22 +538,36 @@ def psm_sparse(
adj_x = coord[0]
adj_y = coord[1]
+ # check if we need to perform refinement and have an override to disable it
refinement_steps = 1
+ default_refinement = self.refinement
+ if bbox_refinement is not None:
+ self.refinement = bbox_refinement
if self.refinement:
refinement_steps = 3
-
+ self.logger.info(f"Refinement : {self.refinement}")
+ self.refinement = default_refinement
bboxes, classes, scores = [], [], []
refinement_image = image
hash_id = hash_frames_fast(frames=[refinement_image])
- ensure_exists("/tmp/boxes")
- enable_visualization = True
+
+ if enable_visualization:
+ ensure_exists("/tmp/boxes")
+
for i in range(refinement_steps):
try:
- bboxes_, classes_, scores_ = self.psm_sparse_step(refinement_image, adj_x, adj_y)
+ bboxes_, classes_, scores_ = self.psm_sparse_step(
+ refinement_image, adj_x, adj_y
+ )
refinement_copy = copy.deepcopy(refinement_image)
- refinement_image = blackout_bboxes(refinement_image, bboxes_, bbox_format="xyxy")
+ refinement_image = blackout_bboxes(
+ refinement_image, bboxes_, bbox_format="xyxy"
+ )
if enable_visualization:
- cv2.imwrite(f"/tmp/boxes/refinement_image_{hash_id}_{i}.png", refinement_image)
+ cv2.imwrite(
+ f"/tmp/boxes/refinement_image_{hash_id}_{i}.png",
+ refinement_image,
+ )
if i == 0:
bboxes.extend(bboxes_)
@@ -593,7 +610,7 @@ def psm_sparse(
x0, y0, x1, y1 = box
w = x1 - x0
h = y1 - y0
- snippet = image[y0: y0 + h, x0: x0 + w:]
+ snippet = image[y0 : y0 + h, x0 : x0 + w :]
offset, cropped = crop_to_content_box(
snippet, content_aware=bbox_context_aware
)
@@ -656,13 +673,14 @@ def psm_multiline(self, image):
@torch.no_grad()
def extract_bounding_boxes(
- self,
- _id,
- key,
- img: Union[np.ndarray, PIL.Image.Image],
- psm=PSMode.SPARSE,
- bbox_optimization: Optional[bool] = False,
- bbox_context_aware: Optional[bool] = True,
+ self,
+ _id,
+ key,
+ img: Union[np.ndarray, PIL.Image.Image],
+ psm=PSMode.SPARSE,
+ bbox_optimization: Optional[bool] = False,
+ bbox_context_aware: Optional[bool] = True,
+ bbox_refinement: Optional[bool] = None,
) -> Tuple[Any, Any, Any, Any, Any]:
if img is None:
raise Exception("Input image can't be empty")
@@ -690,7 +708,7 @@ def extract_bounding_boxes(
# Page Segmentation Model
if psm == PSMode.SPARSE:
bboxes, polys, scores, lines_bboxes, classes = self.psm_sparse(
- image, bbox_optimization, bbox_context_aware
+ image, bbox_optimization, bbox_context_aware, bbox_refinement
)
# elif psm == PSMode.WORD:
# bboxes, polys, scores, lines_bboxes, classes = self.psm_word(image_norm)
@@ -746,7 +764,7 @@ def extract_bounding_boxes(
# self.logger.debug(f" index = {i} box_adj = {box_adj} : {h} , {w} > {box}")
# Class 0 == Text
if classes[i] == 0:
- snippet = img[y0: y0 + h, x0: x0 + w:]
+ snippet = img[y0 : y0 + h, x0 : x0 + w :]
line_number = find_line_number(lines_bboxes, box_adj)
fragments.append(snippet)
rect_from_poly.append(box_adj)
diff --git a/marie/clients/__init__.py b/marie/clients/__init__.py
index 785fdeaf..3ddb97a7 100644
--- a/marie/clients/__init__.py
+++ b/marie/clients/__init__.py
@@ -1,4 +1,5 @@
"""Module wrapping the Client of Jina."""
+
import argparse
from typing import TYPE_CHECKING, List, Optional, Union, overload
@@ -72,9 +73,7 @@ def Client(
# overload_inject_end_client
-def Client(
- args: Optional['argparse.Namespace'] = None, **kwargs
-) -> Union[
+def Client(args: Optional['argparse.Namespace'] = None, **kwargs) -> Union[
'AsyncWebSocketClient',
'WebSocketClient',
'AsyncGRPCClient',
diff --git a/marie/clients/base/__init__.py b/marie/clients/base/__init__.py
index 1d6a262b..845a1deb 100644
--- a/marie/clients/base/__init__.py
+++ b/marie/clients/base/__init__.py
@@ -1,4 +1,5 @@
"""Module containing the Base Client for Jina."""
+
import abc
import argparse
import inspect
@@ -48,9 +49,11 @@ def __init__(
os.unsetenv('https_proxy')
self._inputs = None
self._setup_instrumentation(
- name=self.args.name
- if hasattr(self.args, 'name')
- else self.__class__.__name__,
+ name=(
+ self.args.name
+ if hasattr(self.args, 'name')
+ else self.__class__.__name__
+ ),
tracing=self.args.tracing,
traces_exporter_host=self.args.traces_exporter_host,
traces_exporter_port=self.args.traces_exporter_port,
@@ -180,8 +183,7 @@ async def _get_results(
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
**kwargs,
- ):
- ...
+ ): ...
@abc.abstractmethod
def _is_flow_ready(self, **kwargs) -> bool:
diff --git a/marie/clients/base/grpc.py b/marie/clients/base/grpc.py
index 64eb7979..5a3f1089 100644
--- a/marie/clients/base/grpc.py
+++ b/marie/clients/base/grpc.py
@@ -142,7 +142,9 @@ async def _get_results(
compression=self.compression,
**kwargs,
)
- async for response in stream_rpc.stream_rpc_with_retry():
+ async for (
+ response
+ ) in stream_rpc.stream_rpc_with_retry():
yield response
else:
unary_rpc = UnaryRpc(
diff --git a/marie/clients/base/websocket.py b/marie/clients/base/websocket.py
index dde458b8..a37d2c32 100644
--- a/marie/clients/base/websocket.py
+++ b/marie/clients/base/websocket.py
@@ -1,4 +1,5 @@
"""A module for the websockets-based Client for Jina."""
+
import asyncio
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Dict, Optional, Tuple
@@ -131,9 +132,9 @@ async def _get_results(
)
)
- request_buffer: Dict[
- str, asyncio.Future
- ] = dict() # maps request_ids to futures (tasks)
+ request_buffer: Dict[str, asyncio.Future] = (
+ dict()
+ ) # maps request_ids to futures (tasks)
def _result_handler(result):
return result
diff --git a/marie/components/document_registration/unilm_dit.py b/marie/components/document_registration/unilm_dit.py
index 92fd0bf4..90c5b33e 100644
--- a/marie/components/document_registration/unilm_dit.py
+++ b/marie/components/document_registration/unilm_dit.py
@@ -1,5 +1,6 @@
import argparse
import os
+from collections.abc import Iterable
from typing import List, Optional, Union
import cv2
@@ -9,6 +10,7 @@
from detectron2.utils.visualizer import ColorMode, Visualizer
from ditod import add_vit_config
from docarray import DocList
+from torchvision.ops.boxes import batched_nms
from tqdm import tqdm
from marie.constants import __config_dir__, __model_path__
@@ -375,8 +377,32 @@ def predict_document_image(
return [default_prediction]
if len(boxes) > 1:
- self.logger.warning(f"Multiple boxes detected, skipping segmentation.")
- return [default_prediction]
+
+ min_score = 0.7
+ indices = np.where(scores > min_score)
+ scores = scores[indices]
+ boxes = boxes[indices]
+ classes = classes[indices]
+
+ if len(boxes) == 0:
+ self.logger.warning(
+ f"No segmentation boxes predicted. No boxes above threshold."
+ )
+ return [default_prediction]
+
+ post_nms_topk = 1
+ nms_thresh = 0.5
+ keep = batched_nms(
+ torch.tensor(boxes.astype(np.float32)),
+ torch.tensor(scores.astype(np.float32)),
+ torch.tensor(classes),
+ nms_thresh,
+ )
+
+ keep = keep[:post_nms_topk]
+ boxes = [boxes[keep]]
+ scores = [scores[keep]]
+ classes = [classes[keep]]
boundary_bbox = [int(x) for x in boxes[0]] # xyxy format
# TODO : add this as a parameter
diff --git a/marie/components/template_matching/base.py b/marie/components/template_matching/base.py
index cf9e3484..47f33eab 100644
--- a/marie/components/template_matching/base.py
+++ b/marie/components/template_matching/base.py
@@ -255,9 +255,7 @@ def run(
prediction_time_end = time.time() - prediction_time_start
durations_in_seconds["prediction"] = prediction_time_end
self.logger.debug(
- "Slice-prediction performed in",
- durations_in_seconds["prediction"],
- "seconds.",
+ f"Slice-prediction performed in {durations_in_seconds['prediction']} seconds."
)
for prediction in predictions:
diff --git a/marie/constants.py b/marie/constants.py
index 0daa2b1c..197801f5 100644
--- a/marie/constants.py
+++ b/marie/constants.py
@@ -51,6 +51,7 @@
__default_composite_gateway__ = "CompositeGateway"
__default_websocket_gateway__ = "WebSocketGateway"
__default_grpc_gateway__ = "GRPCGateway"
+__dynamic_base_gateway_hubble__ = "marie/hubble-gateway"
__default_endpoint__ = "/default"
__ready_msg__ = "ready and listening"
__stop_msg__ = "terminated"
diff --git a/marie/enums.py b/marie/enums.py
index ccce1638..fbbe3d71 100644
--- a/marie/enums.py
+++ b/marie/enums.py
@@ -269,6 +269,7 @@ class ProviderType(BetterEnum):
NONE = 0 #: no provider
SAGEMAKER = 1 #: AWS SageMaker
+ AZURE = 2 #: AZURE
def replace_enum_to_str(obj):
diff --git a/marie/helper.py b/marie/helper.py
index fa62dfcb..97443caa 100644
--- a/marie/helper.py
+++ b/marie/helper.py
@@ -947,7 +947,8 @@ def get_full_version() -> Optional[Tuple[Dict, Dict]]:
except:
__hubble_version__ = 'not-available'
try:
- from jcloud import __version__ as __jcloud_version__
+ # from jcloud import __version__ as __jcloud_version__
+ raise ImportError('jcloud is not available')
except:
__jcloud_version__ = 'not-available'
diff --git a/marie/orchestrate/deployments/__init__.py b/marie/orchestrate/deployments/__init__.py
index 4f5fd86e..d710ef23 100644
--- a/marie/orchestrate/deployments/__init__.py
+++ b/marie/orchestrate/deployments/__init__.py
@@ -609,9 +609,11 @@ def _get_connection_list_for_flow(self) -> List[str]:
# there is no head, add the worker connection information instead
ports = self.ports
hosts = [
- __docker_host__
- if host_is_local(host) and in_docker() and self._is_docker
- else host
+ (
+ __docker_host__
+ if host_is_local(host) and in_docker() and self._is_docker
+ else host
+ )
for host in self.hosts
]
return [
@@ -1137,9 +1139,11 @@ def start(self) -> "Deployment":
deployment_args=self.args,
args=self.pod_args["pods"][shard_id],
head_pod=self.head_pod,
- name=f"{self.name}-replica-set-{shard_id}"
- if num_shards > 1
- else f"{self.name}-replica-set",
+ name=(
+ f'{self.name}-replica-set-{shard_id}'
+ if num_shards > 1
+ else f'{self.name}-replica-set'
+ ),
)
self.enter_context(self.shards[shard_id])
diff --git a/marie/orchestrate/flow/base.py b/marie/orchestrate/flow/base.py
index 10447cdf..ba8e2273 100644
--- a/marie/orchestrate/flow/base.py
+++ b/marie/orchestrate/flow/base.py
@@ -283,7 +283,7 @@ def __init__(
Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default)
:param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET'].
- :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER'].
+ :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'AZURE'].
:param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider.
:param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy
:param py_modules: The customized python modules need to be imported before loading the gateway
@@ -984,7 +984,7 @@ def add(
:param port_monitoring: The port on which the prometheus server is exposed, default is a random port between [49152, 65535]
:param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64")
:param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET'].
- :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER'].
+ :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'AZURE'].
:param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider.
:param py_modules: The customized python modules need to be imported before loading the executor
@@ -1416,7 +1416,7 @@ def config_gateway(
Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default)
:param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET'].
- :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER'].
+ :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'AZURE'].
:param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider.
:param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy
:param py_modules: The customized python modules need to be imported before loading the gateway
@@ -1787,10 +1787,8 @@ def build(self, copy_flow: bool = False, **kwargs) -> 'Flow':
op_flow._deployment_nodes[GATEWAY_NAME].args.graph_description = json.dumps(
op_flow._get_graph_representation()
)
- op_flow._deployment_nodes[
- GATEWAY_NAME
- ].args.deployments_addresses = json.dumps(
- op_flow._get_deployments_addresses()
+ op_flow._deployment_nodes[GATEWAY_NAME].args.deployments_addresses = (
+ json.dumps(op_flow._get_deployments_addresses())
)
op_flow._deployment_nodes[GATEWAY_NAME].update_pod_args()
@@ -1872,19 +1870,25 @@ def start(self):
runtime_args = self._deployment_nodes[GATEWAY_NAME].args
- for gport, gprotocol in zip(port_gateway, protocol_gateway):
- # TODO : Need to implement GRPC and WEBSOCKET
- if gprotocol == ProtocolType.HTTP:
- self._setup_service_discovery(
- name=f"marie-{GATEWAY_NAME}",
- host=self.host if self.host != '0.0.0.0' else get_internal_ip(),
- port=gport,
- scheme=runtime_args.scheme if 'scheme' in runtime_args else 'http',
- discovery=runtime_args.discovery,
- discovery_host=runtime_args.discovery_host,
- discovery_port=runtime_args.discovery_port,
- discovery_watchdog_interval=runtime_args.discovery_watchdog_interval,
- )
+ if runtime_args.discovery:
+ for gport, gprotocol in zip(port_gateway, protocol_gateway):
+ if gprotocol in (ProtocolType.HTTP, ProtocolType.GRPC):
+ self._setup_service_discovery(
+ protocol=gprotocol,
+ name=f"marie-{GATEWAY_NAME}",
+ host=self.host if self.host != '0.0.0.0' else get_internal_ip(),
+ port=gport,
+ scheme=(
+ runtime_args.scheme if 'scheme' in runtime_args else 'http'
+ ),
+ discovery=runtime_args.discovery,
+ discovery_host=runtime_args.discovery_host,
+ discovery_port=runtime_args.discovery_port,
+ discovery_watchdog_interval=runtime_args.discovery_watchdog_interval,
+ runtime_args=runtime_args,
+ )
+ else:
+ self.logger.warning('Service Discovery is disabled for gateway.')
self._build_level = FlowBuildLevel.RUNNING
@@ -2031,7 +2035,6 @@ async def _f():
if not running_in_event_loop:
asyncio.get_event_loop().run_until_complete(_async_wait_all())
else:
- # TODO: the same logic that one fails all other fail should be done also here
for k, v in self:
wait_ready_threads.append(
threading.Thread(target=_wait_ready, args=(k, v), daemon=True)
@@ -2444,7 +2447,6 @@ def _get_summary_table(self, all_panels: List[Panel]):
http_ext_table.add_row(':books:', 'Redoc', redoc_link)
if self.gateway_args.expose_graphql_endpoint:
-
http_ext_table.add_row(':strawberry:', 'GraphQL UI', graphql_ui_link)
if True or self.gateway_args.discovery:
@@ -2466,6 +2468,23 @@ def _get_summary_table(self, all_panels: List[Panel]):
)
)
+ pod_ext_table = self._init_table()
+
+ for name, deployment in self:
+ for replica in deployment.pod_args['pods'][0]:
+ pod_ext_table.add_row(
+ ':lock:',
+ replica.name,
+ f'[link=://{replica.host}:{replica.port[0]}]{replica.host}:{replica.port[0]}[/]',
+ )
+ all_panels.append(
+ Panel(
+ pod_ext_table,
+ title=':gem: [b]Deployment Nodes[/]',
+ expand=False,
+ )
+ )
+
if self.monitoring:
monitor_ext_table = self._init_table()
diff --git a/marie/pipe/extract_pipeline.py b/marie/pipe/extract_pipeline.py
index d4c8f838..7c2fa876 100644
--- a/marie/pipe/extract_pipeline.py
+++ b/marie/pipe/extract_pipeline.py
@@ -337,14 +337,10 @@ def execute_frames_pipeline(
# burst frames into individual images
burst_frames(ref_id, frames, root_asset_dir)
-
- print(f"before boundary : {len(frames)}")
frames, boundary_meta = self.boundary(
ref_id, frames, root_asset_dir, enabled=page_boundary_enabled
)
- print(f"after boundary : {len(frames)}")
-
clean_frames = self.segment(
ref_id, frames, root_asset_dir, enabled=page_cleaner_enabled
)
diff --git a/marie/proto/docarray_v1/pb/jina_pb2.py b/marie/proto/docarray_v1/pb/jina_pb2.py
index 92af6004..28830e99 100644
--- a/marie/proto/docarray_v1/pb/jina_pb2.py
+++ b/marie/proto/docarray_v1/pb/jina_pb2.py
@@ -6,6 +6,7 @@
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
+
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@@ -17,81 +18,83 @@
import docarray.proto.pb.docarray_pb2 as docarray__pb2
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\njina.proto\x12\x04jina\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x0e\x64ocarray.proto\"\x9f\x01\n\nRouteProto\x12\x10\n\x08\x65xecutor\x18\x01 \x01(\t\x12.\n\nstart_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12!\n\x06status\x18\x04 \x01(\x0b\x32\x11.jina.StatusProto\"\xc3\x01\n\rJinaInfoProto\x12+\n\x04jina\x18\x01 \x03(\x0b\x32\x1d.jina.JinaInfoProto.JinaEntry\x12+\n\x04\x65nvs\x18\x02 \x03(\x0b\x32\x1d.jina.JinaInfoProto.EnvsEntry\x1a+\n\tJinaEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a+\n\tEnvsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc6\x01\n\x0bHeaderProto\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12!\n\x06status\x18\x02 \x01(\x0b\x32\x11.jina.StatusProto\x12\x1a\n\rexec_endpoint\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0ftarget_executor\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x14\n\x07timeout\x18\x05 \x01(\rH\x02\x88\x01\x01\x42\x10\n\x0e_exec_endpointB\x12\n\x10_target_executorB\n\n\x08_timeout\"f\n\x0e\x45ndpointsProto\x12\x11\n\tendpoints\x18\x01 \x03(\t\x12\x17\n\x0fwrite_endpoints\x18\x02 \x03(\t\x12(\n\x07schemas\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xf9\x01\n\x0bStatusProto\x12*\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1c.jina.StatusProto.StatusCode\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x33\n\texception\x18\x03 \x01(\x0b\x32 .jina.StatusProto.ExceptionProto\x1aN\n\x0e\x45xceptionProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x03(\t\x12\x0e\n\x06stacks\x18\x03 \x03(\t\x12\x10\n\x08\x65xecutor\x18\x04 \x01(\t\"$\n\nStatusCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\t\n\x05\x45RROR\x10\x01\"^\n\rRelatedEntity\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x07\x61\x64\x64ress\x18\x02 \x01(\t\x12\x0c\n\x04port\x18\x03 \x01(\r\x12\x15\n\x08shard_id\x18\x04 \x01(\rH\x00\x88\x01\x01\x42\x0b\n\t_shard_id\"\xa0\x02\n\x10\x44\x61taRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12\x35\n\x04\x64\x61ta\x18\x04 \x01(\x0b\x32\'.jina.DataRequestProto.DataContentProto\x1a\x63\n\x10\x44\x61taContentProto\x12,\n\x04\x64ocs\x18\x01 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12\x14\n\ndocs_bytes\x18\x02 \x01(\x0cH\x00\x42\x0b\n\tdocuments\"\xb9\x01\n\x1aSingleDocumentRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12)\n\x08\x64ocument\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProto\"\x8a\x01\n\x16\x44\x61taRequestProtoWoData\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\"@\n\x14\x44\x61taRequestListProto\x12(\n\x08requests\x18\x01 \x03(\x0b\x32\x16.jina.DataRequestProto\"\x1b\n\nSnapshotId\x12\r\n\x05value\x18\x01 \x01(\t\"\x1a\n\tRestoreId\x12\r\n\x05value\x18\x01 \x01(\t\"\xef\x01\n\x13SnapshotStatusProto\x12\x1c\n\x02id\x18\x01 \x01(\x0b\x32\x10.jina.SnapshotId\x12\x30\n\x06status\x18\x02 \x01(\x0e\x32 .jina.SnapshotStatusProto.Status\x12\x15\n\rsnapshot_file\x18\x03 \x01(\t\"q\n\x06Status\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\r\n\tSCHEDULED\x10\x02\x12\x0b\n\x07RUNNING\x10\x03\x12\r\n\tSUCCEEDED\x10\x04\x12\n\n\x06\x46\x41ILED\x10\x05\x12\r\n\tNOT_FOUND\x10\x06\"\xca\x01\n\x1aRestoreSnapshotStatusProto\x12\x1b\n\x02id\x18\x01 \x01(\x0b\x32\x0f.jina.RestoreId\x12\x37\n\x06status\x18\x02 \x01(\x0e\x32\'.jina.RestoreSnapshotStatusProto.Status\"V\n\x06Status\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\x0b\n\x07RUNNING\x10\x01\x12\r\n\tSUCCEEDED\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\r\n\tNOT_FOUND\x10\x06\"/\n\x16RestoreSnapshotCommand\x12\x15\n\rsnapshot_file\x18\x01 \x01(\t2Z\n\x12JinaDataRequestRPC\x12\x44\n\x0cprocess_data\x12\x1a.jina.DataRequestListProto\x1a\x16.jina.DataRequestProto\"\x00\x32\x63\n\x18JinaSingleDataRequestRPC\x12G\n\x13process_single_data\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00\x32t\n\x1cJinaSingleDocumentRequestRPC\x12T\n\nstream_doc\x12 .jina.SingleDocumentRequestProto\x1a .jina.SingleDocumentRequestProto\"\x00\x30\x01\x32G\n\x07JinaRPC\x12<\n\x04\x43\x61ll\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00(\x01\x30\x01\x32`\n\x18JinaDiscoverEndpointsRPC\x12\x44\n\x12\x65ndpoint_discovery\x12\x16.google.protobuf.Empty\x1a\x14.jina.EndpointsProto\"\x00\x32N\n\x14JinaGatewayDryRunRPC\x12\x36\n\x07\x64ry_run\x12\x16.google.protobuf.Empty\x1a\x11.jina.StatusProto\"\x00\x32G\n\x0bJinaInfoRPC\x12\x38\n\x07_status\x12\x16.google.protobuf.Empty\x1a\x13.jina.JinaInfoProto\"\x00\x32W\n\x14JinaExecutorSnapshot\x12?\n\x08snapshot\x12\x16.google.protobuf.Empty\x1a\x19.jina.SnapshotStatusProto\"\x00\x32`\n\x1cJinaExecutorSnapshotProgress\x12@\n\x0fsnapshot_status\x12\x10.jina.SnapshotId\x1a\x19.jina.SnapshotStatusProto\"\x00\x32\x62\n\x13JinaExecutorRestore\x12K\n\x07restore\x12\x1c.jina.RestoreSnapshotCommand\x1a .jina.RestoreSnapshotStatusProto\"\x00\x32\x64\n\x1bJinaExecutorRestoreProgress\x12\x45\n\x0erestore_status\x12\x0f.jina.RestoreId\x1a .jina.RestoreSnapshotStatusProto\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\njina.proto\x12\x04jina\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x0e\x64ocarray.proto\"\x9f\x01\n\nRouteProto\x12\x10\n\x08\x65xecutor\x18\x01 \x01(\t\x12.\n\nstart_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12!\n\x06status\x18\x04 \x01(\x0b\x32\x11.jina.StatusProto\"\xc3\x01\n\rJinaInfoProto\x12+\n\x04jina\x18\x01 \x03(\x0b\x32\x1d.jina.JinaInfoProto.JinaEntry\x12+\n\x04\x65nvs\x18\x02 \x03(\x0b\x32\x1d.jina.JinaInfoProto.EnvsEntry\x1a+\n\tJinaEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a+\n\tEnvsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc6\x01\n\x0bHeaderProto\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12!\n\x06status\x18\x02 \x01(\x0b\x32\x11.jina.StatusProto\x12\x1a\n\rexec_endpoint\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0ftarget_executor\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x14\n\x07timeout\x18\x05 \x01(\rH\x02\x88\x01\x01\x42\x10\n\x0e_exec_endpointB\x12\n\x10_target_executorB\n\n\x08_timeout\"f\n\x0e\x45ndpointsProto\x12\x11\n\tendpoints\x18\x01 \x03(\t\x12\x17\n\x0fwrite_endpoints\x18\x02 \x03(\t\x12(\n\x07schemas\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xf9\x01\n\x0bStatusProto\x12*\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1c.jina.StatusProto.StatusCode\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x33\n\texception\x18\x03 \x01(\x0b\x32 .jina.StatusProto.ExceptionProto\x1aN\n\x0e\x45xceptionProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x03(\t\x12\x0e\n\x06stacks\x18\x03 \x03(\t\x12\x10\n\x08\x65xecutor\x18\x04 \x01(\t\"$\n\nStatusCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\t\n\x05\x45RROR\x10\x01\"^\n\rRelatedEntity\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x07\x61\x64\x64ress\x18\x02 \x01(\t\x12\x0c\n\x04port\x18\x03 \x01(\r\x12\x15\n\x08shard_id\x18\x04 \x01(\rH\x00\x88\x01\x01\x42\x0b\n\t_shard_id\"\xa0\x02\n\x10\x44\x61taRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12\x35\n\x04\x64\x61ta\x18\x04 \x01(\x0b\x32\'.jina.DataRequestProto.DataContentProto\x1a\x63\n\x10\x44\x61taContentProto\x12,\n\x04\x64ocs\x18\x01 \x01(\x0b\x32\x1c.docarray.DocumentArrayProtoH\x00\x12\x14\n\ndocs_bytes\x18\x02 \x01(\x0cH\x00\x42\x0b\n\tdocuments\"\xb9\x01\n\x1aSingleDocumentRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12)\n\x08\x64ocument\x18\x04 \x01(\x0b\x32\x17.docarray.DocumentProto\"\x8a\x01\n\x16\x44\x61taRequestProtoWoData\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\"@\n\x14\x44\x61taRequestListProto\x12(\n\x08requests\x18\x01 \x03(\x0b\x32\x16.jina.DataRequestProto\"\x1b\n\nSnapshotId\x12\r\n\x05value\x18\x01 \x01(\t\"\x1a\n\tRestoreId\x12\r\n\x05value\x18\x01 \x01(\t\"\xef\x01\n\x13SnapshotStatusProto\x12\x1c\n\x02id\x18\x01 \x01(\x0b\x32\x10.jina.SnapshotId\x12\x30\n\x06status\x18\x02 \x01(\x0e\x32 .jina.SnapshotStatusProto.Status\x12\x15\n\rsnapshot_file\x18\x03 \x01(\t\"q\n\x06Status\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\r\n\tSCHEDULED\x10\x02\x12\x0b\n\x07RUNNING\x10\x03\x12\r\n\tSUCCEEDED\x10\x04\x12\n\n\x06\x46\x41ILED\x10\x05\x12\r\n\tNOT_FOUND\x10\x06\"\xca\x01\n\x1aRestoreSnapshotStatusProto\x12\x1b\n\x02id\x18\x01 \x01(\x0b\x32\x0f.jina.RestoreId\x12\x37\n\x06status\x18\x02 \x01(\x0e\x32\'.jina.RestoreSnapshotStatusProto.Status\"V\n\x06Status\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\x0b\n\x07RUNNING\x10\x01\x12\r\n\tSUCCEEDED\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\r\n\tNOT_FOUND\x10\x06\"/\n\x16RestoreSnapshotCommand\x12\x15\n\rsnapshot_file\x18\x01 \x01(\t2Z\n\x12JinaDataRequestRPC\x12\x44\n\x0cprocess_data\x12\x1a.jina.DataRequestListProto\x1a\x16.jina.DataRequestProto\"\x00\x32\x63\n\x18JinaSingleDataRequestRPC\x12G\n\x13process_single_data\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00\x32t\n\x1cJinaSingleDocumentRequestRPC\x12T\n\nstream_doc\x12 .jina.SingleDocumentRequestProto\x1a .jina.SingleDocumentRequestProto\"\x00\x30\x01\x32G\n\x07JinaRPC\x12<\n\x04\x43\x61ll\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00(\x01\x30\x01\x32`\n\x18JinaDiscoverEndpointsRPC\x12\x44\n\x12\x65ndpoint_discovery\x12\x16.google.protobuf.Empty\x1a\x14.jina.EndpointsProto\"\x00\x32N\n\x14JinaGatewayDryRunRPC\x12\x36\n\x07\x64ry_run\x12\x16.google.protobuf.Empty\x1a\x11.jina.StatusProto\"\x00\x32G\n\x0bJinaInfoRPC\x12\x38\n\x07_status\x12\x16.google.protobuf.Empty\x1a\x13.jina.JinaInfoProto\"\x00\x32W\n\x14JinaExecutorSnapshot\x12?\n\x08snapshot\x12\x16.google.protobuf.Empty\x1a\x19.jina.SnapshotStatusProto\"\x00\x32`\n\x1cJinaExecutorSnapshotProgress\x12@\n\x0fsnapshot_status\x12\x10.jina.SnapshotId\x1a\x19.jina.SnapshotStatusProto\"\x00\x32\x62\n\x13JinaExecutorRestore\x12K\n\x07restore\x12\x1c.jina.RestoreSnapshotCommand\x1a .jina.RestoreSnapshotStatusProto\"\x00\x32\x64\n\x1bJinaExecutorRestoreProgress\x12\x45\n\x0erestore_status\x12\x0f.jina.RestoreId\x1a .jina.RestoreSnapshotStatusProto\"\x00\x62\x06proto3'
+)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'jina_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
- DESCRIPTOR._options = None
- _JINAINFOPROTO_JINAENTRY._options = None
- _JINAINFOPROTO_JINAENTRY._serialized_options = b'8\001'
- _JINAINFOPROTO_ENVSENTRY._options = None
- _JINAINFOPROTO_ENVSENTRY._serialized_options = b'8\001'
- _ROUTEPROTO._serialized_start=129
- _ROUTEPROTO._serialized_end=288
- _JINAINFOPROTO._serialized_start=291
- _JINAINFOPROTO._serialized_end=486
- _JINAINFOPROTO_JINAENTRY._serialized_start=398
- _JINAINFOPROTO_JINAENTRY._serialized_end=441
- _JINAINFOPROTO_ENVSENTRY._serialized_start=443
- _JINAINFOPROTO_ENVSENTRY._serialized_end=486
- _HEADERPROTO._serialized_start=489
- _HEADERPROTO._serialized_end=687
- _ENDPOINTSPROTO._serialized_start=689
- _ENDPOINTSPROTO._serialized_end=791
- _STATUSPROTO._serialized_start=794
- _STATUSPROTO._serialized_end=1043
- _STATUSPROTO_EXCEPTIONPROTO._serialized_start=927
- _STATUSPROTO_EXCEPTIONPROTO._serialized_end=1005
- _STATUSPROTO_STATUSCODE._serialized_start=1007
- _STATUSPROTO_STATUSCODE._serialized_end=1043
- _RELATEDENTITY._serialized_start=1045
- _RELATEDENTITY._serialized_end=1139
- _DATAREQUESTPROTO._serialized_start=1142
- _DATAREQUESTPROTO._serialized_end=1430
- _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_start=1331
- _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_end=1430
- _SINGLEDOCUMENTREQUESTPROTO._serialized_start=1433
- _SINGLEDOCUMENTREQUESTPROTO._serialized_end=1618
- _DATAREQUESTPROTOWODATA._serialized_start=1621
- _DATAREQUESTPROTOWODATA._serialized_end=1759
- _DATAREQUESTLISTPROTO._serialized_start=1761
- _DATAREQUESTLISTPROTO._serialized_end=1825
- _SNAPSHOTID._serialized_start=1827
- _SNAPSHOTID._serialized_end=1854
- _RESTOREID._serialized_start=1856
- _RESTOREID._serialized_end=1882
- _SNAPSHOTSTATUSPROTO._serialized_start=1885
- _SNAPSHOTSTATUSPROTO._serialized_end=2124
- _SNAPSHOTSTATUSPROTO_STATUS._serialized_start=2011
- _SNAPSHOTSTATUSPROTO_STATUS._serialized_end=2124
- _RESTORESNAPSHOTSTATUSPROTO._serialized_start=2127
- _RESTORESNAPSHOTSTATUSPROTO._serialized_end=2329
- _RESTORESNAPSHOTSTATUSPROTO_STATUS._serialized_start=2243
- _RESTORESNAPSHOTSTATUSPROTO_STATUS._serialized_end=2329
- _RESTORESNAPSHOTCOMMAND._serialized_start=2331
- _RESTORESNAPSHOTCOMMAND._serialized_end=2378
- _JINADATAREQUESTRPC._serialized_start=2380
- _JINADATAREQUESTRPC._serialized_end=2470
- _JINASINGLEDATAREQUESTRPC._serialized_start=2472
- _JINASINGLEDATAREQUESTRPC._serialized_end=2571
- _JINASINGLEDOCUMENTREQUESTRPC._serialized_start=2573
- _JINASINGLEDOCUMENTREQUESTRPC._serialized_end=2689
- _JINARPC._serialized_start=2691
- _JINARPC._serialized_end=2762
- _JINADISCOVERENDPOINTSRPC._serialized_start=2764
- _JINADISCOVERENDPOINTSRPC._serialized_end=2860
- _JINAGATEWAYDRYRUNRPC._serialized_start=2862
- _JINAGATEWAYDRYRUNRPC._serialized_end=2940
- _JINAINFORPC._serialized_start=2942
- _JINAINFORPC._serialized_end=3013
- _JINAEXECUTORSNAPSHOT._serialized_start=3015
- _JINAEXECUTORSNAPSHOT._serialized_end=3102
- _JINAEXECUTORSNAPSHOTPROGRESS._serialized_start=3104
- _JINAEXECUTORSNAPSHOTPROGRESS._serialized_end=3200
- _JINAEXECUTORRESTORE._serialized_start=3202
- _JINAEXECUTORRESTORE._serialized_end=3300
- _JINAEXECUTORRESTOREPROGRESS._serialized_start=3302
- _JINAEXECUTORRESTOREPROGRESS._serialized_end=3402
+ DESCRIPTOR._options = None
+ _JINAINFOPROTO_JINAENTRY._options = None
+ _JINAINFOPROTO_JINAENTRY._serialized_options = b'8\001'
+ _JINAINFOPROTO_ENVSENTRY._options = None
+ _JINAINFOPROTO_ENVSENTRY._serialized_options = b'8\001'
+ _ROUTEPROTO._serialized_start = 129
+ _ROUTEPROTO._serialized_end = 288
+ _JINAINFOPROTO._serialized_start = 291
+ _JINAINFOPROTO._serialized_end = 486
+ _JINAINFOPROTO_JINAENTRY._serialized_start = 398
+ _JINAINFOPROTO_JINAENTRY._serialized_end = 441
+ _JINAINFOPROTO_ENVSENTRY._serialized_start = 443
+ _JINAINFOPROTO_ENVSENTRY._serialized_end = 486
+ _HEADERPROTO._serialized_start = 489
+ _HEADERPROTO._serialized_end = 687
+ _ENDPOINTSPROTO._serialized_start = 689
+ _ENDPOINTSPROTO._serialized_end = 791
+ _STATUSPROTO._serialized_start = 794
+ _STATUSPROTO._serialized_end = 1043
+ _STATUSPROTO_EXCEPTIONPROTO._serialized_start = 927
+ _STATUSPROTO_EXCEPTIONPROTO._serialized_end = 1005
+ _STATUSPROTO_STATUSCODE._serialized_start = 1007
+ _STATUSPROTO_STATUSCODE._serialized_end = 1043
+ _RELATEDENTITY._serialized_start = 1045
+ _RELATEDENTITY._serialized_end = 1139
+ _DATAREQUESTPROTO._serialized_start = 1142
+ _DATAREQUESTPROTO._serialized_end = 1430
+ _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_start = 1331
+ _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_end = 1430
+ _SINGLEDOCUMENTREQUESTPROTO._serialized_start = 1433
+ _SINGLEDOCUMENTREQUESTPROTO._serialized_end = 1618
+ _DATAREQUESTPROTOWODATA._serialized_start = 1621
+ _DATAREQUESTPROTOWODATA._serialized_end = 1759
+ _DATAREQUESTLISTPROTO._serialized_start = 1761
+ _DATAREQUESTLISTPROTO._serialized_end = 1825
+ _SNAPSHOTID._serialized_start = 1827
+ _SNAPSHOTID._serialized_end = 1854
+ _RESTOREID._serialized_start = 1856
+ _RESTOREID._serialized_end = 1882
+ _SNAPSHOTSTATUSPROTO._serialized_start = 1885
+ _SNAPSHOTSTATUSPROTO._serialized_end = 2124
+ _SNAPSHOTSTATUSPROTO_STATUS._serialized_start = 2011
+ _SNAPSHOTSTATUSPROTO_STATUS._serialized_end = 2124
+ _RESTORESNAPSHOTSTATUSPROTO._serialized_start = 2127
+ _RESTORESNAPSHOTSTATUSPROTO._serialized_end = 2329
+ _RESTORESNAPSHOTSTATUSPROTO_STATUS._serialized_start = 2243
+ _RESTORESNAPSHOTSTATUSPROTO_STATUS._serialized_end = 2329
+ _RESTORESNAPSHOTCOMMAND._serialized_start = 2331
+ _RESTORESNAPSHOTCOMMAND._serialized_end = 2378
+ _JINADATAREQUESTRPC._serialized_start = 2380
+ _JINADATAREQUESTRPC._serialized_end = 2470
+ _JINASINGLEDATAREQUESTRPC._serialized_start = 2472
+ _JINASINGLEDATAREQUESTRPC._serialized_end = 2571
+ _JINASINGLEDOCUMENTREQUESTRPC._serialized_start = 2573
+ _JINASINGLEDOCUMENTREQUESTRPC._serialized_end = 2689
+ _JINARPC._serialized_start = 2691
+ _JINARPC._serialized_end = 2762
+ _JINADISCOVERENDPOINTSRPC._serialized_start = 2764
+ _JINADISCOVERENDPOINTSRPC._serialized_end = 2860
+ _JINAGATEWAYDRYRUNRPC._serialized_start = 2862
+ _JINAGATEWAYDRYRUNRPC._serialized_end = 2940
+ _JINAINFORPC._serialized_start = 2942
+ _JINAINFORPC._serialized_end = 3013
+ _JINAEXECUTORSNAPSHOT._serialized_start = 3015
+ _JINAEXECUTORSNAPSHOT._serialized_end = 3102
+ _JINAEXECUTORSNAPSHOTPROGRESS._serialized_start = 3104
+ _JINAEXECUTORSNAPSHOTPROGRESS._serialized_end = 3200
+ _JINAEXECUTORRESTORE._serialized_start = 3202
+ _JINAEXECUTORRESTORE._serialized_end = 3300
+ _JINAEXECUTORRESTOREPROGRESS._serialized_start = 3302
+ _JINAEXECUTORRESTOREPROGRESS._serialized_end = 3402
# @@protoc_insertion_point(module_scope)
diff --git a/marie/proto/docarray_v1/pb/jina_pb2_grpc.py b/marie/proto/docarray_v1/pb/jina_pb2_grpc.py
index f52ce19e..f571beae 100644
--- a/marie/proto/docarray_v1/pb/jina_pb2_grpc.py
+++ b/marie/proto/docarray_v1/pb/jina_pb2_grpc.py
@@ -18,10 +18,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.process_data = channel.unary_unary(
- '/jina.JinaDataRequestRPC/process_data',
- request_serializer=jina__pb2.DataRequestListProto.SerializeToString,
- response_deserializer=jina__pb2.DataRequestProto.FromString,
- )
+ '/jina.JinaDataRequestRPC/process_data',
+ request_serializer=jina__pb2.DataRequestListProto.SerializeToString,
+ response_deserializer=jina__pb2.DataRequestProto.FromString,
+ )
class JinaDataRequestRPCServicer(object):
@@ -30,8 +30,7 @@ class JinaDataRequestRPCServicer(object):
"""
def process_data(self, request, context):
- """Used for passing DataRequests to the Executors
- """
+ """Used for passing DataRequests to the Executors"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
@@ -39,39 +38,52 @@ def process_data(self, request, context):
def add_JinaDataRequestRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'process_data': grpc.unary_unary_rpc_method_handler(
- servicer.process_data,
- request_deserializer=jina__pb2.DataRequestListProto.FromString,
- response_serializer=jina__pb2.DataRequestProto.SerializeToString,
- ),
+ 'process_data': grpc.unary_unary_rpc_method_handler(
+ servicer.process_data,
+ request_deserializer=jina__pb2.DataRequestListProto.FromString,
+ response_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaDataRequestRPC', rpc_method_handlers)
+ 'jina.JinaDataRequestRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaDataRequestRPC(object):
"""*
jina gRPC service for DataRequests.
"""
@staticmethod
- def process_data(request,
+ def process_data(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaDataRequestRPC/process_data',
+ '/jina.JinaDataRequestRPC/process_data',
jina__pb2.DataRequestListProto.SerializeToString,
jina__pb2.DataRequestProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaSingleDataRequestRPCStub(object):
@@ -87,10 +99,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.process_single_data = channel.unary_unary(
- '/jina.JinaSingleDataRequestRPC/process_single_data',
- request_serializer=jina__pb2.DataRequestProto.SerializeToString,
- response_deserializer=jina__pb2.DataRequestProto.FromString,
- )
+ '/jina.JinaSingleDataRequestRPC/process_single_data',
+ request_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ response_deserializer=jina__pb2.DataRequestProto.FromString,
+ )
class JinaSingleDataRequestRPCServicer(object):
@@ -100,8 +112,7 @@ class JinaSingleDataRequestRPCServicer(object):
"""
def process_single_data(self, request, context):
- """Used for passing DataRequests to the Executors
- """
+ """Used for passing DataRequests to the Executors"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
@@ -109,18 +120,19 @@ def process_single_data(self, request, context):
def add_JinaSingleDataRequestRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'process_single_data': grpc.unary_unary_rpc_method_handler(
- servicer.process_single_data,
- request_deserializer=jina__pb2.DataRequestProto.FromString,
- response_serializer=jina__pb2.DataRequestProto.SerializeToString,
- ),
+ 'process_single_data': grpc.unary_unary_rpc_method_handler(
+ servicer.process_single_data,
+ request_deserializer=jina__pb2.DataRequestProto.FromString,
+ response_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaSingleDataRequestRPC', rpc_method_handlers)
+ 'jina.JinaSingleDataRequestRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaSingleDataRequestRPC(object):
"""*
jina gRPC service for DataRequests.
@@ -128,21 +140,33 @@ class JinaSingleDataRequestRPC(object):
"""
@staticmethod
- def process_single_data(request,
+ def process_single_data(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaSingleDataRequestRPC/process_single_data',
+ '/jina.JinaSingleDataRequestRPC/process_single_data',
jina__pb2.DataRequestProto.SerializeToString,
jina__pb2.DataRequestProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaSingleDocumentRequestRPCStub(object):
@@ -158,10 +182,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.stream_doc = channel.unary_stream(
- '/jina.JinaSingleDocumentRequestRPC/stream_doc',
- request_serializer=jina__pb2.SingleDocumentRequestProto.SerializeToString,
- response_deserializer=jina__pb2.SingleDocumentRequestProto.FromString,
- )
+ '/jina.JinaSingleDocumentRequestRPC/stream_doc',
+ request_serializer=jina__pb2.SingleDocumentRequestProto.SerializeToString,
+ response_deserializer=jina__pb2.SingleDocumentRequestProto.FromString,
+ )
class JinaSingleDocumentRequestRPCServicer(object):
@@ -171,8 +195,7 @@ class JinaSingleDocumentRequestRPCServicer(object):
"""
def stream_doc(self, request, context):
- """Used for streaming one document to the Executors
- """
+ """Used for streaming one document to the Executors"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
@@ -180,18 +203,19 @@ def stream_doc(self, request, context):
def add_JinaSingleDocumentRequestRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'stream_doc': grpc.unary_stream_rpc_method_handler(
- servicer.stream_doc,
- request_deserializer=jina__pb2.SingleDocumentRequestProto.FromString,
- response_serializer=jina__pb2.SingleDocumentRequestProto.SerializeToString,
- ),
+ 'stream_doc': grpc.unary_stream_rpc_method_handler(
+ servicer.stream_doc,
+ request_deserializer=jina__pb2.SingleDocumentRequestProto.FromString,
+ response_serializer=jina__pb2.SingleDocumentRequestProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaSingleDocumentRequestRPC', rpc_method_handlers)
+ 'jina.JinaSingleDocumentRequestRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaSingleDocumentRequestRPC(object):
"""*
jina gRPC service for DataRequests.
@@ -199,21 +223,33 @@ class JinaSingleDocumentRequestRPC(object):
"""
@staticmethod
- def stream_doc(request,
+ def stream_doc(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_stream(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_stream(request, target, '/jina.JinaSingleDocumentRequestRPC/stream_doc',
+ '/jina.JinaSingleDocumentRequestRPC/stream_doc',
jina__pb2.SingleDocumentRequestProto.SerializeToString,
jina__pb2.SingleDocumentRequestProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaRPCStub(object):
@@ -228,10 +264,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.Call = channel.stream_stream(
- '/jina.JinaRPC/Call',
- request_serializer=jina__pb2.DataRequestProto.SerializeToString,
- response_deserializer=jina__pb2.DataRequestProto.FromString,
- )
+ '/jina.JinaRPC/Call',
+ request_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ response_deserializer=jina__pb2.DataRequestProto.FromString,
+ )
class JinaRPCServicer(object):
@@ -240,8 +276,7 @@ class JinaRPCServicer(object):
"""
def Call(self, request_iterator, context):
- """Pass in a Request and a filled Request with matches will be returned.
- """
+ """Pass in a Request and a filled Request with matches will be returned."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
@@ -249,39 +284,52 @@ def Call(self, request_iterator, context):
def add_JinaRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'Call': grpc.stream_stream_rpc_method_handler(
- servicer.Call,
- request_deserializer=jina__pb2.DataRequestProto.FromString,
- response_serializer=jina__pb2.DataRequestProto.SerializeToString,
- ),
+ 'Call': grpc.stream_stream_rpc_method_handler(
+ servicer.Call,
+ request_deserializer=jina__pb2.DataRequestProto.FromString,
+ response_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaRPC', rpc_method_handlers)
+ 'jina.JinaRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaRPC(object):
"""*
jina streaming gRPC service.
"""
@staticmethod
- def Call(request_iterator,
+ def Call(
+ request_iterator,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.stream_stream(
+ request_iterator,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.stream_stream(request_iterator, target, '/jina.JinaRPC/Call',
+ '/jina.JinaRPC/Call',
jina__pb2.DataRequestProto.SerializeToString,
jina__pb2.DataRequestProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaDiscoverEndpointsRPCStub(object):
@@ -296,10 +344,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.endpoint_discovery = channel.unary_unary(
- '/jina.JinaDiscoverEndpointsRPC/endpoint_discovery',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=jina__pb2.EndpointsProto.FromString,
- )
+ '/jina.JinaDiscoverEndpointsRPC/endpoint_discovery',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=jina__pb2.EndpointsProto.FromString,
+ )
class JinaDiscoverEndpointsRPCServicer(object):
@@ -316,39 +364,52 @@ def endpoint_discovery(self, request, context):
def add_JinaDiscoverEndpointsRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'endpoint_discovery': grpc.unary_unary_rpc_method_handler(
- servicer.endpoint_discovery,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=jina__pb2.EndpointsProto.SerializeToString,
- ),
+ 'endpoint_discovery': grpc.unary_unary_rpc_method_handler(
+ servicer.endpoint_discovery,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=jina__pb2.EndpointsProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaDiscoverEndpointsRPC', rpc_method_handlers)
+ 'jina.JinaDiscoverEndpointsRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaDiscoverEndpointsRPC(object):
"""*
jina gRPC service to expose Endpoints from Executors.
"""
@staticmethod
- def endpoint_discovery(request,
+ def endpoint_discovery(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaDiscoverEndpointsRPC/endpoint_discovery',
+ '/jina.JinaDiscoverEndpointsRPC/endpoint_discovery',
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
jina__pb2.EndpointsProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaGatewayDryRunRPCStub(object):
@@ -363,10 +424,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.dry_run = channel.unary_unary(
- '/jina.JinaGatewayDryRunRPC/dry_run',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=jina__pb2.StatusProto.FromString,
- )
+ '/jina.JinaGatewayDryRunRPC/dry_run',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=jina__pb2.StatusProto.FromString,
+ )
class JinaGatewayDryRunRPCServicer(object):
@@ -383,39 +444,52 @@ def dry_run(self, request, context):
def add_JinaGatewayDryRunRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'dry_run': grpc.unary_unary_rpc_method_handler(
- servicer.dry_run,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=jina__pb2.StatusProto.SerializeToString,
- ),
+ 'dry_run': grpc.unary_unary_rpc_method_handler(
+ servicer.dry_run,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=jina__pb2.StatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaGatewayDryRunRPC', rpc_method_handlers)
+ 'jina.JinaGatewayDryRunRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaGatewayDryRunRPC(object):
"""*
jina gRPC service to expose Endpoints from Executors.
"""
@staticmethod
- def dry_run(request,
+ def dry_run(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaGatewayDryRunRPC/dry_run',
+ '/jina.JinaGatewayDryRunRPC/dry_run',
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
jina__pb2.StatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaInfoRPCStub(object):
@@ -430,10 +504,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self._status = channel.unary_unary(
- '/jina.JinaInfoRPC/_status',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=jina__pb2.JinaInfoProto.FromString,
- )
+ '/jina.JinaInfoRPC/_status',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=jina__pb2.JinaInfoProto.FromString,
+ )
class JinaInfoRPCServicer(object):
@@ -450,39 +524,52 @@ def _status(self, request, context):
def add_JinaInfoRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- '_status': grpc.unary_unary_rpc_method_handler(
- servicer._status,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=jina__pb2.JinaInfoProto.SerializeToString,
- ),
+ '_status': grpc.unary_unary_rpc_method_handler(
+ servicer._status,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=jina__pb2.JinaInfoProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaInfoRPC', rpc_method_handlers)
+ 'jina.JinaInfoRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaInfoRPC(object):
"""*
jina gRPC service to expose information about running jina version and environment.
"""
@staticmethod
- def _status(request,
+ def _status(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaInfoRPC/_status',
+ '/jina.JinaInfoRPC/_status',
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
jina__pb2.JinaInfoProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaExecutorSnapshotStub(object):
@@ -497,10 +584,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.snapshot = channel.unary_unary(
- '/jina.JinaExecutorSnapshot/snapshot',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=jina__pb2.SnapshotStatusProto.FromString,
- )
+ '/jina.JinaExecutorSnapshot/snapshot',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=jina__pb2.SnapshotStatusProto.FromString,
+ )
class JinaExecutorSnapshotServicer(object):
@@ -517,39 +604,52 @@ def snapshot(self, request, context):
def add_JinaExecutorSnapshotServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'snapshot': grpc.unary_unary_rpc_method_handler(
- servicer.snapshot,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=jina__pb2.SnapshotStatusProto.SerializeToString,
- ),
+ 'snapshot': grpc.unary_unary_rpc_method_handler(
+ servicer.snapshot,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=jina__pb2.SnapshotStatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaExecutorSnapshot', rpc_method_handlers)
+ 'jina.JinaExecutorSnapshot', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaExecutorSnapshot(object):
"""*
jina gRPC service to trigger a snapshot at the Executor Runtime.
"""
@staticmethod
- def snapshot(request,
+ def snapshot(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaExecutorSnapshot/snapshot',
+ '/jina.JinaExecutorSnapshot/snapshot',
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
jina__pb2.SnapshotStatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaExecutorSnapshotProgressStub(object):
@@ -564,10 +664,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.snapshot_status = channel.unary_unary(
- '/jina.JinaExecutorSnapshotProgress/snapshot_status',
- request_serializer=jina__pb2.SnapshotId.SerializeToString,
- response_deserializer=jina__pb2.SnapshotStatusProto.FromString,
- )
+ '/jina.JinaExecutorSnapshotProgress/snapshot_status',
+ request_serializer=jina__pb2.SnapshotId.SerializeToString,
+ response_deserializer=jina__pb2.SnapshotStatusProto.FromString,
+ )
class JinaExecutorSnapshotProgressServicer(object):
@@ -584,39 +684,52 @@ def snapshot_status(self, request, context):
def add_JinaExecutorSnapshotProgressServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'snapshot_status': grpc.unary_unary_rpc_method_handler(
- servicer.snapshot_status,
- request_deserializer=jina__pb2.SnapshotId.FromString,
- response_serializer=jina__pb2.SnapshotStatusProto.SerializeToString,
- ),
+ 'snapshot_status': grpc.unary_unary_rpc_method_handler(
+ servicer.snapshot_status,
+ request_deserializer=jina__pb2.SnapshotId.FromString,
+ response_serializer=jina__pb2.SnapshotStatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaExecutorSnapshotProgress', rpc_method_handlers)
+ 'jina.JinaExecutorSnapshotProgress', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaExecutorSnapshotProgress(object):
"""*
jina gRPC service to trigger a snapshot at the Executor Runtime.
"""
@staticmethod
- def snapshot_status(request,
+ def snapshot_status(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaExecutorSnapshotProgress/snapshot_status',
+ '/jina.JinaExecutorSnapshotProgress/snapshot_status',
jina__pb2.SnapshotId.SerializeToString,
jina__pb2.SnapshotStatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaExecutorRestoreStub(object):
@@ -631,10 +744,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.restore = channel.unary_unary(
- '/jina.JinaExecutorRestore/restore',
- request_serializer=jina__pb2.RestoreSnapshotCommand.SerializeToString,
- response_deserializer=jina__pb2.RestoreSnapshotStatusProto.FromString,
- )
+ '/jina.JinaExecutorRestore/restore',
+ request_serializer=jina__pb2.RestoreSnapshotCommand.SerializeToString,
+ response_deserializer=jina__pb2.RestoreSnapshotStatusProto.FromString,
+ )
class JinaExecutorRestoreServicer(object):
@@ -651,39 +764,52 @@ def restore(self, request, context):
def add_JinaExecutorRestoreServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'restore': grpc.unary_unary_rpc_method_handler(
- servicer.restore,
- request_deserializer=jina__pb2.RestoreSnapshotCommand.FromString,
- response_serializer=jina__pb2.RestoreSnapshotStatusProto.SerializeToString,
- ),
+ 'restore': grpc.unary_unary_rpc_method_handler(
+ servicer.restore,
+ request_deserializer=jina__pb2.RestoreSnapshotCommand.FromString,
+ response_serializer=jina__pb2.RestoreSnapshotStatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaExecutorRestore', rpc_method_handlers)
+ 'jina.JinaExecutorRestore', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaExecutorRestore(object):
"""*
jina gRPC service to trigger a restore at the Executor Runtime.
"""
@staticmethod
- def restore(request,
+ def restore(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaExecutorRestore/restore',
+ '/jina.JinaExecutorRestore/restore',
jina__pb2.RestoreSnapshotCommand.SerializeToString,
jina__pb2.RestoreSnapshotStatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaExecutorRestoreProgressStub(object):
@@ -698,10 +824,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.restore_status = channel.unary_unary(
- '/jina.JinaExecutorRestoreProgress/restore_status',
- request_serializer=jina__pb2.RestoreId.SerializeToString,
- response_deserializer=jina__pb2.RestoreSnapshotStatusProto.FromString,
- )
+ '/jina.JinaExecutorRestoreProgress/restore_status',
+ request_serializer=jina__pb2.RestoreId.SerializeToString,
+ response_deserializer=jina__pb2.RestoreSnapshotStatusProto.FromString,
+ )
class JinaExecutorRestoreProgressServicer(object):
@@ -718,36 +844,49 @@ def restore_status(self, request, context):
def add_JinaExecutorRestoreProgressServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'restore_status': grpc.unary_unary_rpc_method_handler(
- servicer.restore_status,
- request_deserializer=jina__pb2.RestoreId.FromString,
- response_serializer=jina__pb2.RestoreSnapshotStatusProto.SerializeToString,
- ),
+ 'restore_status': grpc.unary_unary_rpc_method_handler(
+ servicer.restore_status,
+ request_deserializer=jina__pb2.RestoreId.FromString,
+ response_serializer=jina__pb2.RestoreSnapshotStatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaExecutorRestoreProgress', rpc_method_handlers)
+ 'jina.JinaExecutorRestoreProgress', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaExecutorRestoreProgress(object):
"""*
jina gRPC service to trigger a snapshot at the Executor Runtime.
"""
@staticmethod
- def restore_status(request,
+ def restore_status(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaExecutorRestoreProgress/restore_status',
+ '/jina.JinaExecutorRestoreProgress/restore_status',
jina__pb2.RestoreId.SerializeToString,
jina__pb2.RestoreSnapshotStatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
diff --git a/marie/proto/docarray_v2/pb/jina_pb2.py b/marie/proto/docarray_v2/pb/jina_pb2.py
index 5b19af00..37adc470 100644
--- a/marie/proto/docarray_v2/pb/jina_pb2.py
+++ b/marie/proto/docarray_v2/pb/jina_pb2.py
@@ -6,6 +6,7 @@
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
+
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@@ -17,81 +18,83 @@
import docarray.proto.pb.docarray_pb2 as docarray__pb2
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\njina.proto\x12\x04jina\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x0e\x64ocarray.proto\"\x9f\x01\n\nRouteProto\x12\x10\n\x08\x65xecutor\x18\x01 \x01(\t\x12.\n\nstart_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12!\n\x06status\x18\x04 \x01(\x0b\x32\x11.jina.StatusProto\"\xc3\x01\n\rJinaInfoProto\x12+\n\x04jina\x18\x01 \x03(\x0b\x32\x1d.jina.JinaInfoProto.JinaEntry\x12+\n\x04\x65nvs\x18\x02 \x03(\x0b\x32\x1d.jina.JinaInfoProto.EnvsEntry\x1a+\n\tJinaEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a+\n\tEnvsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc6\x01\n\x0bHeaderProto\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12!\n\x06status\x18\x02 \x01(\x0b\x32\x11.jina.StatusProto\x12\x1a\n\rexec_endpoint\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0ftarget_executor\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x14\n\x07timeout\x18\x05 \x01(\rH\x02\x88\x01\x01\x42\x10\n\x0e_exec_endpointB\x12\n\x10_target_executorB\n\n\x08_timeout\"f\n\x0e\x45ndpointsProto\x12\x11\n\tendpoints\x18\x01 \x03(\t\x12\x17\n\x0fwrite_endpoints\x18\x02 \x03(\t\x12(\n\x07schemas\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xf9\x01\n\x0bStatusProto\x12*\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1c.jina.StatusProto.StatusCode\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x33\n\texception\x18\x03 \x01(\x0b\x32 .jina.StatusProto.ExceptionProto\x1aN\n\x0e\x45xceptionProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x03(\t\x12\x0e\n\x06stacks\x18\x03 \x03(\t\x12\x10\n\x08\x65xecutor\x18\x04 \x01(\t\"$\n\nStatusCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\t\n\x05\x45RROR\x10\x01\"^\n\rRelatedEntity\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x07\x61\x64\x64ress\x18\x02 \x01(\t\x12\x0c\n\x04port\x18\x03 \x01(\r\x12\x15\n\x08shard_id\x18\x04 \x01(\rH\x00\x88\x01\x01\x42\x0b\n\t_shard_id\"\x9a\x02\n\x10\x44\x61taRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12\x35\n\x04\x64\x61ta\x18\x04 \x01(\x0b\x32\'.jina.DataRequestProto.DataContentProto\x1a]\n\x10\x44\x61taContentProto\x12&\n\x04\x64ocs\x18\x01 \x01(\x0b\x32\x16.docarray.DocListProtoH\x00\x12\x14\n\ndocs_bytes\x18\x02 \x01(\x0cH\x00\x42\x0b\n\tdocuments\"\xb4\x01\n\x1aSingleDocumentRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12$\n\x08\x64ocument\x18\x04 \x01(\x0b\x32\x12.docarray.DocProto\"\x8a\x01\n\x16\x44\x61taRequestProtoWoData\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\"@\n\x14\x44\x61taRequestListProto\x12(\n\x08requests\x18\x01 \x03(\x0b\x32\x16.jina.DataRequestProto\"\x1b\n\nSnapshotId\x12\r\n\x05value\x18\x01 \x01(\t\"\x1a\n\tRestoreId\x12\r\n\x05value\x18\x01 \x01(\t\"\xef\x01\n\x13SnapshotStatusProto\x12\x1c\n\x02id\x18\x01 \x01(\x0b\x32\x10.jina.SnapshotId\x12\x30\n\x06status\x18\x02 \x01(\x0e\x32 .jina.SnapshotStatusProto.Status\x12\x15\n\rsnapshot_file\x18\x03 \x01(\t\"q\n\x06Status\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\r\n\tSCHEDULED\x10\x02\x12\x0b\n\x07RUNNING\x10\x03\x12\r\n\tSUCCEEDED\x10\x04\x12\n\n\x06\x46\x41ILED\x10\x05\x12\r\n\tNOT_FOUND\x10\x06\"\xca\x01\n\x1aRestoreSnapshotStatusProto\x12\x1b\n\x02id\x18\x01 \x01(\x0b\x32\x0f.jina.RestoreId\x12\x37\n\x06status\x18\x02 \x01(\x0e\x32\'.jina.RestoreSnapshotStatusProto.Status\"V\n\x06Status\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\x0b\n\x07RUNNING\x10\x01\x12\r\n\tSUCCEEDED\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\r\n\tNOT_FOUND\x10\x06\"/\n\x16RestoreSnapshotCommand\x12\x15\n\rsnapshot_file\x18\x01 \x01(\t2Z\n\x12JinaDataRequestRPC\x12\x44\n\x0cprocess_data\x12\x1a.jina.DataRequestListProto\x1a\x16.jina.DataRequestProto\"\x00\x32\x63\n\x18JinaSingleDataRequestRPC\x12G\n\x13process_single_data\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00\x32t\n\x1cJinaSingleDocumentRequestRPC\x12T\n\nstream_doc\x12 .jina.SingleDocumentRequestProto\x1a .jina.SingleDocumentRequestProto\"\x00\x30\x01\x32G\n\x07JinaRPC\x12<\n\x04\x43\x61ll\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00(\x01\x30\x01\x32`\n\x18JinaDiscoverEndpointsRPC\x12\x44\n\x12\x65ndpoint_discovery\x12\x16.google.protobuf.Empty\x1a\x14.jina.EndpointsProto\"\x00\x32N\n\x14JinaGatewayDryRunRPC\x12\x36\n\x07\x64ry_run\x12\x16.google.protobuf.Empty\x1a\x11.jina.StatusProto\"\x00\x32G\n\x0bJinaInfoRPC\x12\x38\n\x07_status\x12\x16.google.protobuf.Empty\x1a\x13.jina.JinaInfoProto\"\x00\x32W\n\x14JinaExecutorSnapshot\x12?\n\x08snapshot\x12\x16.google.protobuf.Empty\x1a\x19.jina.SnapshotStatusProto\"\x00\x32`\n\x1cJinaExecutorSnapshotProgress\x12@\n\x0fsnapshot_status\x12\x10.jina.SnapshotId\x1a\x19.jina.SnapshotStatusProto\"\x00\x32\x62\n\x13JinaExecutorRestore\x12K\n\x07restore\x12\x1c.jina.RestoreSnapshotCommand\x1a .jina.RestoreSnapshotStatusProto\"\x00\x32\x64\n\x1bJinaExecutorRestoreProgress\x12\x45\n\x0erestore_status\x12\x0f.jina.RestoreId\x1a .jina.RestoreSnapshotStatusProto\"\x00\x62\x06proto3')
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\njina.proto\x12\x04jina\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x0e\x64ocarray.proto\"\x9f\x01\n\nRouteProto\x12\x10\n\x08\x65xecutor\x18\x01 \x01(\t\x12.\n\nstart_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12!\n\x06status\x18\x04 \x01(\x0b\x32\x11.jina.StatusProto\"\xc3\x01\n\rJinaInfoProto\x12+\n\x04jina\x18\x01 \x03(\x0b\x32\x1d.jina.JinaInfoProto.JinaEntry\x12+\n\x04\x65nvs\x18\x02 \x03(\x0b\x32\x1d.jina.JinaInfoProto.EnvsEntry\x1a+\n\tJinaEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a+\n\tEnvsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc6\x01\n\x0bHeaderProto\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12!\n\x06status\x18\x02 \x01(\x0b\x32\x11.jina.StatusProto\x12\x1a\n\rexec_endpoint\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0ftarget_executor\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x14\n\x07timeout\x18\x05 \x01(\rH\x02\x88\x01\x01\x42\x10\n\x0e_exec_endpointB\x12\n\x10_target_executorB\n\n\x08_timeout\"f\n\x0e\x45ndpointsProto\x12\x11\n\tendpoints\x18\x01 \x03(\t\x12\x17\n\x0fwrite_endpoints\x18\x02 \x03(\t\x12(\n\x07schemas\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xf9\x01\n\x0bStatusProto\x12*\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1c.jina.StatusProto.StatusCode\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x33\n\texception\x18\x03 \x01(\x0b\x32 .jina.StatusProto.ExceptionProto\x1aN\n\x0e\x45xceptionProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x03(\t\x12\x0e\n\x06stacks\x18\x03 \x03(\t\x12\x10\n\x08\x65xecutor\x18\x04 \x01(\t\"$\n\nStatusCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\t\n\x05\x45RROR\x10\x01\"^\n\rRelatedEntity\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x07\x61\x64\x64ress\x18\x02 \x01(\t\x12\x0c\n\x04port\x18\x03 \x01(\r\x12\x15\n\x08shard_id\x18\x04 \x01(\rH\x00\x88\x01\x01\x42\x0b\n\t_shard_id\"\x9a\x02\n\x10\x44\x61taRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12\x35\n\x04\x64\x61ta\x18\x04 \x01(\x0b\x32\'.jina.DataRequestProto.DataContentProto\x1a]\n\x10\x44\x61taContentProto\x12&\n\x04\x64ocs\x18\x01 \x01(\x0b\x32\x16.docarray.DocListProtoH\x00\x12\x14\n\ndocs_bytes\x18\x02 \x01(\x0cH\x00\x42\x0b\n\tdocuments\"\xb4\x01\n\x1aSingleDocumentRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12$\n\x08\x64ocument\x18\x04 \x01(\x0b\x32\x12.docarray.DocProto\"\x8a\x01\n\x16\x44\x61taRequestProtoWoData\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\"@\n\x14\x44\x61taRequestListProto\x12(\n\x08requests\x18\x01 \x03(\x0b\x32\x16.jina.DataRequestProto\"\x1b\n\nSnapshotId\x12\r\n\x05value\x18\x01 \x01(\t\"\x1a\n\tRestoreId\x12\r\n\x05value\x18\x01 \x01(\t\"\xef\x01\n\x13SnapshotStatusProto\x12\x1c\n\x02id\x18\x01 \x01(\x0b\x32\x10.jina.SnapshotId\x12\x30\n\x06status\x18\x02 \x01(\x0e\x32 .jina.SnapshotStatusProto.Status\x12\x15\n\rsnapshot_file\x18\x03 \x01(\t\"q\n\x06Status\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\r\n\tSCHEDULED\x10\x02\x12\x0b\n\x07RUNNING\x10\x03\x12\r\n\tSUCCEEDED\x10\x04\x12\n\n\x06\x46\x41ILED\x10\x05\x12\r\n\tNOT_FOUND\x10\x06\"\xca\x01\n\x1aRestoreSnapshotStatusProto\x12\x1b\n\x02id\x18\x01 \x01(\x0b\x32\x0f.jina.RestoreId\x12\x37\n\x06status\x18\x02 \x01(\x0e\x32\'.jina.RestoreSnapshotStatusProto.Status\"V\n\x06Status\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\x0b\n\x07RUNNING\x10\x01\x12\r\n\tSUCCEEDED\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\r\n\tNOT_FOUND\x10\x06\"/\n\x16RestoreSnapshotCommand\x12\x15\n\rsnapshot_file\x18\x01 \x01(\t2Z\n\x12JinaDataRequestRPC\x12\x44\n\x0cprocess_data\x12\x1a.jina.DataRequestListProto\x1a\x16.jina.DataRequestProto\"\x00\x32\x63\n\x18JinaSingleDataRequestRPC\x12G\n\x13process_single_data\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00\x32t\n\x1cJinaSingleDocumentRequestRPC\x12T\n\nstream_doc\x12 .jina.SingleDocumentRequestProto\x1a .jina.SingleDocumentRequestProto\"\x00\x30\x01\x32G\n\x07JinaRPC\x12<\n\x04\x43\x61ll\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00(\x01\x30\x01\x32`\n\x18JinaDiscoverEndpointsRPC\x12\x44\n\x12\x65ndpoint_discovery\x12\x16.google.protobuf.Empty\x1a\x14.jina.EndpointsProto\"\x00\x32N\n\x14JinaGatewayDryRunRPC\x12\x36\n\x07\x64ry_run\x12\x16.google.protobuf.Empty\x1a\x11.jina.StatusProto\"\x00\x32G\n\x0bJinaInfoRPC\x12\x38\n\x07_status\x12\x16.google.protobuf.Empty\x1a\x13.jina.JinaInfoProto\"\x00\x32W\n\x14JinaExecutorSnapshot\x12?\n\x08snapshot\x12\x16.google.protobuf.Empty\x1a\x19.jina.SnapshotStatusProto\"\x00\x32`\n\x1cJinaExecutorSnapshotProgress\x12@\n\x0fsnapshot_status\x12\x10.jina.SnapshotId\x1a\x19.jina.SnapshotStatusProto\"\x00\x32\x62\n\x13JinaExecutorRestore\x12K\n\x07restore\x12\x1c.jina.RestoreSnapshotCommand\x1a .jina.RestoreSnapshotStatusProto\"\x00\x32\x64\n\x1bJinaExecutorRestoreProgress\x12\x45\n\x0erestore_status\x12\x0f.jina.RestoreId\x1a .jina.RestoreSnapshotStatusProto\"\x00\x62\x06proto3'
+)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'jina_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
- DESCRIPTOR._options = None
- _JINAINFOPROTO_JINAENTRY._options = None
- _JINAINFOPROTO_JINAENTRY._serialized_options = b'8\001'
- _JINAINFOPROTO_ENVSENTRY._options = None
- _JINAINFOPROTO_ENVSENTRY._serialized_options = b'8\001'
- _ROUTEPROTO._serialized_start=129
- _ROUTEPROTO._serialized_end=288
- _JINAINFOPROTO._serialized_start=291
- _JINAINFOPROTO._serialized_end=486
- _JINAINFOPROTO_JINAENTRY._serialized_start=398
- _JINAINFOPROTO_JINAENTRY._serialized_end=441
- _JINAINFOPROTO_ENVSENTRY._serialized_start=443
- _JINAINFOPROTO_ENVSENTRY._serialized_end=486
- _HEADERPROTO._serialized_start=489
- _HEADERPROTO._serialized_end=687
- _ENDPOINTSPROTO._serialized_start=689
- _ENDPOINTSPROTO._serialized_end=791
- _STATUSPROTO._serialized_start=794
- _STATUSPROTO._serialized_end=1043
- _STATUSPROTO_EXCEPTIONPROTO._serialized_start=927
- _STATUSPROTO_EXCEPTIONPROTO._serialized_end=1005
- _STATUSPROTO_STATUSCODE._serialized_start=1007
- _STATUSPROTO_STATUSCODE._serialized_end=1043
- _RELATEDENTITY._serialized_start=1045
- _RELATEDENTITY._serialized_end=1139
- _DATAREQUESTPROTO._serialized_start=1142
- _DATAREQUESTPROTO._serialized_end=1424
- _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_start=1331
- _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_end=1424
- _SINGLEDOCUMENTREQUESTPROTO._serialized_start=1427
- _SINGLEDOCUMENTREQUESTPROTO._serialized_end=1607
- _DATAREQUESTPROTOWODATA._serialized_start=1610
- _DATAREQUESTPROTOWODATA._serialized_end=1748
- _DATAREQUESTLISTPROTO._serialized_start=1750
- _DATAREQUESTLISTPROTO._serialized_end=1814
- _SNAPSHOTID._serialized_start=1816
- _SNAPSHOTID._serialized_end=1843
- _RESTOREID._serialized_start=1845
- _RESTOREID._serialized_end=1871
- _SNAPSHOTSTATUSPROTO._serialized_start=1874
- _SNAPSHOTSTATUSPROTO._serialized_end=2113
- _SNAPSHOTSTATUSPROTO_STATUS._serialized_start=2000
- _SNAPSHOTSTATUSPROTO_STATUS._serialized_end=2113
- _RESTORESNAPSHOTSTATUSPROTO._serialized_start=2116
- _RESTORESNAPSHOTSTATUSPROTO._serialized_end=2318
- _RESTORESNAPSHOTSTATUSPROTO_STATUS._serialized_start=2232
- _RESTORESNAPSHOTSTATUSPROTO_STATUS._serialized_end=2318
- _RESTORESNAPSHOTCOMMAND._serialized_start=2320
- _RESTORESNAPSHOTCOMMAND._serialized_end=2367
- _JINADATAREQUESTRPC._serialized_start=2369
- _JINADATAREQUESTRPC._serialized_end=2459
- _JINASINGLEDATAREQUESTRPC._serialized_start=2461
- _JINASINGLEDATAREQUESTRPC._serialized_end=2560
- _JINASINGLEDOCUMENTREQUESTRPC._serialized_start=2562
- _JINASINGLEDOCUMENTREQUESTRPC._serialized_end=2678
- _JINARPC._serialized_start=2680
- _JINARPC._serialized_end=2751
- _JINADISCOVERENDPOINTSRPC._serialized_start=2753
- _JINADISCOVERENDPOINTSRPC._serialized_end=2849
- _JINAGATEWAYDRYRUNRPC._serialized_start=2851
- _JINAGATEWAYDRYRUNRPC._serialized_end=2929
- _JINAINFORPC._serialized_start=2931
- _JINAINFORPC._serialized_end=3002
- _JINAEXECUTORSNAPSHOT._serialized_start=3004
- _JINAEXECUTORSNAPSHOT._serialized_end=3091
- _JINAEXECUTORSNAPSHOTPROGRESS._serialized_start=3093
- _JINAEXECUTORSNAPSHOTPROGRESS._serialized_end=3189
- _JINAEXECUTORRESTORE._serialized_start=3191
- _JINAEXECUTORRESTORE._serialized_end=3289
- _JINAEXECUTORRESTOREPROGRESS._serialized_start=3291
- _JINAEXECUTORRESTOREPROGRESS._serialized_end=3391
+ DESCRIPTOR._options = None
+ _JINAINFOPROTO_JINAENTRY._options = None
+ _JINAINFOPROTO_JINAENTRY._serialized_options = b'8\001'
+ _JINAINFOPROTO_ENVSENTRY._options = None
+ _JINAINFOPROTO_ENVSENTRY._serialized_options = b'8\001'
+ _ROUTEPROTO._serialized_start = 129
+ _ROUTEPROTO._serialized_end = 288
+ _JINAINFOPROTO._serialized_start = 291
+ _JINAINFOPROTO._serialized_end = 486
+ _JINAINFOPROTO_JINAENTRY._serialized_start = 398
+ _JINAINFOPROTO_JINAENTRY._serialized_end = 441
+ _JINAINFOPROTO_ENVSENTRY._serialized_start = 443
+ _JINAINFOPROTO_ENVSENTRY._serialized_end = 486
+ _HEADERPROTO._serialized_start = 489
+ _HEADERPROTO._serialized_end = 687
+ _ENDPOINTSPROTO._serialized_start = 689
+ _ENDPOINTSPROTO._serialized_end = 791
+ _STATUSPROTO._serialized_start = 794
+ _STATUSPROTO._serialized_end = 1043
+ _STATUSPROTO_EXCEPTIONPROTO._serialized_start = 927
+ _STATUSPROTO_EXCEPTIONPROTO._serialized_end = 1005
+ _STATUSPROTO_STATUSCODE._serialized_start = 1007
+ _STATUSPROTO_STATUSCODE._serialized_end = 1043
+ _RELATEDENTITY._serialized_start = 1045
+ _RELATEDENTITY._serialized_end = 1139
+ _DATAREQUESTPROTO._serialized_start = 1142
+ _DATAREQUESTPROTO._serialized_end = 1424
+ _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_start = 1331
+ _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_end = 1424
+ _SINGLEDOCUMENTREQUESTPROTO._serialized_start = 1427
+ _SINGLEDOCUMENTREQUESTPROTO._serialized_end = 1607
+ _DATAREQUESTPROTOWODATA._serialized_start = 1610
+ _DATAREQUESTPROTOWODATA._serialized_end = 1748
+ _DATAREQUESTLISTPROTO._serialized_start = 1750
+ _DATAREQUESTLISTPROTO._serialized_end = 1814
+ _SNAPSHOTID._serialized_start = 1816
+ _SNAPSHOTID._serialized_end = 1843
+ _RESTOREID._serialized_start = 1845
+ _RESTOREID._serialized_end = 1871
+ _SNAPSHOTSTATUSPROTO._serialized_start = 1874
+ _SNAPSHOTSTATUSPROTO._serialized_end = 2113
+ _SNAPSHOTSTATUSPROTO_STATUS._serialized_start = 2000
+ _SNAPSHOTSTATUSPROTO_STATUS._serialized_end = 2113
+ _RESTORESNAPSHOTSTATUSPROTO._serialized_start = 2116
+ _RESTORESNAPSHOTSTATUSPROTO._serialized_end = 2318
+ _RESTORESNAPSHOTSTATUSPROTO_STATUS._serialized_start = 2232
+ _RESTORESNAPSHOTSTATUSPROTO_STATUS._serialized_end = 2318
+ _RESTORESNAPSHOTCOMMAND._serialized_start = 2320
+ _RESTORESNAPSHOTCOMMAND._serialized_end = 2367
+ _JINADATAREQUESTRPC._serialized_start = 2369
+ _JINADATAREQUESTRPC._serialized_end = 2459
+ _JINASINGLEDATAREQUESTRPC._serialized_start = 2461
+ _JINASINGLEDATAREQUESTRPC._serialized_end = 2560
+ _JINASINGLEDOCUMENTREQUESTRPC._serialized_start = 2562
+ _JINASINGLEDOCUMENTREQUESTRPC._serialized_end = 2678
+ _JINARPC._serialized_start = 2680
+ _JINARPC._serialized_end = 2751
+ _JINADISCOVERENDPOINTSRPC._serialized_start = 2753
+ _JINADISCOVERENDPOINTSRPC._serialized_end = 2849
+ _JINAGATEWAYDRYRUNRPC._serialized_start = 2851
+ _JINAGATEWAYDRYRUNRPC._serialized_end = 2929
+ _JINAINFORPC._serialized_start = 2931
+ _JINAINFORPC._serialized_end = 3002
+ _JINAEXECUTORSNAPSHOT._serialized_start = 3004
+ _JINAEXECUTORSNAPSHOT._serialized_end = 3091
+ _JINAEXECUTORSNAPSHOTPROGRESS._serialized_start = 3093
+ _JINAEXECUTORSNAPSHOTPROGRESS._serialized_end = 3189
+ _JINAEXECUTORRESTORE._serialized_start = 3191
+ _JINAEXECUTORRESTORE._serialized_end = 3289
+ _JINAEXECUTORRESTOREPROGRESS._serialized_start = 3291
+ _JINAEXECUTORRESTOREPROGRESS._serialized_end = 3391
# @@protoc_insertion_point(module_scope)
diff --git a/marie/proto/docarray_v2/pb/jina_pb2_grpc.py b/marie/proto/docarray_v2/pb/jina_pb2_grpc.py
index f52ce19e..f571beae 100644
--- a/marie/proto/docarray_v2/pb/jina_pb2_grpc.py
+++ b/marie/proto/docarray_v2/pb/jina_pb2_grpc.py
@@ -18,10 +18,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.process_data = channel.unary_unary(
- '/jina.JinaDataRequestRPC/process_data',
- request_serializer=jina__pb2.DataRequestListProto.SerializeToString,
- response_deserializer=jina__pb2.DataRequestProto.FromString,
- )
+ '/jina.JinaDataRequestRPC/process_data',
+ request_serializer=jina__pb2.DataRequestListProto.SerializeToString,
+ response_deserializer=jina__pb2.DataRequestProto.FromString,
+ )
class JinaDataRequestRPCServicer(object):
@@ -30,8 +30,7 @@ class JinaDataRequestRPCServicer(object):
"""
def process_data(self, request, context):
- """Used for passing DataRequests to the Executors
- """
+ """Used for passing DataRequests to the Executors"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
@@ -39,39 +38,52 @@ def process_data(self, request, context):
def add_JinaDataRequestRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'process_data': grpc.unary_unary_rpc_method_handler(
- servicer.process_data,
- request_deserializer=jina__pb2.DataRequestListProto.FromString,
- response_serializer=jina__pb2.DataRequestProto.SerializeToString,
- ),
+ 'process_data': grpc.unary_unary_rpc_method_handler(
+ servicer.process_data,
+ request_deserializer=jina__pb2.DataRequestListProto.FromString,
+ response_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaDataRequestRPC', rpc_method_handlers)
+ 'jina.JinaDataRequestRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaDataRequestRPC(object):
"""*
jina gRPC service for DataRequests.
"""
@staticmethod
- def process_data(request,
+ def process_data(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaDataRequestRPC/process_data',
+ '/jina.JinaDataRequestRPC/process_data',
jina__pb2.DataRequestListProto.SerializeToString,
jina__pb2.DataRequestProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaSingleDataRequestRPCStub(object):
@@ -87,10 +99,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.process_single_data = channel.unary_unary(
- '/jina.JinaSingleDataRequestRPC/process_single_data',
- request_serializer=jina__pb2.DataRequestProto.SerializeToString,
- response_deserializer=jina__pb2.DataRequestProto.FromString,
- )
+ '/jina.JinaSingleDataRequestRPC/process_single_data',
+ request_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ response_deserializer=jina__pb2.DataRequestProto.FromString,
+ )
class JinaSingleDataRequestRPCServicer(object):
@@ -100,8 +112,7 @@ class JinaSingleDataRequestRPCServicer(object):
"""
def process_single_data(self, request, context):
- """Used for passing DataRequests to the Executors
- """
+ """Used for passing DataRequests to the Executors"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
@@ -109,18 +120,19 @@ def process_single_data(self, request, context):
def add_JinaSingleDataRequestRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'process_single_data': grpc.unary_unary_rpc_method_handler(
- servicer.process_single_data,
- request_deserializer=jina__pb2.DataRequestProto.FromString,
- response_serializer=jina__pb2.DataRequestProto.SerializeToString,
- ),
+ 'process_single_data': grpc.unary_unary_rpc_method_handler(
+ servicer.process_single_data,
+ request_deserializer=jina__pb2.DataRequestProto.FromString,
+ response_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaSingleDataRequestRPC', rpc_method_handlers)
+ 'jina.JinaSingleDataRequestRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaSingleDataRequestRPC(object):
"""*
jina gRPC service for DataRequests.
@@ -128,21 +140,33 @@ class JinaSingleDataRequestRPC(object):
"""
@staticmethod
- def process_single_data(request,
+ def process_single_data(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaSingleDataRequestRPC/process_single_data',
+ '/jina.JinaSingleDataRequestRPC/process_single_data',
jina__pb2.DataRequestProto.SerializeToString,
jina__pb2.DataRequestProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaSingleDocumentRequestRPCStub(object):
@@ -158,10 +182,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.stream_doc = channel.unary_stream(
- '/jina.JinaSingleDocumentRequestRPC/stream_doc',
- request_serializer=jina__pb2.SingleDocumentRequestProto.SerializeToString,
- response_deserializer=jina__pb2.SingleDocumentRequestProto.FromString,
- )
+ '/jina.JinaSingleDocumentRequestRPC/stream_doc',
+ request_serializer=jina__pb2.SingleDocumentRequestProto.SerializeToString,
+ response_deserializer=jina__pb2.SingleDocumentRequestProto.FromString,
+ )
class JinaSingleDocumentRequestRPCServicer(object):
@@ -171,8 +195,7 @@ class JinaSingleDocumentRequestRPCServicer(object):
"""
def stream_doc(self, request, context):
- """Used for streaming one document to the Executors
- """
+ """Used for streaming one document to the Executors"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
@@ -180,18 +203,19 @@ def stream_doc(self, request, context):
def add_JinaSingleDocumentRequestRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'stream_doc': grpc.unary_stream_rpc_method_handler(
- servicer.stream_doc,
- request_deserializer=jina__pb2.SingleDocumentRequestProto.FromString,
- response_serializer=jina__pb2.SingleDocumentRequestProto.SerializeToString,
- ),
+ 'stream_doc': grpc.unary_stream_rpc_method_handler(
+ servicer.stream_doc,
+ request_deserializer=jina__pb2.SingleDocumentRequestProto.FromString,
+ response_serializer=jina__pb2.SingleDocumentRequestProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaSingleDocumentRequestRPC', rpc_method_handlers)
+ 'jina.JinaSingleDocumentRequestRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaSingleDocumentRequestRPC(object):
"""*
jina gRPC service for DataRequests.
@@ -199,21 +223,33 @@ class JinaSingleDocumentRequestRPC(object):
"""
@staticmethod
- def stream_doc(request,
+ def stream_doc(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_stream(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_stream(request, target, '/jina.JinaSingleDocumentRequestRPC/stream_doc',
+ '/jina.JinaSingleDocumentRequestRPC/stream_doc',
jina__pb2.SingleDocumentRequestProto.SerializeToString,
jina__pb2.SingleDocumentRequestProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaRPCStub(object):
@@ -228,10 +264,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.Call = channel.stream_stream(
- '/jina.JinaRPC/Call',
- request_serializer=jina__pb2.DataRequestProto.SerializeToString,
- response_deserializer=jina__pb2.DataRequestProto.FromString,
- )
+ '/jina.JinaRPC/Call',
+ request_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ response_deserializer=jina__pb2.DataRequestProto.FromString,
+ )
class JinaRPCServicer(object):
@@ -240,8 +276,7 @@ class JinaRPCServicer(object):
"""
def Call(self, request_iterator, context):
- """Pass in a Request and a filled Request with matches will be returned.
- """
+ """Pass in a Request and a filled Request with matches will be returned."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
@@ -249,39 +284,52 @@ def Call(self, request_iterator, context):
def add_JinaRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'Call': grpc.stream_stream_rpc_method_handler(
- servicer.Call,
- request_deserializer=jina__pb2.DataRequestProto.FromString,
- response_serializer=jina__pb2.DataRequestProto.SerializeToString,
- ),
+ 'Call': grpc.stream_stream_rpc_method_handler(
+ servicer.Call,
+ request_deserializer=jina__pb2.DataRequestProto.FromString,
+ response_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaRPC', rpc_method_handlers)
+ 'jina.JinaRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaRPC(object):
"""*
jina streaming gRPC service.
"""
@staticmethod
- def Call(request_iterator,
+ def Call(
+ request_iterator,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.stream_stream(
+ request_iterator,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.stream_stream(request_iterator, target, '/jina.JinaRPC/Call',
+ '/jina.JinaRPC/Call',
jina__pb2.DataRequestProto.SerializeToString,
jina__pb2.DataRequestProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaDiscoverEndpointsRPCStub(object):
@@ -296,10 +344,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.endpoint_discovery = channel.unary_unary(
- '/jina.JinaDiscoverEndpointsRPC/endpoint_discovery',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=jina__pb2.EndpointsProto.FromString,
- )
+ '/jina.JinaDiscoverEndpointsRPC/endpoint_discovery',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=jina__pb2.EndpointsProto.FromString,
+ )
class JinaDiscoverEndpointsRPCServicer(object):
@@ -316,39 +364,52 @@ def endpoint_discovery(self, request, context):
def add_JinaDiscoverEndpointsRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'endpoint_discovery': grpc.unary_unary_rpc_method_handler(
- servicer.endpoint_discovery,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=jina__pb2.EndpointsProto.SerializeToString,
- ),
+ 'endpoint_discovery': grpc.unary_unary_rpc_method_handler(
+ servicer.endpoint_discovery,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=jina__pb2.EndpointsProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaDiscoverEndpointsRPC', rpc_method_handlers)
+ 'jina.JinaDiscoverEndpointsRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaDiscoverEndpointsRPC(object):
"""*
jina gRPC service to expose Endpoints from Executors.
"""
@staticmethod
- def endpoint_discovery(request,
+ def endpoint_discovery(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaDiscoverEndpointsRPC/endpoint_discovery',
+ '/jina.JinaDiscoverEndpointsRPC/endpoint_discovery',
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
jina__pb2.EndpointsProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaGatewayDryRunRPCStub(object):
@@ -363,10 +424,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.dry_run = channel.unary_unary(
- '/jina.JinaGatewayDryRunRPC/dry_run',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=jina__pb2.StatusProto.FromString,
- )
+ '/jina.JinaGatewayDryRunRPC/dry_run',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=jina__pb2.StatusProto.FromString,
+ )
class JinaGatewayDryRunRPCServicer(object):
@@ -383,39 +444,52 @@ def dry_run(self, request, context):
def add_JinaGatewayDryRunRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'dry_run': grpc.unary_unary_rpc_method_handler(
- servicer.dry_run,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=jina__pb2.StatusProto.SerializeToString,
- ),
+ 'dry_run': grpc.unary_unary_rpc_method_handler(
+ servicer.dry_run,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=jina__pb2.StatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaGatewayDryRunRPC', rpc_method_handlers)
+ 'jina.JinaGatewayDryRunRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaGatewayDryRunRPC(object):
"""*
jina gRPC service to expose Endpoints from Executors.
"""
@staticmethod
- def dry_run(request,
+ def dry_run(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaGatewayDryRunRPC/dry_run',
+ '/jina.JinaGatewayDryRunRPC/dry_run',
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
jina__pb2.StatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaInfoRPCStub(object):
@@ -430,10 +504,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self._status = channel.unary_unary(
- '/jina.JinaInfoRPC/_status',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=jina__pb2.JinaInfoProto.FromString,
- )
+ '/jina.JinaInfoRPC/_status',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=jina__pb2.JinaInfoProto.FromString,
+ )
class JinaInfoRPCServicer(object):
@@ -450,39 +524,52 @@ def _status(self, request, context):
def add_JinaInfoRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- '_status': grpc.unary_unary_rpc_method_handler(
- servicer._status,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=jina__pb2.JinaInfoProto.SerializeToString,
- ),
+ '_status': grpc.unary_unary_rpc_method_handler(
+ servicer._status,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=jina__pb2.JinaInfoProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaInfoRPC', rpc_method_handlers)
+ 'jina.JinaInfoRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaInfoRPC(object):
"""*
jina gRPC service to expose information about running jina version and environment.
"""
@staticmethod
- def _status(request,
+ def _status(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaInfoRPC/_status',
+ '/jina.JinaInfoRPC/_status',
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
jina__pb2.JinaInfoProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaExecutorSnapshotStub(object):
@@ -497,10 +584,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.snapshot = channel.unary_unary(
- '/jina.JinaExecutorSnapshot/snapshot',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=jina__pb2.SnapshotStatusProto.FromString,
- )
+ '/jina.JinaExecutorSnapshot/snapshot',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=jina__pb2.SnapshotStatusProto.FromString,
+ )
class JinaExecutorSnapshotServicer(object):
@@ -517,39 +604,52 @@ def snapshot(self, request, context):
def add_JinaExecutorSnapshotServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'snapshot': grpc.unary_unary_rpc_method_handler(
- servicer.snapshot,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=jina__pb2.SnapshotStatusProto.SerializeToString,
- ),
+ 'snapshot': grpc.unary_unary_rpc_method_handler(
+ servicer.snapshot,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=jina__pb2.SnapshotStatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaExecutorSnapshot', rpc_method_handlers)
+ 'jina.JinaExecutorSnapshot', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaExecutorSnapshot(object):
"""*
jina gRPC service to trigger a snapshot at the Executor Runtime.
"""
@staticmethod
- def snapshot(request,
+ def snapshot(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaExecutorSnapshot/snapshot',
+ '/jina.JinaExecutorSnapshot/snapshot',
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
jina__pb2.SnapshotStatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaExecutorSnapshotProgressStub(object):
@@ -564,10 +664,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.snapshot_status = channel.unary_unary(
- '/jina.JinaExecutorSnapshotProgress/snapshot_status',
- request_serializer=jina__pb2.SnapshotId.SerializeToString,
- response_deserializer=jina__pb2.SnapshotStatusProto.FromString,
- )
+ '/jina.JinaExecutorSnapshotProgress/snapshot_status',
+ request_serializer=jina__pb2.SnapshotId.SerializeToString,
+ response_deserializer=jina__pb2.SnapshotStatusProto.FromString,
+ )
class JinaExecutorSnapshotProgressServicer(object):
@@ -584,39 +684,52 @@ def snapshot_status(self, request, context):
def add_JinaExecutorSnapshotProgressServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'snapshot_status': grpc.unary_unary_rpc_method_handler(
- servicer.snapshot_status,
- request_deserializer=jina__pb2.SnapshotId.FromString,
- response_serializer=jina__pb2.SnapshotStatusProto.SerializeToString,
- ),
+ 'snapshot_status': grpc.unary_unary_rpc_method_handler(
+ servicer.snapshot_status,
+ request_deserializer=jina__pb2.SnapshotId.FromString,
+ response_serializer=jina__pb2.SnapshotStatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaExecutorSnapshotProgress', rpc_method_handlers)
+ 'jina.JinaExecutorSnapshotProgress', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaExecutorSnapshotProgress(object):
"""*
jina gRPC service to trigger a snapshot at the Executor Runtime.
"""
@staticmethod
- def snapshot_status(request,
+ def snapshot_status(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaExecutorSnapshotProgress/snapshot_status',
+ '/jina.JinaExecutorSnapshotProgress/snapshot_status',
jina__pb2.SnapshotId.SerializeToString,
jina__pb2.SnapshotStatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaExecutorRestoreStub(object):
@@ -631,10 +744,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.restore = channel.unary_unary(
- '/jina.JinaExecutorRestore/restore',
- request_serializer=jina__pb2.RestoreSnapshotCommand.SerializeToString,
- response_deserializer=jina__pb2.RestoreSnapshotStatusProto.FromString,
- )
+ '/jina.JinaExecutorRestore/restore',
+ request_serializer=jina__pb2.RestoreSnapshotCommand.SerializeToString,
+ response_deserializer=jina__pb2.RestoreSnapshotStatusProto.FromString,
+ )
class JinaExecutorRestoreServicer(object):
@@ -651,39 +764,52 @@ def restore(self, request, context):
def add_JinaExecutorRestoreServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'restore': grpc.unary_unary_rpc_method_handler(
- servicer.restore,
- request_deserializer=jina__pb2.RestoreSnapshotCommand.FromString,
- response_serializer=jina__pb2.RestoreSnapshotStatusProto.SerializeToString,
- ),
+ 'restore': grpc.unary_unary_rpc_method_handler(
+ servicer.restore,
+ request_deserializer=jina__pb2.RestoreSnapshotCommand.FromString,
+ response_serializer=jina__pb2.RestoreSnapshotStatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaExecutorRestore', rpc_method_handlers)
+ 'jina.JinaExecutorRestore', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaExecutorRestore(object):
"""*
jina gRPC service to trigger a restore at the Executor Runtime.
"""
@staticmethod
- def restore(request,
+ def restore(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaExecutorRestore/restore',
+ '/jina.JinaExecutorRestore/restore',
jina__pb2.RestoreSnapshotCommand.SerializeToString,
jina__pb2.RestoreSnapshotStatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaExecutorRestoreProgressStub(object):
@@ -698,10 +824,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.restore_status = channel.unary_unary(
- '/jina.JinaExecutorRestoreProgress/restore_status',
- request_serializer=jina__pb2.RestoreId.SerializeToString,
- response_deserializer=jina__pb2.RestoreSnapshotStatusProto.FromString,
- )
+ '/jina.JinaExecutorRestoreProgress/restore_status',
+ request_serializer=jina__pb2.RestoreId.SerializeToString,
+ response_deserializer=jina__pb2.RestoreSnapshotStatusProto.FromString,
+ )
class JinaExecutorRestoreProgressServicer(object):
@@ -718,36 +844,49 @@ def restore_status(self, request, context):
def add_JinaExecutorRestoreProgressServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'restore_status': grpc.unary_unary_rpc_method_handler(
- servicer.restore_status,
- request_deserializer=jina__pb2.RestoreId.FromString,
- response_serializer=jina__pb2.RestoreSnapshotStatusProto.SerializeToString,
- ),
+ 'restore_status': grpc.unary_unary_rpc_method_handler(
+ servicer.restore_status,
+ request_deserializer=jina__pb2.RestoreId.FromString,
+ response_serializer=jina__pb2.RestoreSnapshotStatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaExecutorRestoreProgress', rpc_method_handlers)
+ 'jina.JinaExecutorRestoreProgress', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaExecutorRestoreProgress(object):
"""*
jina gRPC service to trigger a snapshot at the Executor Runtime.
"""
@staticmethod
- def restore_status(request,
+ def restore_status(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaExecutorRestoreProgress/restore_status',
+ '/jina.JinaExecutorRestoreProgress/restore_status',
jina__pb2.RestoreId.SerializeToString,
jina__pb2.RestoreSnapshotStatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
diff --git a/marie/proto/docarray_v2/pb2/jina_pb2.py b/marie/proto/docarray_v2/pb2/jina_pb2.py
index 828b0ed5..b1dc9b7c 100644
--- a/marie/proto/docarray_v2/pb2/jina_pb2.py
+++ b/marie/proto/docarray_v2/pb2/jina_pb2.py
@@ -7,6 +7,7 @@
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
+
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@@ -18,8 +19,9 @@
import docarray.proto.pb2.docarray_pb2 as docarray__pb2
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\njina.proto\x12\x04jina\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x0e\x64ocarray.proto\"\x9f\x01\n\nRouteProto\x12\x10\n\x08\x65xecutor\x18\x01 \x01(\t\x12.\n\nstart_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12!\n\x06status\x18\x04 \x01(\x0b\x32\x11.jina.StatusProto\"\xc3\x01\n\rJinaInfoProto\x12+\n\x04jina\x18\x01 \x03(\x0b\x32\x1d.jina.JinaInfoProto.JinaEntry\x12+\n\x04\x65nvs\x18\x02 \x03(\x0b\x32\x1d.jina.JinaInfoProto.EnvsEntry\x1a+\n\tJinaEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a+\n\tEnvsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc6\x01\n\x0bHeaderProto\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12!\n\x06status\x18\x02 \x01(\x0b\x32\x11.jina.StatusProto\x12\x1a\n\rexec_endpoint\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0ftarget_executor\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x14\n\x07timeout\x18\x05 \x01(\rH\x02\x88\x01\x01\x42\x10\n\x0e_exec_endpointB\x12\n\x10_target_executorB\n\n\x08_timeout\"f\n\x0e\x45ndpointsProto\x12\x11\n\tendpoints\x18\x01 \x03(\t\x12\x17\n\x0fwrite_endpoints\x18\x02 \x03(\t\x12(\n\x07schemas\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xf9\x01\n\x0bStatusProto\x12*\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1c.jina.StatusProto.StatusCode\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x33\n\texception\x18\x03 \x01(\x0b\x32 .jina.StatusProto.ExceptionProto\x1aN\n\x0e\x45xceptionProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x03(\t\x12\x0e\n\x06stacks\x18\x03 \x03(\t\x12\x10\n\x08\x65xecutor\x18\x04 \x01(\t\"$\n\nStatusCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\t\n\x05\x45RROR\x10\x01\"^\n\rRelatedEntity\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x07\x61\x64\x64ress\x18\x02 \x01(\t\x12\x0c\n\x04port\x18\x03 \x01(\r\x12\x15\n\x08shard_id\x18\x04 \x01(\rH\x00\x88\x01\x01\x42\x0b\n\t_shard_id\"\x9a\x02\n\x10\x44\x61taRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12\x35\n\x04\x64\x61ta\x18\x04 \x01(\x0b\x32\'.jina.DataRequestProto.DataContentProto\x1a]\n\x10\x44\x61taContentProto\x12&\n\x04\x64ocs\x18\x01 \x01(\x0b\x32\x16.docarray.DocListProtoH\x00\x12\x14\n\ndocs_bytes\x18\x02 \x01(\x0cH\x00\x42\x0b\n\tdocuments\"\xb4\x01\n\x1aSingleDocumentRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12$\n\x08\x64ocument\x18\x04 \x01(\x0b\x32\x12.docarray.DocProto\"\x8a\x01\n\x16\x44\x61taRequestProtoWoData\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\"@\n\x14\x44\x61taRequestListProto\x12(\n\x08requests\x18\x01 \x03(\x0b\x32\x16.jina.DataRequestProto\"\x1b\n\nSnapshotId\x12\r\n\x05value\x18\x01 \x01(\t\"\x1a\n\tRestoreId\x12\r\n\x05value\x18\x01 \x01(\t\"\xef\x01\n\x13SnapshotStatusProto\x12\x1c\n\x02id\x18\x01 \x01(\x0b\x32\x10.jina.SnapshotId\x12\x30\n\x06status\x18\x02 \x01(\x0e\x32 .jina.SnapshotStatusProto.Status\x12\x15\n\rsnapshot_file\x18\x03 \x01(\t\"q\n\x06Status\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\r\n\tSCHEDULED\x10\x02\x12\x0b\n\x07RUNNING\x10\x03\x12\r\n\tSUCCEEDED\x10\x04\x12\n\n\x06\x46\x41ILED\x10\x05\x12\r\n\tNOT_FOUND\x10\x06\"\xca\x01\n\x1aRestoreSnapshotStatusProto\x12\x1b\n\x02id\x18\x01 \x01(\x0b\x32\x0f.jina.RestoreId\x12\x37\n\x06status\x18\x02 \x01(\x0e\x32\'.jina.RestoreSnapshotStatusProto.Status\"V\n\x06Status\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\x0b\n\x07RUNNING\x10\x01\x12\r\n\tSUCCEEDED\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\r\n\tNOT_FOUND\x10\x06\"/\n\x16RestoreSnapshotCommand\x12\x15\n\rsnapshot_file\x18\x01 \x01(\t2Z\n\x12JinaDataRequestRPC\x12\x44\n\x0cprocess_data\x12\x1a.jina.DataRequestListProto\x1a\x16.jina.DataRequestProto\"\x00\x32\x63\n\x18JinaSingleDataRequestRPC\x12G\n\x13process_single_data\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00\x32t\n\x1cJinaSingleDocumentRequestRPC\x12T\n\nstream_doc\x12 .jina.SingleDocumentRequestProto\x1a .jina.SingleDocumentRequestProto\"\x00\x30\x01\x32G\n\x07JinaRPC\x12<\n\x04\x43\x61ll\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00(\x01\x30\x01\x32`\n\x18JinaDiscoverEndpointsRPC\x12\x44\n\x12\x65ndpoint_discovery\x12\x16.google.protobuf.Empty\x1a\x14.jina.EndpointsProto\"\x00\x32N\n\x14JinaGatewayDryRunRPC\x12\x36\n\x07\x64ry_run\x12\x16.google.protobuf.Empty\x1a\x11.jina.StatusProto\"\x00\x32G\n\x0bJinaInfoRPC\x12\x38\n\x07_status\x12\x16.google.protobuf.Empty\x1a\x13.jina.JinaInfoProto\"\x00\x32W\n\x14JinaExecutorSnapshot\x12?\n\x08snapshot\x12\x16.google.protobuf.Empty\x1a\x19.jina.SnapshotStatusProto\"\x00\x32`\n\x1cJinaExecutorSnapshotProgress\x12@\n\x0fsnapshot_status\x12\x10.jina.SnapshotId\x1a\x19.jina.SnapshotStatusProto\"\x00\x32\x62\n\x13JinaExecutorRestore\x12K\n\x07restore\x12\x1c.jina.RestoreSnapshotCommand\x1a .jina.RestoreSnapshotStatusProto\"\x00\x32\x64\n\x1bJinaExecutorRestoreProgress\x12\x45\n\x0erestore_status\x12\x0f.jina.RestoreId\x1a .jina.RestoreSnapshotStatusProto\"\x00\x62\x06proto3')
-
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\njina.proto\x12\x04jina\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x0e\x64ocarray.proto\"\x9f\x01\n\nRouteProto\x12\x10\n\x08\x65xecutor\x18\x01 \x01(\t\x12.\n\nstart_time\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12,\n\x08\x65nd_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12!\n\x06status\x18\x04 \x01(\x0b\x32\x11.jina.StatusProto\"\xc3\x01\n\rJinaInfoProto\x12+\n\x04jina\x18\x01 \x03(\x0b\x32\x1d.jina.JinaInfoProto.JinaEntry\x12+\n\x04\x65nvs\x18\x02 \x03(\x0b\x32\x1d.jina.JinaInfoProto.EnvsEntry\x1a+\n\tJinaEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a+\n\tEnvsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc6\x01\n\x0bHeaderProto\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12!\n\x06status\x18\x02 \x01(\x0b\x32\x11.jina.StatusProto\x12\x1a\n\rexec_endpoint\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x1c\n\x0ftarget_executor\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x14\n\x07timeout\x18\x05 \x01(\rH\x02\x88\x01\x01\x42\x10\n\x0e_exec_endpointB\x12\n\x10_target_executorB\n\n\x08_timeout\"f\n\x0e\x45ndpointsProto\x12\x11\n\tendpoints\x18\x01 \x03(\t\x12\x17\n\x0fwrite_endpoints\x18\x02 \x03(\t\x12(\n\x07schemas\x18\x03 \x01(\x0b\x32\x17.google.protobuf.Struct\"\xf9\x01\n\x0bStatusProto\x12*\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x1c.jina.StatusProto.StatusCode\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x33\n\texception\x18\x03 \x01(\x0b\x32 .jina.StatusProto.ExceptionProto\x1aN\n\x0e\x45xceptionProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x03(\t\x12\x0e\n\x06stacks\x18\x03 \x03(\t\x12\x10\n\x08\x65xecutor\x18\x04 \x01(\t\"$\n\nStatusCode\x12\x0b\n\x07SUCCESS\x10\x00\x12\t\n\x05\x45RROR\x10\x01\"^\n\rRelatedEntity\x12\n\n\x02id\x18\x01 \x01(\t\x12\x0f\n\x07\x61\x64\x64ress\x18\x02 \x01(\t\x12\x0c\n\x04port\x18\x03 \x01(\r\x12\x15\n\x08shard_id\x18\x04 \x01(\rH\x00\x88\x01\x01\x42\x0b\n\t_shard_id\"\x9a\x02\n\x10\x44\x61taRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12\x35\n\x04\x64\x61ta\x18\x04 \x01(\x0b\x32\'.jina.DataRequestProto.DataContentProto\x1a]\n\x10\x44\x61taContentProto\x12&\n\x04\x64ocs\x18\x01 \x01(\x0b\x32\x16.docarray.DocListProtoH\x00\x12\x14\n\ndocs_bytes\x18\x02 \x01(\x0cH\x00\x42\x0b\n\tdocuments\"\xb4\x01\n\x1aSingleDocumentRequestProto\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\x12$\n\x08\x64ocument\x18\x04 \x01(\x0b\x32\x12.docarray.DocProto\"\x8a\x01\n\x16\x44\x61taRequestProtoWoData\x12!\n\x06header\x18\x01 \x01(\x0b\x32\x11.jina.HeaderProto\x12+\n\nparameters\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12 \n\x06routes\x18\x03 \x03(\x0b\x32\x10.jina.RouteProto\"@\n\x14\x44\x61taRequestListProto\x12(\n\x08requests\x18\x01 \x03(\x0b\x32\x16.jina.DataRequestProto\"\x1b\n\nSnapshotId\x12\r\n\x05value\x18\x01 \x01(\t\"\x1a\n\tRestoreId\x12\r\n\x05value\x18\x01 \x01(\t\"\xef\x01\n\x13SnapshotStatusProto\x12\x1c\n\x02id\x18\x01 \x01(\x0b\x32\x10.jina.SnapshotId\x12\x30\n\x06status\x18\x02 \x01(\x0e\x32 .jina.SnapshotStatusProto.Status\x12\x15\n\rsnapshot_file\x18\x03 \x01(\t\"q\n\x06Status\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\n\n\x06QUEUED\x10\x01\x12\r\n\tSCHEDULED\x10\x02\x12\x0b\n\x07RUNNING\x10\x03\x12\r\n\tSUCCEEDED\x10\x04\x12\n\n\x06\x46\x41ILED\x10\x05\x12\r\n\tNOT_FOUND\x10\x06\"\xca\x01\n\x1aRestoreSnapshotStatusProto\x12\x1b\n\x02id\x18\x01 \x01(\x0b\x32\x0f.jina.RestoreId\x12\x37\n\x06status\x18\x02 \x01(\x0e\x32\'.jina.RestoreSnapshotStatusProto.Status\"V\n\x06Status\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\x0b\n\x07RUNNING\x10\x01\x12\r\n\tSUCCEEDED\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\r\n\tNOT_FOUND\x10\x06\"/\n\x16RestoreSnapshotCommand\x12\x15\n\rsnapshot_file\x18\x01 \x01(\t2Z\n\x12JinaDataRequestRPC\x12\x44\n\x0cprocess_data\x12\x1a.jina.DataRequestListProto\x1a\x16.jina.DataRequestProto\"\x00\x32\x63\n\x18JinaSingleDataRequestRPC\x12G\n\x13process_single_data\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00\x32t\n\x1cJinaSingleDocumentRequestRPC\x12T\n\nstream_doc\x12 .jina.SingleDocumentRequestProto\x1a .jina.SingleDocumentRequestProto\"\x00\x30\x01\x32G\n\x07JinaRPC\x12<\n\x04\x43\x61ll\x12\x16.jina.DataRequestProto\x1a\x16.jina.DataRequestProto\"\x00(\x01\x30\x01\x32`\n\x18JinaDiscoverEndpointsRPC\x12\x44\n\x12\x65ndpoint_discovery\x12\x16.google.protobuf.Empty\x1a\x14.jina.EndpointsProto\"\x00\x32N\n\x14JinaGatewayDryRunRPC\x12\x36\n\x07\x64ry_run\x12\x16.google.protobuf.Empty\x1a\x11.jina.StatusProto\"\x00\x32G\n\x0bJinaInfoRPC\x12\x38\n\x07_status\x12\x16.google.protobuf.Empty\x1a\x13.jina.JinaInfoProto\"\x00\x32W\n\x14JinaExecutorSnapshot\x12?\n\x08snapshot\x12\x16.google.protobuf.Empty\x1a\x19.jina.SnapshotStatusProto\"\x00\x32`\n\x1cJinaExecutorSnapshotProgress\x12@\n\x0fsnapshot_status\x12\x10.jina.SnapshotId\x1a\x19.jina.SnapshotStatusProto\"\x00\x32\x62\n\x13JinaExecutorRestore\x12K\n\x07restore\x12\x1c.jina.RestoreSnapshotCommand\x1a .jina.RestoreSnapshotStatusProto\"\x00\x32\x64\n\x1bJinaExecutorRestoreProgress\x12\x45\n\x0erestore_status\x12\x0f.jina.RestoreId\x1a .jina.RestoreSnapshotStatusProto\"\x00\x62\x06proto3'
+)
_ROUTEPROTO = DESCRIPTOR.message_types_by_name['RouteProto']
@@ -32,237 +34,319 @@
_STATUSPROTO_EXCEPTIONPROTO = _STATUSPROTO.nested_types_by_name['ExceptionProto']
_RELATEDENTITY = DESCRIPTOR.message_types_by_name['RelatedEntity']
_DATAREQUESTPROTO = DESCRIPTOR.message_types_by_name['DataRequestProto']
-_DATAREQUESTPROTO_DATACONTENTPROTO = _DATAREQUESTPROTO.nested_types_by_name['DataContentProto']
-_SINGLEDOCUMENTREQUESTPROTO = DESCRIPTOR.message_types_by_name['SingleDocumentRequestProto']
+_DATAREQUESTPROTO_DATACONTENTPROTO = _DATAREQUESTPROTO.nested_types_by_name[
+ 'DataContentProto'
+]
+_SINGLEDOCUMENTREQUESTPROTO = DESCRIPTOR.message_types_by_name[
+ 'SingleDocumentRequestProto'
+]
_DATAREQUESTPROTOWODATA = DESCRIPTOR.message_types_by_name['DataRequestProtoWoData']
_DATAREQUESTLISTPROTO = DESCRIPTOR.message_types_by_name['DataRequestListProto']
_SNAPSHOTID = DESCRIPTOR.message_types_by_name['SnapshotId']
_RESTOREID = DESCRIPTOR.message_types_by_name['RestoreId']
_SNAPSHOTSTATUSPROTO = DESCRIPTOR.message_types_by_name['SnapshotStatusProto']
-_RESTORESNAPSHOTSTATUSPROTO = DESCRIPTOR.message_types_by_name['RestoreSnapshotStatusProto']
+_RESTORESNAPSHOTSTATUSPROTO = DESCRIPTOR.message_types_by_name[
+ 'RestoreSnapshotStatusProto'
+]
_RESTORESNAPSHOTCOMMAND = DESCRIPTOR.message_types_by_name['RestoreSnapshotCommand']
_STATUSPROTO_STATUSCODE = _STATUSPROTO.enum_types_by_name['StatusCode']
_SNAPSHOTSTATUSPROTO_STATUS = _SNAPSHOTSTATUSPROTO.enum_types_by_name['Status']
-_RESTORESNAPSHOTSTATUSPROTO_STATUS = _RESTORESNAPSHOTSTATUSPROTO.enum_types_by_name['Status']
-RouteProto = _reflection.GeneratedProtocolMessageType('RouteProto', (_message.Message,), {
- 'DESCRIPTOR' : _ROUTEPROTO,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.RouteProto)
- })
+_RESTORESNAPSHOTSTATUSPROTO_STATUS = _RESTORESNAPSHOTSTATUSPROTO.enum_types_by_name[
+ 'Status'
+]
+RouteProto = _reflection.GeneratedProtocolMessageType(
+ 'RouteProto',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _ROUTEPROTO,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.RouteProto)
+ },
+)
_sym_db.RegisterMessage(RouteProto)
-JinaInfoProto = _reflection.GeneratedProtocolMessageType('JinaInfoProto', (_message.Message,), {
-
- 'JinaEntry' : _reflection.GeneratedProtocolMessageType('JinaEntry', (_message.Message,), {
- 'DESCRIPTOR' : _JINAINFOPROTO_JINAENTRY,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.JinaInfoProto.JinaEntry)
- })
- ,
-
- 'EnvsEntry' : _reflection.GeneratedProtocolMessageType('EnvsEntry', (_message.Message,), {
- 'DESCRIPTOR' : _JINAINFOPROTO_ENVSENTRY,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.JinaInfoProto.EnvsEntry)
- })
- ,
- 'DESCRIPTOR' : _JINAINFOPROTO,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.JinaInfoProto)
- })
+JinaInfoProto = _reflection.GeneratedProtocolMessageType(
+ 'JinaInfoProto',
+ (_message.Message,),
+ {
+ 'JinaEntry': _reflection.GeneratedProtocolMessageType(
+ 'JinaEntry',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _JINAINFOPROTO_JINAENTRY,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.JinaInfoProto.JinaEntry)
+ },
+ ),
+ 'EnvsEntry': _reflection.GeneratedProtocolMessageType(
+ 'EnvsEntry',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _JINAINFOPROTO_ENVSENTRY,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.JinaInfoProto.EnvsEntry)
+ },
+ ),
+ 'DESCRIPTOR': _JINAINFOPROTO,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.JinaInfoProto)
+ },
+)
_sym_db.RegisterMessage(JinaInfoProto)
_sym_db.RegisterMessage(JinaInfoProto.JinaEntry)
_sym_db.RegisterMessage(JinaInfoProto.EnvsEntry)
-HeaderProto = _reflection.GeneratedProtocolMessageType('HeaderProto', (_message.Message,), {
- 'DESCRIPTOR' : _HEADERPROTO,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.HeaderProto)
- })
+HeaderProto = _reflection.GeneratedProtocolMessageType(
+ 'HeaderProto',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _HEADERPROTO,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.HeaderProto)
+ },
+)
_sym_db.RegisterMessage(HeaderProto)
-EndpointsProto = _reflection.GeneratedProtocolMessageType('EndpointsProto', (_message.Message,), {
- 'DESCRIPTOR' : _ENDPOINTSPROTO,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.EndpointsProto)
- })
+EndpointsProto = _reflection.GeneratedProtocolMessageType(
+ 'EndpointsProto',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _ENDPOINTSPROTO,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.EndpointsProto)
+ },
+)
_sym_db.RegisterMessage(EndpointsProto)
-StatusProto = _reflection.GeneratedProtocolMessageType('StatusProto', (_message.Message,), {
-
- 'ExceptionProto' : _reflection.GeneratedProtocolMessageType('ExceptionProto', (_message.Message,), {
- 'DESCRIPTOR' : _STATUSPROTO_EXCEPTIONPROTO,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.StatusProto.ExceptionProto)
- })
- ,
- 'DESCRIPTOR' : _STATUSPROTO,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.StatusProto)
- })
+StatusProto = _reflection.GeneratedProtocolMessageType(
+ 'StatusProto',
+ (_message.Message,),
+ {
+ 'ExceptionProto': _reflection.GeneratedProtocolMessageType(
+ 'ExceptionProto',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _STATUSPROTO_EXCEPTIONPROTO,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.StatusProto.ExceptionProto)
+ },
+ ),
+ 'DESCRIPTOR': _STATUSPROTO,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.StatusProto)
+ },
+)
_sym_db.RegisterMessage(StatusProto)
_sym_db.RegisterMessage(StatusProto.ExceptionProto)
-RelatedEntity = _reflection.GeneratedProtocolMessageType('RelatedEntity', (_message.Message,), {
- 'DESCRIPTOR' : _RELATEDENTITY,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.RelatedEntity)
- })
+RelatedEntity = _reflection.GeneratedProtocolMessageType(
+ 'RelatedEntity',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _RELATEDENTITY,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.RelatedEntity)
+ },
+)
_sym_db.RegisterMessage(RelatedEntity)
-DataRequestProto = _reflection.GeneratedProtocolMessageType('DataRequestProto', (_message.Message,), {
-
- 'DataContentProto' : _reflection.GeneratedProtocolMessageType('DataContentProto', (_message.Message,), {
- 'DESCRIPTOR' : _DATAREQUESTPROTO_DATACONTENTPROTO,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.DataRequestProto.DataContentProto)
- })
- ,
- 'DESCRIPTOR' : _DATAREQUESTPROTO,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.DataRequestProto)
- })
+DataRequestProto = _reflection.GeneratedProtocolMessageType(
+ 'DataRequestProto',
+ (_message.Message,),
+ {
+ 'DataContentProto': _reflection.GeneratedProtocolMessageType(
+ 'DataContentProto',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _DATAREQUESTPROTO_DATACONTENTPROTO,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.DataRequestProto.DataContentProto)
+ },
+ ),
+ 'DESCRIPTOR': _DATAREQUESTPROTO,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.DataRequestProto)
+ },
+)
_sym_db.RegisterMessage(DataRequestProto)
_sym_db.RegisterMessage(DataRequestProto.DataContentProto)
-SingleDocumentRequestProto = _reflection.GeneratedProtocolMessageType('SingleDocumentRequestProto', (_message.Message,), {
- 'DESCRIPTOR' : _SINGLEDOCUMENTREQUESTPROTO,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.SingleDocumentRequestProto)
- })
+SingleDocumentRequestProto = _reflection.GeneratedProtocolMessageType(
+ 'SingleDocumentRequestProto',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _SINGLEDOCUMENTREQUESTPROTO,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.SingleDocumentRequestProto)
+ },
+)
_sym_db.RegisterMessage(SingleDocumentRequestProto)
-DataRequestProtoWoData = _reflection.GeneratedProtocolMessageType('DataRequestProtoWoData', (_message.Message,), {
- 'DESCRIPTOR' : _DATAREQUESTPROTOWODATA,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.DataRequestProtoWoData)
- })
+DataRequestProtoWoData = _reflection.GeneratedProtocolMessageType(
+ 'DataRequestProtoWoData',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _DATAREQUESTPROTOWODATA,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.DataRequestProtoWoData)
+ },
+)
_sym_db.RegisterMessage(DataRequestProtoWoData)
-DataRequestListProto = _reflection.GeneratedProtocolMessageType('DataRequestListProto', (_message.Message,), {
- 'DESCRIPTOR' : _DATAREQUESTLISTPROTO,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.DataRequestListProto)
- })
+DataRequestListProto = _reflection.GeneratedProtocolMessageType(
+ 'DataRequestListProto',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _DATAREQUESTLISTPROTO,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.DataRequestListProto)
+ },
+)
_sym_db.RegisterMessage(DataRequestListProto)
-SnapshotId = _reflection.GeneratedProtocolMessageType('SnapshotId', (_message.Message,), {
- 'DESCRIPTOR' : _SNAPSHOTID,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.SnapshotId)
- })
+SnapshotId = _reflection.GeneratedProtocolMessageType(
+ 'SnapshotId',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _SNAPSHOTID,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.SnapshotId)
+ },
+)
_sym_db.RegisterMessage(SnapshotId)
-RestoreId = _reflection.GeneratedProtocolMessageType('RestoreId', (_message.Message,), {
- 'DESCRIPTOR' : _RESTOREID,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.RestoreId)
- })
+RestoreId = _reflection.GeneratedProtocolMessageType(
+ 'RestoreId',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _RESTOREID,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.RestoreId)
+ },
+)
_sym_db.RegisterMessage(RestoreId)
-SnapshotStatusProto = _reflection.GeneratedProtocolMessageType('SnapshotStatusProto', (_message.Message,), {
- 'DESCRIPTOR' : _SNAPSHOTSTATUSPROTO,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.SnapshotStatusProto)
- })
+SnapshotStatusProto = _reflection.GeneratedProtocolMessageType(
+ 'SnapshotStatusProto',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _SNAPSHOTSTATUSPROTO,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.SnapshotStatusProto)
+ },
+)
_sym_db.RegisterMessage(SnapshotStatusProto)
-RestoreSnapshotStatusProto = _reflection.GeneratedProtocolMessageType('RestoreSnapshotStatusProto', (_message.Message,), {
- 'DESCRIPTOR' : _RESTORESNAPSHOTSTATUSPROTO,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.RestoreSnapshotStatusProto)
- })
+RestoreSnapshotStatusProto = _reflection.GeneratedProtocolMessageType(
+ 'RestoreSnapshotStatusProto',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _RESTORESNAPSHOTSTATUSPROTO,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.RestoreSnapshotStatusProto)
+ },
+)
_sym_db.RegisterMessage(RestoreSnapshotStatusProto)
-RestoreSnapshotCommand = _reflection.GeneratedProtocolMessageType('RestoreSnapshotCommand', (_message.Message,), {
- 'DESCRIPTOR' : _RESTORESNAPSHOTCOMMAND,
- '__module__' : 'jina_pb2'
- # @@protoc_insertion_point(class_scope:jina.RestoreSnapshotCommand)
- })
+RestoreSnapshotCommand = _reflection.GeneratedProtocolMessageType(
+ 'RestoreSnapshotCommand',
+ (_message.Message,),
+ {
+ 'DESCRIPTOR': _RESTORESNAPSHOTCOMMAND,
+ '__module__': 'jina_pb2',
+ # @@protoc_insertion_point(class_scope:jina.RestoreSnapshotCommand)
+ },
+)
_sym_db.RegisterMessage(RestoreSnapshotCommand)
_JINADATAREQUESTRPC = DESCRIPTOR.services_by_name['JinaDataRequestRPC']
_JINASINGLEDATAREQUESTRPC = DESCRIPTOR.services_by_name['JinaSingleDataRequestRPC']
-_JINASINGLEDOCUMENTREQUESTRPC = DESCRIPTOR.services_by_name['JinaSingleDocumentRequestRPC']
+_JINASINGLEDOCUMENTREQUESTRPC = DESCRIPTOR.services_by_name[
+ 'JinaSingleDocumentRequestRPC'
+]
_JINARPC = DESCRIPTOR.services_by_name['JinaRPC']
_JINADISCOVERENDPOINTSRPC = DESCRIPTOR.services_by_name['JinaDiscoverEndpointsRPC']
_JINAGATEWAYDRYRUNRPC = DESCRIPTOR.services_by_name['JinaGatewayDryRunRPC']
_JINAINFORPC = DESCRIPTOR.services_by_name['JinaInfoRPC']
_JINAEXECUTORSNAPSHOT = DESCRIPTOR.services_by_name['JinaExecutorSnapshot']
-_JINAEXECUTORSNAPSHOTPROGRESS = DESCRIPTOR.services_by_name['JinaExecutorSnapshotProgress']
+_JINAEXECUTORSNAPSHOTPROGRESS = DESCRIPTOR.services_by_name[
+ 'JinaExecutorSnapshotProgress'
+]
_JINAEXECUTORRESTORE = DESCRIPTOR.services_by_name['JinaExecutorRestore']
-_JINAEXECUTORRESTOREPROGRESS = DESCRIPTOR.services_by_name['JinaExecutorRestoreProgress']
+_JINAEXECUTORRESTOREPROGRESS = DESCRIPTOR.services_by_name[
+ 'JinaExecutorRestoreProgress'
+]
if _descriptor._USE_C_DESCRIPTORS == False:
- DESCRIPTOR._options = None
- _JINAINFOPROTO_JINAENTRY._options = None
- _JINAINFOPROTO_JINAENTRY._serialized_options = b'8\001'
- _JINAINFOPROTO_ENVSENTRY._options = None
- _JINAINFOPROTO_ENVSENTRY._serialized_options = b'8\001'
- _ROUTEPROTO._serialized_start=129
- _ROUTEPROTO._serialized_end=288
- _JINAINFOPROTO._serialized_start=291
- _JINAINFOPROTO._serialized_end=486
- _JINAINFOPROTO_JINAENTRY._serialized_start=398
- _JINAINFOPROTO_JINAENTRY._serialized_end=441
- _JINAINFOPROTO_ENVSENTRY._serialized_start=443
- _JINAINFOPROTO_ENVSENTRY._serialized_end=486
- _HEADERPROTO._serialized_start=489
- _HEADERPROTO._serialized_end=687
- _ENDPOINTSPROTO._serialized_start=689
- _ENDPOINTSPROTO._serialized_end=791
- _STATUSPROTO._serialized_start=794
- _STATUSPROTO._serialized_end=1043
- _STATUSPROTO_EXCEPTIONPROTO._serialized_start=927
- _STATUSPROTO_EXCEPTIONPROTO._serialized_end=1005
- _STATUSPROTO_STATUSCODE._serialized_start=1007
- _STATUSPROTO_STATUSCODE._serialized_end=1043
- _RELATEDENTITY._serialized_start=1045
- _RELATEDENTITY._serialized_end=1139
- _DATAREQUESTPROTO._serialized_start=1142
- _DATAREQUESTPROTO._serialized_end=1424
- _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_start=1331
- _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_end=1424
- _SINGLEDOCUMENTREQUESTPROTO._serialized_start=1427
- _SINGLEDOCUMENTREQUESTPROTO._serialized_end=1607
- _DATAREQUESTPROTOWODATA._serialized_start=1610
- _DATAREQUESTPROTOWODATA._serialized_end=1748
- _DATAREQUESTLISTPROTO._serialized_start=1750
- _DATAREQUESTLISTPROTO._serialized_end=1814
- _SNAPSHOTID._serialized_start=1816
- _SNAPSHOTID._serialized_end=1843
- _RESTOREID._serialized_start=1845
- _RESTOREID._serialized_end=1871
- _SNAPSHOTSTATUSPROTO._serialized_start=1874
- _SNAPSHOTSTATUSPROTO._serialized_end=2113
- _SNAPSHOTSTATUSPROTO_STATUS._serialized_start=2000
- _SNAPSHOTSTATUSPROTO_STATUS._serialized_end=2113
- _RESTORESNAPSHOTSTATUSPROTO._serialized_start=2116
- _RESTORESNAPSHOTSTATUSPROTO._serialized_end=2318
- _RESTORESNAPSHOTSTATUSPROTO_STATUS._serialized_start=2232
- _RESTORESNAPSHOTSTATUSPROTO_STATUS._serialized_end=2318
- _RESTORESNAPSHOTCOMMAND._serialized_start=2320
- _RESTORESNAPSHOTCOMMAND._serialized_end=2367
- _JINADATAREQUESTRPC._serialized_start=2369
- _JINADATAREQUESTRPC._serialized_end=2459
- _JINASINGLEDATAREQUESTRPC._serialized_start=2461
- _JINASINGLEDATAREQUESTRPC._serialized_end=2560
- _JINASINGLEDOCUMENTREQUESTRPC._serialized_start=2562
- _JINASINGLEDOCUMENTREQUESTRPC._serialized_end=2678
- _JINARPC._serialized_start=2680
- _JINARPC._serialized_end=2751
- _JINADISCOVERENDPOINTSRPC._serialized_start=2753
- _JINADISCOVERENDPOINTSRPC._serialized_end=2849
- _JINAGATEWAYDRYRUNRPC._serialized_start=2851
- _JINAGATEWAYDRYRUNRPC._serialized_end=2929
- _JINAINFORPC._serialized_start=2931
- _JINAINFORPC._serialized_end=3002
- _JINAEXECUTORSNAPSHOT._serialized_start=3004
- _JINAEXECUTORSNAPSHOT._serialized_end=3091
- _JINAEXECUTORSNAPSHOTPROGRESS._serialized_start=3093
- _JINAEXECUTORSNAPSHOTPROGRESS._serialized_end=3189
- _JINAEXECUTORRESTORE._serialized_start=3191
- _JINAEXECUTORRESTORE._serialized_end=3289
- _JINAEXECUTORRESTOREPROGRESS._serialized_start=3291
- _JINAEXECUTORRESTOREPROGRESS._serialized_end=3391
+ DESCRIPTOR._options = None
+ _JINAINFOPROTO_JINAENTRY._options = None
+ _JINAINFOPROTO_JINAENTRY._serialized_options = b'8\001'
+ _JINAINFOPROTO_ENVSENTRY._options = None
+ _JINAINFOPROTO_ENVSENTRY._serialized_options = b'8\001'
+ _ROUTEPROTO._serialized_start = 129
+ _ROUTEPROTO._serialized_end = 288
+ _JINAINFOPROTO._serialized_start = 291
+ _JINAINFOPROTO._serialized_end = 486
+ _JINAINFOPROTO_JINAENTRY._serialized_start = 398
+ _JINAINFOPROTO_JINAENTRY._serialized_end = 441
+ _JINAINFOPROTO_ENVSENTRY._serialized_start = 443
+ _JINAINFOPROTO_ENVSENTRY._serialized_end = 486
+ _HEADERPROTO._serialized_start = 489
+ _HEADERPROTO._serialized_end = 687
+ _ENDPOINTSPROTO._serialized_start = 689
+ _ENDPOINTSPROTO._serialized_end = 791
+ _STATUSPROTO._serialized_start = 794
+ _STATUSPROTO._serialized_end = 1043
+ _STATUSPROTO_EXCEPTIONPROTO._serialized_start = 927
+ _STATUSPROTO_EXCEPTIONPROTO._serialized_end = 1005
+ _STATUSPROTO_STATUSCODE._serialized_start = 1007
+ _STATUSPROTO_STATUSCODE._serialized_end = 1043
+ _RELATEDENTITY._serialized_start = 1045
+ _RELATEDENTITY._serialized_end = 1139
+ _DATAREQUESTPROTO._serialized_start = 1142
+ _DATAREQUESTPROTO._serialized_end = 1424
+ _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_start = 1331
+ _DATAREQUESTPROTO_DATACONTENTPROTO._serialized_end = 1424
+ _SINGLEDOCUMENTREQUESTPROTO._serialized_start = 1427
+ _SINGLEDOCUMENTREQUESTPROTO._serialized_end = 1607
+ _DATAREQUESTPROTOWODATA._serialized_start = 1610
+ _DATAREQUESTPROTOWODATA._serialized_end = 1748
+ _DATAREQUESTLISTPROTO._serialized_start = 1750
+ _DATAREQUESTLISTPROTO._serialized_end = 1814
+ _SNAPSHOTID._serialized_start = 1816
+ _SNAPSHOTID._serialized_end = 1843
+ _RESTOREID._serialized_start = 1845
+ _RESTOREID._serialized_end = 1871
+ _SNAPSHOTSTATUSPROTO._serialized_start = 1874
+ _SNAPSHOTSTATUSPROTO._serialized_end = 2113
+ _SNAPSHOTSTATUSPROTO_STATUS._serialized_start = 2000
+ _SNAPSHOTSTATUSPROTO_STATUS._serialized_end = 2113
+ _RESTORESNAPSHOTSTATUSPROTO._serialized_start = 2116
+ _RESTORESNAPSHOTSTATUSPROTO._serialized_end = 2318
+ _RESTORESNAPSHOTSTATUSPROTO_STATUS._serialized_start = 2232
+ _RESTORESNAPSHOTSTATUSPROTO_STATUS._serialized_end = 2318
+ _RESTORESNAPSHOTCOMMAND._serialized_start = 2320
+ _RESTORESNAPSHOTCOMMAND._serialized_end = 2367
+ _JINADATAREQUESTRPC._serialized_start = 2369
+ _JINADATAREQUESTRPC._serialized_end = 2459
+ _JINASINGLEDATAREQUESTRPC._serialized_start = 2461
+ _JINASINGLEDATAREQUESTRPC._serialized_end = 2560
+ _JINASINGLEDOCUMENTREQUESTRPC._serialized_start = 2562
+ _JINASINGLEDOCUMENTREQUESTRPC._serialized_end = 2678
+ _JINARPC._serialized_start = 2680
+ _JINARPC._serialized_end = 2751
+ _JINADISCOVERENDPOINTSRPC._serialized_start = 2753
+ _JINADISCOVERENDPOINTSRPC._serialized_end = 2849
+ _JINAGATEWAYDRYRUNRPC._serialized_start = 2851
+ _JINAGATEWAYDRYRUNRPC._serialized_end = 2929
+ _JINAINFORPC._serialized_start = 2931
+ _JINAINFORPC._serialized_end = 3002
+ _JINAEXECUTORSNAPSHOT._serialized_start = 3004
+ _JINAEXECUTORSNAPSHOT._serialized_end = 3091
+ _JINAEXECUTORSNAPSHOTPROGRESS._serialized_start = 3093
+ _JINAEXECUTORSNAPSHOTPROGRESS._serialized_end = 3189
+ _JINAEXECUTORRESTORE._serialized_start = 3191
+ _JINAEXECUTORRESTORE._serialized_end = 3289
+ _JINAEXECUTORRESTOREPROGRESS._serialized_start = 3291
+ _JINAEXECUTORRESTOREPROGRESS._serialized_end = 3391
# @@protoc_insertion_point(module_scope)
diff --git a/marie/proto/docarray_v2/pb2/jina_pb2_grpc.py b/marie/proto/docarray_v2/pb2/jina_pb2_grpc.py
index f52ce19e..f571beae 100644
--- a/marie/proto/docarray_v2/pb2/jina_pb2_grpc.py
+++ b/marie/proto/docarray_v2/pb2/jina_pb2_grpc.py
@@ -18,10 +18,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.process_data = channel.unary_unary(
- '/jina.JinaDataRequestRPC/process_data',
- request_serializer=jina__pb2.DataRequestListProto.SerializeToString,
- response_deserializer=jina__pb2.DataRequestProto.FromString,
- )
+ '/jina.JinaDataRequestRPC/process_data',
+ request_serializer=jina__pb2.DataRequestListProto.SerializeToString,
+ response_deserializer=jina__pb2.DataRequestProto.FromString,
+ )
class JinaDataRequestRPCServicer(object):
@@ -30,8 +30,7 @@ class JinaDataRequestRPCServicer(object):
"""
def process_data(self, request, context):
- """Used for passing DataRequests to the Executors
- """
+ """Used for passing DataRequests to the Executors"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
@@ -39,39 +38,52 @@ def process_data(self, request, context):
def add_JinaDataRequestRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'process_data': grpc.unary_unary_rpc_method_handler(
- servicer.process_data,
- request_deserializer=jina__pb2.DataRequestListProto.FromString,
- response_serializer=jina__pb2.DataRequestProto.SerializeToString,
- ),
+ 'process_data': grpc.unary_unary_rpc_method_handler(
+ servicer.process_data,
+ request_deserializer=jina__pb2.DataRequestListProto.FromString,
+ response_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaDataRequestRPC', rpc_method_handlers)
+ 'jina.JinaDataRequestRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaDataRequestRPC(object):
"""*
jina gRPC service for DataRequests.
"""
@staticmethod
- def process_data(request,
+ def process_data(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaDataRequestRPC/process_data',
+ '/jina.JinaDataRequestRPC/process_data',
jina__pb2.DataRequestListProto.SerializeToString,
jina__pb2.DataRequestProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaSingleDataRequestRPCStub(object):
@@ -87,10 +99,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.process_single_data = channel.unary_unary(
- '/jina.JinaSingleDataRequestRPC/process_single_data',
- request_serializer=jina__pb2.DataRequestProto.SerializeToString,
- response_deserializer=jina__pb2.DataRequestProto.FromString,
- )
+ '/jina.JinaSingleDataRequestRPC/process_single_data',
+ request_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ response_deserializer=jina__pb2.DataRequestProto.FromString,
+ )
class JinaSingleDataRequestRPCServicer(object):
@@ -100,8 +112,7 @@ class JinaSingleDataRequestRPCServicer(object):
"""
def process_single_data(self, request, context):
- """Used for passing DataRequests to the Executors
- """
+ """Used for passing DataRequests to the Executors"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
@@ -109,18 +120,19 @@ def process_single_data(self, request, context):
def add_JinaSingleDataRequestRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'process_single_data': grpc.unary_unary_rpc_method_handler(
- servicer.process_single_data,
- request_deserializer=jina__pb2.DataRequestProto.FromString,
- response_serializer=jina__pb2.DataRequestProto.SerializeToString,
- ),
+ 'process_single_data': grpc.unary_unary_rpc_method_handler(
+ servicer.process_single_data,
+ request_deserializer=jina__pb2.DataRequestProto.FromString,
+ response_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaSingleDataRequestRPC', rpc_method_handlers)
+ 'jina.JinaSingleDataRequestRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaSingleDataRequestRPC(object):
"""*
jina gRPC service for DataRequests.
@@ -128,21 +140,33 @@ class JinaSingleDataRequestRPC(object):
"""
@staticmethod
- def process_single_data(request,
+ def process_single_data(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaSingleDataRequestRPC/process_single_data',
+ '/jina.JinaSingleDataRequestRPC/process_single_data',
jina__pb2.DataRequestProto.SerializeToString,
jina__pb2.DataRequestProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaSingleDocumentRequestRPCStub(object):
@@ -158,10 +182,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.stream_doc = channel.unary_stream(
- '/jina.JinaSingleDocumentRequestRPC/stream_doc',
- request_serializer=jina__pb2.SingleDocumentRequestProto.SerializeToString,
- response_deserializer=jina__pb2.SingleDocumentRequestProto.FromString,
- )
+ '/jina.JinaSingleDocumentRequestRPC/stream_doc',
+ request_serializer=jina__pb2.SingleDocumentRequestProto.SerializeToString,
+ response_deserializer=jina__pb2.SingleDocumentRequestProto.FromString,
+ )
class JinaSingleDocumentRequestRPCServicer(object):
@@ -171,8 +195,7 @@ class JinaSingleDocumentRequestRPCServicer(object):
"""
def stream_doc(self, request, context):
- """Used for streaming one document to the Executors
- """
+ """Used for streaming one document to the Executors"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
@@ -180,18 +203,19 @@ def stream_doc(self, request, context):
def add_JinaSingleDocumentRequestRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'stream_doc': grpc.unary_stream_rpc_method_handler(
- servicer.stream_doc,
- request_deserializer=jina__pb2.SingleDocumentRequestProto.FromString,
- response_serializer=jina__pb2.SingleDocumentRequestProto.SerializeToString,
- ),
+ 'stream_doc': grpc.unary_stream_rpc_method_handler(
+ servicer.stream_doc,
+ request_deserializer=jina__pb2.SingleDocumentRequestProto.FromString,
+ response_serializer=jina__pb2.SingleDocumentRequestProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaSingleDocumentRequestRPC', rpc_method_handlers)
+ 'jina.JinaSingleDocumentRequestRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaSingleDocumentRequestRPC(object):
"""*
jina gRPC service for DataRequests.
@@ -199,21 +223,33 @@ class JinaSingleDocumentRequestRPC(object):
"""
@staticmethod
- def stream_doc(request,
+ def stream_doc(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_stream(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_stream(request, target, '/jina.JinaSingleDocumentRequestRPC/stream_doc',
+ '/jina.JinaSingleDocumentRequestRPC/stream_doc',
jina__pb2.SingleDocumentRequestProto.SerializeToString,
jina__pb2.SingleDocumentRequestProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaRPCStub(object):
@@ -228,10 +264,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.Call = channel.stream_stream(
- '/jina.JinaRPC/Call',
- request_serializer=jina__pb2.DataRequestProto.SerializeToString,
- response_deserializer=jina__pb2.DataRequestProto.FromString,
- )
+ '/jina.JinaRPC/Call',
+ request_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ response_deserializer=jina__pb2.DataRequestProto.FromString,
+ )
class JinaRPCServicer(object):
@@ -240,8 +276,7 @@ class JinaRPCServicer(object):
"""
def Call(self, request_iterator, context):
- """Pass in a Request and a filled Request with matches will be returned.
- """
+ """Pass in a Request and a filled Request with matches will be returned."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
@@ -249,39 +284,52 @@ def Call(self, request_iterator, context):
def add_JinaRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'Call': grpc.stream_stream_rpc_method_handler(
- servicer.Call,
- request_deserializer=jina__pb2.DataRequestProto.FromString,
- response_serializer=jina__pb2.DataRequestProto.SerializeToString,
- ),
+ 'Call': grpc.stream_stream_rpc_method_handler(
+ servicer.Call,
+ request_deserializer=jina__pb2.DataRequestProto.FromString,
+ response_serializer=jina__pb2.DataRequestProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaRPC', rpc_method_handlers)
+ 'jina.JinaRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaRPC(object):
"""*
jina streaming gRPC service.
"""
@staticmethod
- def Call(request_iterator,
+ def Call(
+ request_iterator,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.stream_stream(
+ request_iterator,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.stream_stream(request_iterator, target, '/jina.JinaRPC/Call',
+ '/jina.JinaRPC/Call',
jina__pb2.DataRequestProto.SerializeToString,
jina__pb2.DataRequestProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaDiscoverEndpointsRPCStub(object):
@@ -296,10 +344,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.endpoint_discovery = channel.unary_unary(
- '/jina.JinaDiscoverEndpointsRPC/endpoint_discovery',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=jina__pb2.EndpointsProto.FromString,
- )
+ '/jina.JinaDiscoverEndpointsRPC/endpoint_discovery',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=jina__pb2.EndpointsProto.FromString,
+ )
class JinaDiscoverEndpointsRPCServicer(object):
@@ -316,39 +364,52 @@ def endpoint_discovery(self, request, context):
def add_JinaDiscoverEndpointsRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'endpoint_discovery': grpc.unary_unary_rpc_method_handler(
- servicer.endpoint_discovery,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=jina__pb2.EndpointsProto.SerializeToString,
- ),
+ 'endpoint_discovery': grpc.unary_unary_rpc_method_handler(
+ servicer.endpoint_discovery,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=jina__pb2.EndpointsProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaDiscoverEndpointsRPC', rpc_method_handlers)
+ 'jina.JinaDiscoverEndpointsRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaDiscoverEndpointsRPC(object):
"""*
jina gRPC service to expose Endpoints from Executors.
"""
@staticmethod
- def endpoint_discovery(request,
+ def endpoint_discovery(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaDiscoverEndpointsRPC/endpoint_discovery',
+ '/jina.JinaDiscoverEndpointsRPC/endpoint_discovery',
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
jina__pb2.EndpointsProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaGatewayDryRunRPCStub(object):
@@ -363,10 +424,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.dry_run = channel.unary_unary(
- '/jina.JinaGatewayDryRunRPC/dry_run',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=jina__pb2.StatusProto.FromString,
- )
+ '/jina.JinaGatewayDryRunRPC/dry_run',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=jina__pb2.StatusProto.FromString,
+ )
class JinaGatewayDryRunRPCServicer(object):
@@ -383,39 +444,52 @@ def dry_run(self, request, context):
def add_JinaGatewayDryRunRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'dry_run': grpc.unary_unary_rpc_method_handler(
- servicer.dry_run,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=jina__pb2.StatusProto.SerializeToString,
- ),
+ 'dry_run': grpc.unary_unary_rpc_method_handler(
+ servicer.dry_run,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=jina__pb2.StatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaGatewayDryRunRPC', rpc_method_handlers)
+ 'jina.JinaGatewayDryRunRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaGatewayDryRunRPC(object):
"""*
jina gRPC service to expose Endpoints from Executors.
"""
@staticmethod
- def dry_run(request,
+ def dry_run(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaGatewayDryRunRPC/dry_run',
+ '/jina.JinaGatewayDryRunRPC/dry_run',
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
jina__pb2.StatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaInfoRPCStub(object):
@@ -430,10 +504,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self._status = channel.unary_unary(
- '/jina.JinaInfoRPC/_status',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=jina__pb2.JinaInfoProto.FromString,
- )
+ '/jina.JinaInfoRPC/_status',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=jina__pb2.JinaInfoProto.FromString,
+ )
class JinaInfoRPCServicer(object):
@@ -450,39 +524,52 @@ def _status(self, request, context):
def add_JinaInfoRPCServicer_to_server(servicer, server):
rpc_method_handlers = {
- '_status': grpc.unary_unary_rpc_method_handler(
- servicer._status,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=jina__pb2.JinaInfoProto.SerializeToString,
- ),
+ '_status': grpc.unary_unary_rpc_method_handler(
+ servicer._status,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=jina__pb2.JinaInfoProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaInfoRPC', rpc_method_handlers)
+ 'jina.JinaInfoRPC', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaInfoRPC(object):
"""*
jina gRPC service to expose information about running jina version and environment.
"""
@staticmethod
- def _status(request,
+ def _status(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaInfoRPC/_status',
+ '/jina.JinaInfoRPC/_status',
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
jina__pb2.JinaInfoProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaExecutorSnapshotStub(object):
@@ -497,10 +584,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.snapshot = channel.unary_unary(
- '/jina.JinaExecutorSnapshot/snapshot',
- request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
- response_deserializer=jina__pb2.SnapshotStatusProto.FromString,
- )
+ '/jina.JinaExecutorSnapshot/snapshot',
+ request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
+ response_deserializer=jina__pb2.SnapshotStatusProto.FromString,
+ )
class JinaExecutorSnapshotServicer(object):
@@ -517,39 +604,52 @@ def snapshot(self, request, context):
def add_JinaExecutorSnapshotServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'snapshot': grpc.unary_unary_rpc_method_handler(
- servicer.snapshot,
- request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
- response_serializer=jina__pb2.SnapshotStatusProto.SerializeToString,
- ),
+ 'snapshot': grpc.unary_unary_rpc_method_handler(
+ servicer.snapshot,
+ request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
+ response_serializer=jina__pb2.SnapshotStatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaExecutorSnapshot', rpc_method_handlers)
+ 'jina.JinaExecutorSnapshot', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaExecutorSnapshot(object):
"""*
jina gRPC service to trigger a snapshot at the Executor Runtime.
"""
@staticmethod
- def snapshot(request,
+ def snapshot(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaExecutorSnapshot/snapshot',
+ '/jina.JinaExecutorSnapshot/snapshot',
google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
jina__pb2.SnapshotStatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaExecutorSnapshotProgressStub(object):
@@ -564,10 +664,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.snapshot_status = channel.unary_unary(
- '/jina.JinaExecutorSnapshotProgress/snapshot_status',
- request_serializer=jina__pb2.SnapshotId.SerializeToString,
- response_deserializer=jina__pb2.SnapshotStatusProto.FromString,
- )
+ '/jina.JinaExecutorSnapshotProgress/snapshot_status',
+ request_serializer=jina__pb2.SnapshotId.SerializeToString,
+ response_deserializer=jina__pb2.SnapshotStatusProto.FromString,
+ )
class JinaExecutorSnapshotProgressServicer(object):
@@ -584,39 +684,52 @@ def snapshot_status(self, request, context):
def add_JinaExecutorSnapshotProgressServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'snapshot_status': grpc.unary_unary_rpc_method_handler(
- servicer.snapshot_status,
- request_deserializer=jina__pb2.SnapshotId.FromString,
- response_serializer=jina__pb2.SnapshotStatusProto.SerializeToString,
- ),
+ 'snapshot_status': grpc.unary_unary_rpc_method_handler(
+ servicer.snapshot_status,
+ request_deserializer=jina__pb2.SnapshotId.FromString,
+ response_serializer=jina__pb2.SnapshotStatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaExecutorSnapshotProgress', rpc_method_handlers)
+ 'jina.JinaExecutorSnapshotProgress', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaExecutorSnapshotProgress(object):
"""*
jina gRPC service to trigger a snapshot at the Executor Runtime.
"""
@staticmethod
- def snapshot_status(request,
+ def snapshot_status(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaExecutorSnapshotProgress/snapshot_status',
+ '/jina.JinaExecutorSnapshotProgress/snapshot_status',
jina__pb2.SnapshotId.SerializeToString,
jina__pb2.SnapshotStatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaExecutorRestoreStub(object):
@@ -631,10 +744,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.restore = channel.unary_unary(
- '/jina.JinaExecutorRestore/restore',
- request_serializer=jina__pb2.RestoreSnapshotCommand.SerializeToString,
- response_deserializer=jina__pb2.RestoreSnapshotStatusProto.FromString,
- )
+ '/jina.JinaExecutorRestore/restore',
+ request_serializer=jina__pb2.RestoreSnapshotCommand.SerializeToString,
+ response_deserializer=jina__pb2.RestoreSnapshotStatusProto.FromString,
+ )
class JinaExecutorRestoreServicer(object):
@@ -651,39 +764,52 @@ def restore(self, request, context):
def add_JinaExecutorRestoreServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'restore': grpc.unary_unary_rpc_method_handler(
- servicer.restore,
- request_deserializer=jina__pb2.RestoreSnapshotCommand.FromString,
- response_serializer=jina__pb2.RestoreSnapshotStatusProto.SerializeToString,
- ),
+ 'restore': grpc.unary_unary_rpc_method_handler(
+ servicer.restore,
+ request_deserializer=jina__pb2.RestoreSnapshotCommand.FromString,
+ response_serializer=jina__pb2.RestoreSnapshotStatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaExecutorRestore', rpc_method_handlers)
+ 'jina.JinaExecutorRestore', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaExecutorRestore(object):
"""*
jina gRPC service to trigger a restore at the Executor Runtime.
"""
@staticmethod
- def restore(request,
+ def restore(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaExecutorRestore/restore',
+ '/jina.JinaExecutorRestore/restore',
jina__pb2.RestoreSnapshotCommand.SerializeToString,
jina__pb2.RestoreSnapshotStatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
class JinaExecutorRestoreProgressStub(object):
@@ -698,10 +824,10 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.restore_status = channel.unary_unary(
- '/jina.JinaExecutorRestoreProgress/restore_status',
- request_serializer=jina__pb2.RestoreId.SerializeToString,
- response_deserializer=jina__pb2.RestoreSnapshotStatusProto.FromString,
- )
+ '/jina.JinaExecutorRestoreProgress/restore_status',
+ request_serializer=jina__pb2.RestoreId.SerializeToString,
+ response_deserializer=jina__pb2.RestoreSnapshotStatusProto.FromString,
+ )
class JinaExecutorRestoreProgressServicer(object):
@@ -718,36 +844,49 @@ def restore_status(self, request, context):
def add_JinaExecutorRestoreProgressServicer_to_server(servicer, server):
rpc_method_handlers = {
- 'restore_status': grpc.unary_unary_rpc_method_handler(
- servicer.restore_status,
- request_deserializer=jina__pb2.RestoreId.FromString,
- response_serializer=jina__pb2.RestoreSnapshotStatusProto.SerializeToString,
- ),
+ 'restore_status': grpc.unary_unary_rpc_method_handler(
+ servicer.restore_status,
+ request_deserializer=jina__pb2.RestoreId.FromString,
+ response_serializer=jina__pb2.RestoreSnapshotStatusProto.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
- 'jina.JinaExecutorRestoreProgress', rpc_method_handlers)
+ 'jina.JinaExecutorRestoreProgress', rpc_method_handlers
+ )
server.add_generic_rpc_handlers((generic_handler,))
- # This class is part of an EXPERIMENTAL API.
+# This class is part of an EXPERIMENTAL API.
class JinaExecutorRestoreProgress(object):
"""*
jina gRPC service to trigger a snapshot at the Executor Runtime.
"""
@staticmethod
- def restore_status(request,
+ def restore_status(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.unary_unary(request, target, '/jina.JinaExecutorRestoreProgress/restore_status',
+ '/jina.JinaExecutorRestoreProgress/restore_status',
jina__pb2.RestoreId.SerializeToString,
jina__pb2.RestoreSnapshotStatusProto.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
diff --git a/marie/resources/logging.default.yml b/marie/resources/logging.default.yml
index 807f2da9..4ad9c797 100644
--- a/marie/resources/logging.default.yml
+++ b/marie/resources/logging.default.yml
@@ -2,7 +2,7 @@ handlers: # enabled handlers, order does not matter
# - StreamHandler
# - FileHandler
- RichHandler
-level: INFO # set verbose level
+level: DEBUG # set verbose level
configs:
FileHandler:
format: '%(asctime)s:{name:>15}@%(process)2d[%(levelname).1s]:%(message)s'
diff --git a/marie/serve/consensus/go.mod b/marie/serve/consensus/go.mod
index fc854188..666b8d86 100644
--- a/marie/serve/consensus/go.mod
+++ b/marie/serve/consensus/go.mod
@@ -11,7 +11,7 @@ require (
github.com/hashicorp/raft v1.3.11
github.com/hashicorp/raft-boltdb v0.0.0-20220329195025-15018e9b97e0
google.golang.org/grpc v1.56.3
- google.golang.org/protobuf v1.30.0
+ google.golang.org/protobuf v1.33.0
)
require (
diff --git a/marie/serve/discovery/__init__.py b/marie/serve/discovery/__init__.py
index e670b020..ac26a531 100644
--- a/marie/serve/discovery/__init__.py
+++ b/marie/serve/discovery/__init__.py
@@ -1,76 +1,141 @@
import threading
-import time
-from typing import TYPE_CHECKING, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
import requests
+from marie._annotations import deprecated
+from marie.enums import ProtocolType
from marie.helper import get_internal_ip
from marie.importer import ImportExtensions
+from marie.serve.discovery.address import JsonAddress
+from marie.serve.discovery.registry import EtcdServiceRegistry
+from marie.utils.timer import RepeatedTimer
if TYPE_CHECKING: # pragma: no cover
import consul
-class RepeatedTimer(object):
- def __init__(self, interval, function, *args, **kwargs):
- self._timer = None
- self.interval = interval
- self.function = function
- self.args = args
- self.kwargs = kwargs
- self.is_running = False
- self.next_call = time.time()
- self.start()
-
- def _run(self):
- self.is_running = False
- self.start()
- self.function(*self.args, **self.kwargs)
-
- def start(self):
- if not self.is_running:
- self.next_call += self.interval
- self._timer = threading.Timer(self.next_call - time.time(), self._run)
- self._timer.start()
- self.is_running = True
-
- def stop(self):
- self._timer.cancel()
- self.is_running = False
-
-
class DiscoveryServiceMixin:
"""Instrumentation mixin for Service Discovery handling"""
def _setup_service_discovery(
self,
+ protocol: ProtocolType,
name: str,
host: str,
port: int,
- scheme: Optional[str] = 'http',
+ scheme: Optional[str] = "http",
discovery: Optional[bool] = False,
- discovery_host: Optional[str] = '0.0.0.0',
+ discovery_host: Optional[str] = "0.0.0.0",
discovery_port: Optional[int] = 8500,
- discovery_scheme: Optional[str] = 'http',
+ discovery_scheme: Optional[str] = "http",
discovery_watchdog_interval: Optional[int] = 60,
+ runtime_args: Optional[Dict] = None,
) -> None:
if self.logger is None:
raise Exception("Expected logger to be configured")
- self.sd_state = 'started'
+ if protocol == ProtocolType.GRPC:
+ self._setup_service_discovery_etcd(
+ name=name,
+ host=host,
+ port=port,
+ scheme=scheme,
+ discovery=discovery,
+ discovery_host=discovery_host,
+ discovery_port=discovery_port,
+ discovery_scheme=discovery_scheme,
+ discovery_watchdog_interval=discovery_watchdog_interval,
+ runtime_args=runtime_args,
+ )
+ elif protocol == ProtocolType.HTTP: # DEPRECATED : HTTP is deprecated
+ self._setup_service_discovery_consul(
+ name=name,
+ host=host,
+ port=port,
+ scheme=scheme,
+ discovery=discovery,
+ discovery_host=discovery_host,
+ discovery_port=discovery_port,
+ discovery_scheme=discovery_scheme,
+ discovery_watchdog_interval=discovery_watchdog_interval,
+ )
+ else:
+ raise NotImplementedError(f"Protocol {protocol} is not supported")
+
+ def _setup_service_discovery_etcd(
+ self,
+ name: str,
+ host: str,
+ port: int,
+ scheme: Optional[str] = "http",
+ discovery: Optional[bool] = False,
+ discovery_host: Optional[str] = "0.0.0.0",
+ discovery_port: Optional[int] = 8500,
+ discovery_scheme: Optional[str] = "http",
+ discovery_watchdog_interval: Optional[int] = 60,
+ runtime_args: Optional[Dict] = None,
+ ) -> None:
+ if self.logger is None:
+ raise Exception("Expected logger to be configured")
+ if runtime_args is None:
+ raise Exception("Expected runtime_args to be configured")
+
+ self.logger.info("Setting up service discovery ETCD ...")
+ self.sd_state = "started"
self.discovery_host = discovery_host
self.discovery_port = discovery_port
self.discovery_scheme = discovery_scheme
+ deployments_addresses = runtime_args.deployments_addresses
+ scheme = "grpc"
+ ctrl_address = f"{scheme}://{host}:{port}"
+ ctrl_address = f"{host}:{port}"
+ service_name = "gateway/service_test"
+
+ self.logger.info(f"Deployments addresses: {deployments_addresses}")
+
+ etcd_registry = EtcdServiceRegistry(
+ "0.0.0.0",
+ 2379,
+ heartbeat_time=5,
+ )
+ lease = etcd_registry.register(
+ [service_name],
+ ctrl_address,
+ 6,
+ addr_cls=JsonAddress,
+ metadata=deployments_addresses,
+ )
+
+ self.logger.info(f"Lease ID: {lease.id}")
+
+ @deprecated
+ def _setup_service_discovery_consul(
+ self,
+ name: str,
+ host: str,
+ port: int,
+ scheme: Optional[str] = "http",
+ discovery: Optional[bool] = False,
+ discovery_host: Optional[str] = "0.0.0.0",
+ discovery_port: Optional[int] = 8500,
+ discovery_scheme: Optional[str] = "http",
+ discovery_watchdog_interval: Optional[int] = 60,
+ ) -> None:
+
+ # testing
+ if True:
+ return
if discovery:
with ImportExtensions(
required=True,
- help_text='You need to install the `python-consul` to use the service discovery functionality of marie',
+ help_text="You need to install the `python-consul` to use the service discovery functionality of marie",
):
import consul
# Ban advertising 0.0.0.0 or setting it as a service address #2961
- if host == '0.0.0.0':
+ if host == "0.0.0.0":
host = get_internal_ip()
def _watchdog_target():
@@ -88,7 +153,7 @@ def _watchdog_target():
t = threading.Thread(target=_watchdog_target, daemon=True)
t.start()
- def _is_discovery_online(self, client: Union['consul.Consul', None]) -> bool:
+ def _is_discovery_online(self, client: Union["consul.Consul", None]) -> bool:
"""Check if service discovery is online"""
if client is None:
return False
@@ -105,9 +170,9 @@ def _teardown_service_discovery(
self,
) -> None:
"""Teardown service discovery, by unregistering existing service from the catalog"""
- if self.sd_state != 'ready':
+ if self.sd_state != "ready":
return
- self.sd_state = 'stopping'
+ self.sd_state = "stopping"
try:
self.discovery_client.agent.service.deregister(self.service_id)
except Exception:
@@ -128,7 +193,7 @@ def _start_discovery_watchdog(
# Create new service id, otherwise we will re-register same id
self.service_id = f"{name}@{service_host}:{service_port}"
self.service_name = "traefik-system-ingress"
- self.sd_state = 'ready'
+ self.sd_state = "ready"
self.discovery_client, online = self._create_discovery_client(True)
def __register(_service_host, _service_port, _service_scheme):
@@ -163,7 +228,7 @@ def _verify_discovery_connection(
self,
discovery_host: str,
discovery_port: int = 8500,
- discovery_scheme: Optional[str] = 'http',
+ discovery_scheme: Optional[str] = "http",
) -> bool:
"""Verify consul connection
Exceptions throw such as ConnectionError will be captured
@@ -193,7 +258,7 @@ def _verify_discovery_connection(
def _create_discovery_client(
self,
verify: bool = True,
- ) -> Tuple[Union['consul.Consul', None], bool]:
+ ) -> Tuple[Union["consul.Consul", None], bool]:
"""Create new consul client"""
import consul
@@ -218,6 +283,8 @@ def _create_discovery_client(
def _get_service_node(self, service_name, service_id):
try:
index, nodes = self.discovery_client.catalog.service(service_name)
+ if nodes is None:
+ return None
for node in nodes:
if node["ServiceID"] == service_id:
return node
diff --git a/marie/serve/discovery/address.py b/marie/serve/discovery/address.py
new file mode 100644
index 00000000..100670d2
--- /dev/null
+++ b/marie/serve/discovery/address.py
@@ -0,0 +1,93 @@
+import abc
+import json
+
+__all__ = ["Address", "PlainAddress", "JsonAddress"]
+
+
+def b2str(i_b):
+ if isinstance(i_b, str):
+ return i_b
+ return i_b.decode()
+
+
+class Address(abc.ABC):
+ """gRPC service address."""
+
+ @abc.abstractmethod
+ def __init__(self, addr, metadata=None):
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def add_value(self):
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def delete_value(self):
+ raise NotImplementedError
+
+ @classmethod
+ @abc.abstractmethod
+ def from_value(cls, val, deserializer=None):
+ raise NotImplementedError
+
+
+class PlainAddress(Address):
+ """Plain text address."""
+
+ def __init__(self, addr, metadata=None):
+ self._addr = addr
+
+ def add_value(self):
+ return self._addr
+
+ def delete_value(self):
+ return self._addr
+
+ @classmethod
+ def from_value(cls, val, deserializer=None):
+ return b2str(val)
+
+
+class JsonAddress(Address):
+ """Json address."""
+
+ add_op = 0
+ delete_op = 1
+
+ def __init__(self, addr, metadata=None, serializer=json.dumps):
+ self._addr = addr
+ self._metadata = metadata or {}
+ self._serializer = serializer
+
+ def add_value(self):
+ return self._serializer(
+ {
+ "Op": self.add_op,
+ "Addr": self._addr,
+ "Metadata": self._serializer(self._metadata),
+ }
+ )
+
+ def delete_value(self):
+ return self._serializer(
+ {
+ "Op": self.delete_op,
+ "Addr": self._addr,
+ "Metadata": self._serializer(self._metadata),
+ }
+ )
+
+ @classmethod
+ def from_value(cls, val, deserializer=json.loads):
+ addr_val = deserializer(b2str(val))
+ addr_val["Metadata"] = deserializer(addr_val["Metadata"])
+ addr_op = addr_val["Op"]
+ if False:
+ if addr_op == cls.add_op:
+ return True, addr_val["Addr"]
+ elif addr_op == cls.delete_op:
+ return False, addr_val["Addr"]
+
+ raise ValueError("invalid address value.")
+
+ return JsonAddress(addr_val["Addr"], addr_val["Metadata"])
diff --git a/marie/serve/discovery/etcd_client.py b/marie/serve/discovery/etcd_client.py
new file mode 100644
index 00000000..4cf6091a
--- /dev/null
+++ b/marie/serve/discovery/etcd_client.py
@@ -0,0 +1,436 @@
+import functools
+import logging
+import time
+from collections import namedtuple
+from typing import Callable, Dict, Mapping, Union
+from urllib.parse import quote as _quote
+from urllib.parse import unquote
+
+import etcd3
+import grpc
+from etcd3 import etcdrpc
+from etcd3.client import EtcdTokenCallCredentials
+from grpc._channel import _Rendezvous
+
+__all__ = ["EtcdClient", "Event"]
+
+from marie.excepts import RuntimeFailToStart
+
+Event = namedtuple("Event", "key event value")
+log = logging.getLogger(__name__)
+
+quote = functools.partial(_quote, safe="")
+
+
+def make_dict_from_pairs(key_prefix, pairs, path_sep="/"):
+ result = {}
+ len_prefix = len(key_prefix)
+ if isinstance(pairs, dict):
+ iterator = pairs.items()
+ else:
+ iterator = pairs
+ for k, v in iterator:
+ if not k.startswith(key_prefix):
+ continue
+ subkey = k[len_prefix:]
+ if subkey.startswith(path_sep):
+ subkey = subkey[1:]
+ path_components = subkey.split("/")
+ parent = result
+ for p in path_components[:-1]:
+ p = unquote(p)
+ if p not in parent:
+ parent[p] = {}
+ if p in parent and not isinstance(parent[p], dict):
+ root = parent[p]
+ parent[p] = {"": root}
+ parent = parent[p]
+ parent[unquote(path_components[-1])] = v
+ return result
+
+
+def _slash(v: str):
+ return v.rstrip("/") + "/" if len(v) > 0 else ""
+
+
+def reauthenticate(etcd_sync, creds, executor):
+ # This code is taken from the constructor of etcd3.client.Etcd3Client class.
+ # Related issue: kragniz/python-etcd3#580
+ etcd_sync.auth_stub = etcdrpc.AuthStub(etcd_sync.channel)
+ auth_request = etcdrpc.AuthenticateRequest(
+ name=creds["user"],
+ password=creds["password"],
+ )
+ resp = etcd_sync.auth_stub.Authenticate(auth_request, etcd_sync.timeout)
+ etcd_sync.metadata = (("token", resp.token),)
+ etcd_sync.call_credentials = grpc.metadata_call_credentials(
+ EtcdTokenCallCredentials(resp.token)
+ )
+
+
+def reconn_reauth_adaptor(meth: Callable):
+ """
+ Retry connection and authentication for the given method.
+
+ :param meth: The method to be wrapped.
+ :return: The wrapped method.
+ """
+
+ @functools.wraps(meth)
+ def wrapped(self, *args, **kwargs):
+ num_reauth_tries = 0
+ num_reconn_tries = 0
+ while True:
+ try:
+ return meth(self, *args, **kwargs)
+ except etcd3.exceptions.ConnectionFailedError:
+ if num_reconn_tries >= 20:
+ log.warning(
+ "etcd3 connection failed more than %d times. retrying after 1 sec...",
+ num_reconn_tries,
+ )
+ else:
+ log.debug("etcd3 connection failed. retrying after 1 sec...")
+ time.sleep(1.0)
+ num_reconn_tries += 1
+ continue
+ except grpc.RpcError as e:
+ if (
+ e.code() == grpc.StatusCode.UNAUTHENTICATED
+ or (
+ e.code() == grpc.StatusCode.UNKNOWN
+ and "invalid auth token" in e.details()
+ )
+ ) and self._creds:
+ if num_reauth_tries > 0:
+ raise
+ reauthenticate(self.client, self._creds, None)
+ log.debug("etcd3 reauthenticated due to auth token expiration.")
+ num_reauth_tries += 1
+ continue
+ else:
+ raise
+
+ return wrapped
+
+
+# https://github.com/qqq-tech/backend.ai-common/blob/main/src/ai/backend/common/etcd.py
+class EtcdClient(object):
+ """A etcd client proxy."""
+
+ _suffer_status_code = (
+ grpc.StatusCode.UNAVAILABLE,
+ grpc.StatusCode.ABORTED,
+ grpc.StatusCode.RESOURCE_EXHAUSTED,
+ )
+
+ def __init__(
+ self,
+ etcd_host,
+ etcd_port,
+ namespace="marie",
+ credentials=None,
+ encoding="utf8",
+ retry_times=10,
+ ):
+ self.client = None # type: etcd3.client
+ self._host = etcd_host
+ self._port = etcd_port
+ self._client_idx = 0
+ self._cluster = None
+ self.encoding = encoding
+ self.retry_times = 3 # retry_times
+ self.ns = namespace
+ self._creds = credentials
+
+ self.connect()
+
+ def _mangle_key(self, key: str) -> bytes:
+ if key.startswith("/"):
+ key = key[1:]
+ return f"/{self.ns}/{key}".encode(self.encoding)
+
+ def _demangle_key(self, k: Union[bytes, str]) -> str:
+ if isinstance(k, bytes):
+ k = k.decode(self.encoding)
+ prefix = f"/{self.ns}/"
+ if k.startswith(prefix):
+ k = k[len(prefix) :]
+ return k
+
+ def call(self, method, *args, **kwargs):
+ """Etcd operation proxy method."""
+ if self._cluster is None:
+ raise RuntimeFailToStart("Etcd client not initialized.")
+
+ times = 0
+ while times < self.retry_times:
+ client = self._cluster[self._client_idx]
+ try:
+ ret = getattr(client, method)(*args, **kwargs)
+ return ret
+ except _Rendezvous as e:
+ if e.code() in self._suffer_status_code:
+ times += 1
+ self._client_idx = (self._client_idx + 1) % len(self._cluster)
+ log.info(f"Failed with exception {e}, retry after 1 second.")
+ time.sleep(1)
+ raise e # raise exception if not in suffer status code
+ except Exception as e:
+ times += 1
+ log.info(f"Failed with exception {e}, retry after 1 second.")
+ time.sleep(1)
+
+ raise ValueError(f"Failed after {times} times.")
+
+ def _watch(
+ self, raw_key: bytes, event_callback: Callable, prefix: bool = False, **kwargs
+ ) -> int:
+ """Watch a key in etcd."""
+
+ print("Watching raw key:", raw_key)
+
+ def _watch_callback(response: etcd3.watch.WatchResponse):
+ if isinstance(response, grpc.RpcError):
+ if response.code() == grpc.StatusCode.UNAVAILABLE or (
+ response.code() == grpc.StatusCode.UNKNOWN
+ and "invalid auth token" not in response.details()
+ ):
+ # server restarting or terminated
+ return
+ else:
+ raise RuntimeError(f"Unexpected RPC Error: {response}")
+
+ for ev in response.events:
+ log.info(f"Received etcd event: {ev}")
+ if isinstance(ev, etcd3.events.PutEvent):
+ ev_type = "put"
+ elif isinstance(ev, etcd3.events.DeleteEvent):
+ ev_type = "delete"
+ else:
+ raise TypeError("Not recognized etcd event type.")
+ # etcd3 library uses a separate thread for its watchers.
+ event = Event(
+ self._demangle_key(ev.key),
+ ev_type,
+ ev.value.decode(self.encoding),
+ )
+ event_callback(self._demangle_key(ev.key), event)
+
+ try:
+ if prefix:
+ watch_id = self.client.add_watch_prefix_callback(
+ raw_key, _watch_callback, **kwargs
+ )
+ else:
+ watch_id = self.client.add_watch_callback(
+ raw_key, _watch_callback, **kwargs
+ )
+ return watch_id
+ except Exception as ex:
+ raise ex
+
+ @reconn_reauth_adaptor
+ def watch(self, key: str, callback, **kwargs):
+ scope_prefix = ""
+ mangled_key = self._mangle_key(f"{_slash(scope_prefix)}{key}")
+ return self._watch(mangled_key, callback, **kwargs)
+
+ @reconn_reauth_adaptor
+ def add_watch_prefix_callback(self, key_prefix: str, callback: Callable, **kwargs):
+ scope_prefix = ""
+ mangled_key = self._mangle_key(f"{_slash(scope_prefix)}{key_prefix}")
+ return self._watch(mangled_key, callback, prefix=True, **kwargs)
+
+ @reconn_reauth_adaptor
+ def get(self, key: str) -> tuple:
+ """
+ Get a single key from the etcd.
+ Returns ``None`` if the key does not exist.
+ The returned value may be an empty string if the value is a zero-length string.
+
+ :param key: The key. This must be quoted by the caller as needed.
+ :return:
+ """
+
+ mangled_key = self._mangle_key(key)
+ value, _ = self.client.get(mangled_key)
+ return value.decode(self.encoding) if value is not None else None
+
+ @reconn_reauth_adaptor
+ def get_all(self):
+ return self.client.get_all()
+
+ @reconn_reauth_adaptor
+ def put(self, key: str, val: str, lease=None) -> tuple:
+ """
+ Put a single key-value pair to the etcd.
+
+ :param key: The key. This must be quoted by the caller as needed.
+ :param val: The value.
+ :param lease: The lease ID.
+ :return: The key and value.
+ """
+
+ scope_prefix = ""
+ mangled_key = self._mangle_key(f"{_slash(scope_prefix)}{key}")
+ val = self.client.put(mangled_key, str(val).encode(self.encoding), lease=lease)
+ return mangled_key, val
+
+ @reconn_reauth_adaptor
+ def put_prefix(self, key: str, dict_obj: Mapping[str, str]):
+ """
+ Put a nested dict object under the given key prefix.
+ All keys in the dict object are automatically quoted to avoid conflicts with the path separator.
+
+ :param key: Prefix to put the given data. This must be quoted by the caller as needed.
+ :param dict_obj: Nested dictionary representing the data.
+ :return:
+ """
+ scope_prefix = ""
+ flattened_dict: Dict[str, str] = {}
+
+ def _flatten(prefix: str, inner_dict: Mapping[str, str]) -> None:
+ for k, v in inner_dict.items():
+ if k == "":
+ flattened_key = prefix
+ else:
+ flattened_key = prefix + "/" + quote(k)
+ if isinstance(v, dict):
+ _flatten(flattened_key, v)
+ else:
+ flattened_dict[flattened_key] = v
+
+ _flatten(key, dict_obj)
+
+ return self.client.transaction(
+ [],
+ [
+ self.client.transactions.put(
+ self._mangle_key(f"{_slash(scope_prefix)}{k}"),
+ str(v).encode(self.encoding),
+ )
+ for k, v in flattened_dict.items()
+ ],
+ [],
+ )
+
+ @reconn_reauth_adaptor
+ def get_prefix(self, key_prefix: str, sort_order=None, sort_target="key") -> dict:
+ """
+ Retrieves all key-value pairs under the given key prefix as a nested dictionary.
+ All dictionary keys are automatically unquoted.
+ If a key has a value while it is also used as path prefix for other keys,
+ the value directly referenced by the key itself is included as a value in a dictionary
+ with the empty-string key.
+ :param key_prefix: Prefix to get the data. This must be quoted by the caller as needed.
+ :return: A dict object representing the data.
+ """
+ scope_prefix = ""
+ mangled_key_prefix = self._mangle_key(f"{_slash(scope_prefix)}{key_prefix}")
+ results = self.client.get_prefix(
+ mangled_key_prefix, sort_order=sort_order, sort_target=sort_target
+ )
+ pair_sets = {
+ self._demangle_key(k.key): v.decode(self.encoding) for v, k in results
+ }
+
+ return make_dict_from_pairs(
+ f"{_slash(scope_prefix)}{key_prefix}", pair_sets, "/"
+ )
+
+ @reconn_reauth_adaptor
+ def lease(self, ttl, lease_id=None):
+ """Create a new lease."""
+ return self.client.lease(ttl, lease_id=lease_id)
+
+ @reconn_reauth_adaptor
+ def delete(self, key: str):
+ scope_prefix = ""
+ mangled_key = self._mangle_key(f"{_slash(scope_prefix)}{key}")
+ return self.client.delete(mangled_key)
+
+ @reconn_reauth_adaptor
+ def delete_prefix(self, key_prefix: str):
+ scope_prefix = ""
+ mangled_key_prefix = self._mangle_key(f"{_slash(scope_prefix)}{key_prefix}")
+ return self.client.delete_prefix(mangled_key_prefix)
+
+ @reconn_reauth_adaptor
+ def replace(self, key: str, initial_val: str, new_val: str):
+ scope_prefix = ""
+ mangled_key = self._mangle_key(f"{_slash(scope_prefix)}{key}")
+ return self.client.replace(mangled_key, initial_val, new_val)
+
+ @reconn_reauth_adaptor
+ def cancel_watch(self, watch_id):
+ return self.client.cancel_watch(watch_id)
+
+ @reconn_reauth_adaptor
+ def reconnect(self) -> bool:
+ """
+ Reconnect to etcd. This method is used to recover from a connection failure.
+ :return: True if reconnected successfully. False otherwise.
+ """
+ log.warning("Reconnecting to etcd.")
+ try:
+ connected = self.connect()
+ except Exception as e:
+ log.error(f"Failed to reconnect to etcd. {e}")
+ connected = False
+ log.warning(f"Reconnected to etcd. {connected}")
+ return connected
+
+ def connect(self) -> bool:
+ addr = f"{self._host}:{self._port}"
+ times = 0
+ last_ex = None
+
+ while times < self.retry_times:
+ try:
+ self.client = etcd3.client(
+ host=self._host,
+ port=self._port,
+ user=self._creds.get("user") if self._creds else None,
+ password=self._creds.get("password") if self._creds else None,
+ )
+ self._cluster = [member._etcd_client for member in self.client.members]
+ break
+ except grpc.RpcError as e:
+ times += 1
+ last_ex = e
+ if e.code() in (grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN):
+ log.error(
+ f"etcd3 connection failed. retrying after 1 sec, attempt # {times} of {self.retry_times}"
+ )
+ time.sleep(1)
+ continue
+ raise e
+ if times >= self.retry_times:
+ raise RuntimeFailToStart(
+ f"Initialize etcd client failed failed after {self.retry_times} times. Due to {last_ex}"
+ )
+ log.info(f'using etcd cluster from {addr} with namespace "{self.ns}"')
+ return True
+
+
+if __name__ == "__main__":
+ etcd_client = EtcdClient("localhost", 2379)
+ etcd_client.put("key", "Value XYZ")
+
+ kv = etcd_client.get("key")
+ print(etcd_client.get("key"))
+ # etcd_client.delete('key')
+ print(etcd_client.get("key"))
+
+ kv = {"key1": "Value 1", "key2": "Value 2", "key3": "Value 3"}
+
+ etcd_client.put_prefix("prefix", kv)
+
+ print(etcd_client.get_prefix("prefix"))
+
+ print("------ GET ALL ---------")
+ for kv in etcd_client.get_all():
+ v = kv[0].decode("utf8")
+ k = kv[1].key.decode("utf8")
+ print(k, v)
diff --git a/marie/serve/discovery/registry.py b/marie/serve/discovery/registry.py
new file mode 100644
index 00000000..4e9b218d
--- /dev/null
+++ b/marie/serve/discovery/registry.py
@@ -0,0 +1,311 @@
+import abc
+import asyncio
+import logging
+import threading
+import time
+from typing import Union
+
+import etcd3
+
+from marie.helper import get_or_reuse_loop
+from marie.serve.discovery.address import JsonAddress, PlainAddress
+from marie.serve.discovery.etcd_client import EtcdClient
+from marie.serve.discovery.util import form_service_key
+from marie.utils.timer import RepeatedTimer
+
+__all__ = ["EtcdServiceRegistry"]
+
+logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+log = logging.getLogger(__name__)
+
+
+class ServiceRegistry(abc.ABC):
+ """A service registry."""
+
+ @abc.abstractmethod
+ def register(self, service_names, service_addr, service_ttl):
+ """Register services with the same address."""
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def heartbeat(self, service_addr=None):
+ """Service registry heartbeat."""
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def unregister(self, service_names, service_addr):
+ """Unregister services with the same address."""
+ raise NotImplementedError
+
+
+class EtcdServiceRegistry(ServiceRegistry):
+ """service registry based on etcd."""
+
+ def __init__(
+ self,
+ etcd_host: str,
+ etcd_port: int,
+ etcd_client: EtcdClient = None,
+ heartbeat_time=-1,
+ ):
+ """Initialize etcd service registry.
+
+ :param heartbeat_time:
+ :param etcd_host: (optional) etcd node host for :class:`client.EtcdClient`.
+ :param etcd_port: (optional) etcd node port for :class:`client.EtcdClient`.
+ :param etcd_client: (optional) A :class:`client.EtcdClient` object.
+ :param heartbeat_time: (optional) service registry heartbeat time interval, default -1. If -1, no heartbeat.
+
+ """
+ if etcd_host is None and etcd_client is None:
+ raise ValueError("etcd_host or etcd_client must be provided.")
+ self._client = etcd_client if etcd_client else EtcdClient(etcd_host, etcd_port)
+ self._leases = {}
+ self._services = {}
+ self._loop = get_or_reuse_loop()
+ self._heartbeat_time = heartbeat_time
+ self.setup_heartbeat_async()
+ # self.setup_heartbeat()
+
+ def get_lease(self, service_addr, service_ttl):
+ """Get a service lease from etcd.
+
+ :param service_addr: service address.
+ :param service_ttl: service lease ttl(seconds).
+ :rtype `etcd3.lease.Lease`
+
+ """
+ lease = self._leases.get(service_addr)
+ if lease and lease.remaining_ttl > 0:
+ return lease
+
+ lease_id = hash(service_addr)
+ lease = self._client.lease(service_ttl, lease_id)
+ self._leases[service_addr] = lease
+ return lease
+
+ def register(
+ self,
+ service_names: Union[str, list[str]],
+ service_addr: str,
+ service_ttl: int,
+ addr_cls=None,
+ metadata=None,
+ ) -> int:
+ """Register services with the same address.
+
+ :param service_names: A collection of service name.
+ :param service_addr: server address.
+ :param service_ttl: service ttl(seconds).
+ :param addr_cls: format class of service address.
+ :param metadata: extra meta data for JsonAddress.
+ rtype `etcd3.lease.Lease`
+ """
+ if isinstance(service_names, str):
+ service_names = [service_names]
+ lease = self.get_lease(service_addr, service_ttl)
+ addr_cls = addr_cls or PlainAddress
+
+ for service_name in service_names:
+ key = form_service_key(service_name, service_addr)
+ resolved = self._client.get(key)
+
+ if resolved:
+ log.warning(
+ f"Service already registered : {service_name}@{service_addr}"
+ )
+ continue
+
+ if addr_cls == JsonAddress:
+ addr_obj = addr_cls(service_addr, metadata=metadata)
+ else:
+ addr_obj = addr_cls(service_addr)
+
+ addr_val = addr_obj.add_value()
+ put_key, _ = self._client.put(key, addr_val, lease=lease)
+ log.warning(
+ f"Registering service : {service_name}@{service_addr} : {put_key}"
+ )
+ try:
+ self._services[service_addr].add(service_name)
+ except KeyError:
+ self._services[service_addr] = {service_name}
+ return lease
+
+ def heartbeat(self, service_addr=None, service_ttl=5):
+ """service heartbeat."""
+ log.info(f"Heartbeat service_addr : {service_addr}")
+ if service_addr:
+ lease = self.get_lease(service_addr, service_ttl)
+ leases = ((service_addr, lease),)
+ else:
+ leases = tuple(self._leases.items())
+
+ for service_addr, lease in leases:
+ registered = self._services.get(service_addr, None)
+ if not registered:
+ continue
+ try:
+ log.debug(
+ f"Refreshing lease for: {service_addr}, {lease.remaining_ttl}"
+ )
+ ret = lease.refresh()[0]
+ if ret.TTL == 0:
+ self.register(
+ self._services.get(service_addr, []),
+ service_addr,
+ lease.ttl,
+ )
+ except (ValueError, etcd3.exceptions.ConnectionFailedError) as e:
+ if (
+ isinstance(e, etcd3.exceptions.ConnectionFailedError)
+ or str(e) == "Trying to use a failed node"
+ ):
+ log.warning(
+ f"Trying to use a failed node, attempting to reconnect."
+ )
+ if self._client.reconnect():
+ log.info("Reconnected to etcd")
+ lease.etcd_client = self._client.client
+ except Exception as e:
+ raise e
+
+ def unregister(self, service_names, service_addr, addr_cls=None):
+ """Unregister services with the same address.
+
+ :param service_names: A collection of service name.
+ :param service_addr: server address.
+ :param addr_cls: format class of service address.
+ """
+
+ addr_cls = addr_cls or PlainAddress
+ etcd_delete = True
+ if addr_cls != PlainAddress:
+ etcd_delete = False
+
+ registered_services = self._services.get(service_addr, {})
+ for service_name in service_names:
+ log.info(f"Unregistering service : {service_name}@{service_addr}")
+ key = form_service_key(service_name, service_addr)
+ if etcd_delete:
+ self._client.delete(key)
+ else:
+ self._client.put(addr_cls(service_addr).delete_value())
+ registered_services.discard(service_name)
+
+ def setup_heartbeat_async(self):
+ """
+ Set up an asynchronous heartbeat process.
+
+ :return: None
+ """
+
+ async def _heartbeat_setup():
+ initial_delay = self._heartbeat_time
+ await asyncio.sleep(initial_delay)
+ while True:
+ try:
+ self.heartbeat()
+ await asyncio.sleep(self._heartbeat_time)
+ except Exception as e:
+ log.error(f"Error in heartbeat : {str(e)}")
+
+ def _polling_status():
+ task = self._loop.create_task(_heartbeat_setup())
+ self._loop.run_until_complete(task)
+
+ polling_status_thread = threading.Thread(
+ target=_polling_status,
+ daemon=True,
+ )
+ polling_status_thread.start()
+
+ def setup_heartbeat(self):
+ """
+ This method is used to set up a heartbeat for the etcd service registry.
+
+ :return: None
+ """
+ log.info(
+ f"Setting up heartbeat for etcd service registry : {self._heartbeat_time}",
+ )
+ if self._heartbeat_time > 0:
+ rt = RepeatedTimer(
+ self._heartbeat_time,
+ self.heartbeat,
+ )
+
+
+def main():
+ import argparse
+
+ parser = argparse.ArgumentParser(description="service discovery etcd cluster")
+ parser.add_argument(
+ "--host",
+ help="the etcd host, default = 127.0.0.1",
+ required=False,
+ default="127.0.0.1",
+ )
+ parser.add_argument(
+ "--port",
+ help="the etcd port, default = 2379",
+ required=False,
+ default=2379,
+ type=int,
+ )
+ parser.add_argument("--ca-cert", help="the etcd ca-cert", required=False)
+ parser.add_argument("--cert-key", help="the etcd cert key", required=False)
+ parser.add_argument("--cert-cert", help="the etcd cert", required=False)
+ parser.add_argument("--service-key", help="the service key", required=True)
+ parser.add_argument(
+ "--service-addr", help="the service address host:port ", required=True
+ )
+ parser.add_argument(
+ "--lease-ttl",
+ help="the lease ttl in seconds, default is 10",
+ required=False,
+ default=10,
+ type=int,
+ )
+ parser.add_argument("--my-id", help="my identifier", required=True)
+ parser.add_argument(
+ "--timeout",
+ help="the etcd operation timeout in seconds, default is 2",
+ required=False,
+ type=int,
+ default=2,
+ )
+ args = parser.parse_args()
+
+ params = {"host": args.host, "port": args.port, "timeout": args.timeout}
+ if args.ca_cert:
+ params["ca_cert"] = args.ca_cert
+ if args.cert_key:
+ params["cert_key"] = args.cert_key
+ if args.cert_cert:
+ params["cert_cert"] = args.cert_cert
+
+ log.info(f"args : {args}")
+
+ etcd_registry = EtcdServiceRegistry(args.host, args.port, heartbeat_time=5)
+ etcd_registry.register([args.service_key], args.service_addr, args.lease_ttl)
+
+ try:
+ while True:
+ time.sleep(2) # Keep the program running.
+ except KeyboardInterrupt:
+ etcd_registry.unregister([args.service_key], args.service_addr)
+ print("Service unregistered.")
+
+
+if __name__ == "__main__":
+ main()
+
+if __name__ == "__main__XXXX":
+ etcd_registry = EtcdServiceRegistry("127.0.0.1", 2379, heartbeat_time=5)
+ etcd_registry.register(["gateway/service_test"], "127.0.0.1:50011", 6)
+
+ print(etcd_registry._services)
+ print(etcd_registry._leases)
diff --git a/marie/serve/discovery/resolver.py b/marie/serve/discovery/resolver.py
new file mode 100644
index 00000000..2d7789cd
--- /dev/null
+++ b/marie/serve/discovery/resolver.py
@@ -0,0 +1,290 @@
+import abc
+import logging
+import threading
+import time
+
+from marie.serve.discovery.address import PlainAddress
+from marie.serve.discovery.etcd_client import EtcdClient, Event
+
+__all__ = ["EtcdServiceResolver"]
+
+from marie.serve.discovery.util import form_service_key
+
+logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+log = logging.getLogger(__name__)
+
+
+class ServiceResolver(abc.ABC):
+ """gRPC service Resolver class."""
+
+ @abc.abstractmethod
+ def resolve(self, name):
+ raise NotADirectoryError
+
+ @abc.abstractmethod
+ def update(self, **kwargs):
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def listen(self):
+ raise NotImplementedError
+
+
+class EtcdServiceResolver(ServiceResolver):
+ """gRPC service resolver based on Etcd."""
+
+ def __init__(
+ self,
+ etcd_host=None,
+ etcd_port=None,
+ etcd_client=None,
+ start_listener=True,
+ listen_timeout=5,
+ addr_cls=None,
+ namespace="marie",
+ ):
+ """Initialize etcd service resolver.
+
+ :param etcd_host: (optional) etcd node host for :class:`client.EtcdClient`.
+ :param etcd_port: (optional) etcd node port for :class:`client.EtcdClient`.
+ :param etcd_client: (optional) A :class:`client.EtcdClient` object.
+ :param start_listener: (optional) Indicate whether starting the resolver listen thread.
+ :param listen_timeout: (optional) Resolver thread listen timeout.
+ :param addr_cls: (optional) address format class.
+ :param namespace: (optional) Etcd namespace.
+ """
+ super().__init__()
+
+ if etcd_host is None and etcd_client is None:
+ raise ValueError("etcd_host or etcd_client must be provided.")
+ self._listening = False
+ self._stopped = False
+ self._listen_thread = None
+ self._listen_timeout = listen_timeout
+ self._lock = threading.Lock()
+ self._client = (
+ etcd_client
+ if etcd_client
+ else EtcdClient(etcd_host, etcd_port, namespace=namespace)
+ )
+ self.watched_services = {}
+ self._names = {}
+ self._addr_cls = addr_cls or PlainAddress
+
+ if start_listener:
+ self.start_listener()
+
+ def resolve(self, name: str) -> list:
+ """Resolve gRPC service name.
+
+ :param name: gRPC service name.
+ :rtype list: A collection gRPC server address.
+
+ """
+ with self._lock:
+ try:
+ return self._names[name]
+ except KeyError:
+ addrs = self.get(name)
+ self._names[name] = addrs
+ return addrs
+
+ def get(self, name: str):
+ """Get values from Etcd.
+
+ :param name: Etcd key prefix name.
+ :rtype list: A collection of Etcd values.
+
+ """
+ keys = self._client.get_prefix(name)
+ vals = []
+ plain = True
+ if self._addr_cls != PlainAddress:
+ plain = False
+
+ for val, metadata in keys.items():
+ if plain:
+ vals.append(self._addr_cls.from_value(val))
+ else:
+ add, addr = self._addr_cls.from_value(val)
+ if add:
+ vals.append(addr)
+
+ return vals
+
+ def update(self, **kwargs):
+ """Add or delete service address.
+
+ :param kwargs: Dictionary of ``'service_name': ((add-address, delete-address)).``
+
+ """
+ with self._lock:
+ for name, (add, delete) in kwargs.items():
+ try:
+ self._names[name].extend(add)
+ except KeyError:
+ self._names[name] = add
+
+ for del_item in delete:
+ try:
+ self._names[name].remove(del_item)
+ except ValueError:
+ continue
+
+ def listen(self):
+ """Listen for change about service address."""
+ while not self._stopped:
+ for name in self._names:
+ try:
+ vals = self.get(name)
+ except:
+ continue
+ else:
+ with self._lock:
+ self._names[name] = vals
+
+ time.sleep(self._listen_timeout)
+
+ def watch_service(
+ self, service_name: str, event_callback: callable, notify_on_start=True
+ ):
+ """Watch service event."""
+ log.info(f"Watching service : {service_name} for changes.")
+ log.info(f"Notify on start : {notify_on_start}")
+ watch_id = self._client.add_watch_prefix_callback(service_name, event_callback)
+ self.watched_services[service_name] = watch_id
+ if notify_on_start:
+ resolved = self._client.get_prefix(service_name)
+ for val, metadata in resolved.items():
+ log.info(
+ f"Resolved service: {service_name}, {val}, {metadata}",
+ )
+ key = form_service_key(service_name, val)
+ event = Event(
+ service_name,
+ "put",
+ metadata,
+ )
+ event_callback(service_name, event)
+
+ def stop_watch_service(self, service_name: str = None) -> None:
+ """Stop watching services."""
+
+ if service_name:
+ watch_id = self.watched_services.pop(service_name, None)
+ if watch_id:
+ self._client.cancel_watch(watch_id)
+ log.info(
+ f"Stop watching service: {service_name}, {watch_id}",
+ )
+ else:
+ for service_name, watch_id in self.watched_services.items():
+ self._client.cancel_watch(watch_id)
+ log.info(f"Stop watching service: {service_name}, {watch_id}")
+ self.watched_services.clear()
+
+ def start_listener(self, daemon=True):
+ """Start listen thread.
+
+ :param daemon: Indicate whether start thread as a daemon.
+
+ """
+ if self._listening:
+ return
+
+ thread_name = "Thread-resolver-listener"
+ self._listen_thread = threading.Thread(target=self.listen, name=thread_name)
+ self._listen_thread.daemon = daemon
+ self._listen_thread.start()
+ self._listening = True
+
+ def stop(self):
+ """Stop service resolver."""
+ if self._stopped:
+ return
+
+ self._stopped = True
+
+ def __del__(self):
+ self.stop()
+
+
+def main():
+ import argparse
+
+ parser = argparse.ArgumentParser(description="service discovery etcd cluster")
+ parser.add_argument(
+ "--host",
+ help="the etcd host, default = 127.0.0.1",
+ required=False,
+ default="127.0.0.1",
+ )
+ parser.add_argument(
+ "--port",
+ help="the etcd port, default = 2379",
+ required=False,
+ default=2379,
+ type=int,
+ )
+ parser.add_argument("--ca-cert", help="the etcd ca-cert", required=False)
+ parser.add_argument("--cert-key", help="the etcd cert key", required=False)
+ parser.add_argument("--cert-cert", help="the etcd cert", required=False)
+ parser.add_argument("--service-key", help="the service key", required=True)
+ parser.add_argument(
+ "--timeout",
+ help="the etcd operation timeout in seconds, default is 2",
+ required=False,
+ type=int,
+ default=2,
+ )
+ args = parser.parse_args()
+
+ params = {"host": args.host, "port": args.port, "timeout": args.timeout}
+ if args.ca_cert:
+ params["ca_cert"] = args.ca_cert
+ if args.cert_key:
+ params["cert_key"] = args.cert_key
+ if args.cert_cert:
+ params["cert_cert"] = args.cert_cert
+
+ log.info(f"args : {args}")
+
+ resolver = EtcdServiceResolver(
+ args.host, args.port, namespace="marie", start_listener=False, listen_timeout=5
+ )
+
+ log.info(f"Resolved : {resolver.resolve(args.service_key)}")
+
+ resolver.watch_service(
+ args.service_key,
+ lambda service, event: log.info(f"Event from service : {service}, {event}"),
+ )
+
+ try:
+ while True:
+ log.info(f"Checking service address...")
+ time.sleep(2)
+ except KeyboardInterrupt:
+ print("Service stopped.")
+
+
+if __name__ == "__main__":
+ main()
+
+if __name__ == "__main__XXX":
+ resolver = EtcdServiceResolver(
+ "127.0.0.1", 2379, namespace="marie", start_listener=False, listen_timeout=5
+ )
+ print(resolver.resolve("gateway/service_test"))
+
+ resolver.watch_service(
+ "gateway/service_test",
+ lambda service, event: print("Event from service : ", service, event),
+ )
+
+ while True:
+ print("Checking service address...")
+ # print(resolver.resolve('service_test'))
+ time.sleep(2)
diff --git a/marie/serve/discovery/util.py b/marie/serve/discovery/util.py
new file mode 100644
index 00000000..2e26cff9
--- /dev/null
+++ b/marie/serve/discovery/util.py
@@ -0,0 +1,10 @@
+import re
+
+
+def form_service_key(service_name: str, service_addr: str):
+ """Return service's key in etcd."""
+ # validate service_addr format meets the requirement of host:port or ip:port or scheme://host:port
+ # if not re.match(r'^[a-zA-Z]+://[a-zA-Z0-9.]+:\d+$', service_addr):
+ # raise ValueError(f"Invalid service address: {service_addr}")
+
+ return '/'.join((service_name, service_addr))
diff --git a/marie/serve/executors/__init__.py b/marie/serve/executors/__init__.py
index 60fbaee0..1690f73e 100644
--- a/marie/serve/executors/__init__.py
+++ b/marie/serve/executors/__init__.py
@@ -402,10 +402,10 @@ def __init__(
self._init_workspace = workspace
if __dry_run_endpoint__ not in self.requests:
- self.requests[
- __dry_run_endpoint__
- ] = _FunctionWithSchema.get_function_with_schema(
- self.__class__._dry_run_func
+ self.requests[__dry_run_endpoint__] = (
+ _FunctionWithSchema.get_function_with_schema(
+ self.__class__._dry_run_func
+ )
)
else:
self.logger.warning(
@@ -413,10 +413,10 @@ def __init__(
f' So it is recommended not to expose this endpoint. '
)
if type(self) == BaseExecutor:
- self.requests[
- __default_endpoint__
- ] = _FunctionWithSchema.get_function_with_schema(
- self.__class__._dry_run_func
+ self.requests[__default_endpoint__] = (
+ _FunctionWithSchema.get_function_with_schema(
+ self.__class__._dry_run_func
+ )
)
self._lock = contextlib.AsyncExitStack()
@@ -596,14 +596,14 @@ def _add_requests(self, _requests: Optional[Dict]):
_func = getattr(self.__class__, func)
if callable(_func):
# the target function is not decorated with `@requests` yet
- self.requests[
- endpoint
- ] = _FunctionWithSchema.get_function_with_schema(_func)
+ self.requests[endpoint] = (
+ _FunctionWithSchema.get_function_with_schema(_func)
+ )
elif typename(_func) == 'marie.executors.decorators.FunctionMapper':
# the target function is already decorated with `@requests`, need unwrap with `.fn`
- self.requests[
- endpoint
- ] = _FunctionWithSchema.get_function_with_schema(_func.fn)
+ self.requests[endpoint] = (
+ _FunctionWithSchema.get_function_with_schema(_func.fn)
+ )
else:
raise TypeError(
f'expect {typename(self)}.{func} to be a function, but receiving {typename(_func)}'
@@ -620,7 +620,14 @@ def _validate_sagemaker(self):
):
return
+ remove_keys = set()
+ for k in self.requests.keys():
+ if k != '/invocations':
+ remove_keys.add(k)
+
if '/invocations' in self.requests:
+ for k in remove_keys:
+ self.requests.pop(k)
return
if (
@@ -633,12 +640,16 @@ def _validate_sagemaker(self):
f'Using "{endpoint_to_use}" as "/invocations" route'
)
self.requests['/invocations'] = self.requests[endpoint_to_use]
+ for k in remove_keys:
+ self.requests.pop(k)
return
if len(self.requests) == 1:
route = list(self.requests.keys())[0]
self.logger.warning(f'Using "{route}" as "/invocations" route')
self.requests['/invocations'] = self.requests[route]
+ for k in remove_keys:
+ self.requests.pop(k)
return
raise ValueError('Cannot identify the endpoint to use for "/invocations"')
@@ -1103,7 +1114,7 @@ def serve(
:param port_monitoring: The port on which the prometheus server is exposed, default is a random port between [49152, 65535]
:param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64")
:param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET'].
- :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER'].
+ :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER', 'AZURE'].
:param provider_endpoint: If set, Executor endpoint will be explicitly chosen and used in the custom container operated by the provider.
:param py_modules: The customized python modules need to be imported before loading the executor
diff --git a/marie/serve/executors/run.py b/marie/serve/executors/run.py
index 9c4f65d3..2a0c9927 100644
--- a/marie/serve/executors/run.py
+++ b/marie/serve/executors/run.py
@@ -66,7 +66,7 @@ def pascal_case_dict(d):
raft_id,
raft_dir,
args.name,
- score_threshold=executor_target,
+ executor_target,
**raft_configuration,
)
@@ -149,10 +149,12 @@ def _set_envs():
)
except Exception as ex:
logger.error(
- f'{ex!r} during {runtime_cls!r} initialization'
- + f'\n add "--quiet-error" to suppress the exception details'
- if not args.quiet_error
- else '',
+ (
+ f'{ex!r} during {runtime_cls!r} initialization'
+ + f'\n add "--quiet-error" to suppress the exception details'
+ if not args.quiet_error
+ else ''
+ ),
exc_info=not args.quiet_error,
)
else:
diff --git a/marie/serve/networking/__init__.py b/marie/serve/networking/__init__.py
index 2ae434ec..82f0f5d7 100644
--- a/marie/serve/networking/__init__.py
+++ b/marie/serve/networking/__init__.py
@@ -22,6 +22,7 @@
from marie.logging.logger import MarieLogger
from marie.proto import jina_pb2
from marie.serve.helper import format_grpc_error
+from marie.serve.networking.balancer.load_balancer import LoadBalancer
from marie.serve.networking.connection_pool_map import _ConnectionPoolMap
from marie.serve.networking.connection_stub import create_async_channel_stub
from marie.serve.networking.instrumentation import (
@@ -63,12 +64,13 @@ def __init__(
runtime_name,
logger: Optional[MarieLogger] = None,
compression: Optional[str] = None,
- metrics_registry: Optional['CollectorRegistry'] = None,
- meter: Optional['Meter'] = None,
- aio_tracing_client_interceptors: Optional[Sequence['ClientInterceptor']] = None,
- tracing_client_interceptor: Optional['OpenTelemetryClientInterceptor'] = None,
+ metrics_registry: Optional["CollectorRegistry"] = None,
+ meter: Optional["Meter"] = None,
+ aio_tracing_client_interceptors: Optional[Sequence["ClientInterceptor"]] = None,
+ tracing_client_interceptor: Optional["OpenTelemetryClientInterceptor"] = None,
channel_options: Optional[list] = None,
- load_balancer_type: Optional[str] = 'round_robin',
+ load_balancer_type: Optional[str] = "round_robin",
+ load_balancer: Optional[LoadBalancer] = None,
):
self._logger = logger or MarieLogger(self.__class__.__name__)
self.channel_options = channel_options
@@ -82,32 +84,32 @@ def __init__(
if metrics_registry:
with ImportExtensions(
required=True,
- help_text='You need to install the `prometheus_client` to use the monitoring functionality of marie',
+ help_text="You need to install the `prometheus_client` to use the monitoring functionality of marie",
):
from prometheus_client import Summary
sending_requests_time_metrics = Summary(
- 'sending_request_seconds',
- 'Time spent between sending a request to the Executor/Head and receiving the response',
+ "sending_request_seconds",
+ "Time spent between sending a request to the Executor/Head and receiving the response",
registry=metrics_registry,
- namespace='marie',
- labelnames=('runtime_name',),
+ namespace="marie",
+ labelnames=("runtime_name",),
).labels(runtime_name)
received_response_bytes = Summary(
- 'received_response_bytes',
- 'Size in bytes of the response returned from the Head/Executor',
+ "received_response_bytes",
+ "Size in bytes of the response returned from the Head/Executor",
registry=metrics_registry,
- namespace='marie',
- labelnames=('runtime_name',),
+ namespace="marie",
+ labelnames=("runtime_name",),
).labels(runtime_name)
send_requests_bytes_metrics = Summary(
- 'sent_request_bytes',
- 'Size in bytes of the request sent to the Head/Executor',
+ "sent_request_bytes",
+ "Size in bytes of the request sent to the Head/Executor",
registry=metrics_registry,
- namespace='marie',
- labelnames=('runtime_name',),
+ namespace="marie",
+ labelnames=("runtime_name",),
).labels(runtime_name)
else:
sending_requests_time_metrics = None
@@ -123,21 +125,21 @@ def __init__(
if meter:
self._histograms = _NetworkingHistograms(
sending_requests_time_metrics=meter.create_histogram(
- name='marie_sending_request_seconds',
- unit='s',
- description='Time spent between sending a request to the Executor/Head and receiving the response',
+ name="marie_sending_request_seconds",
+ unit="s",
+ description="Time spent between sending a request to the Executor/Head and receiving the response",
),
received_response_bytes=meter.create_histogram(
- name='marie_received_response_bytes',
- unit='By',
- description='Size in bytes of the response returned from the Head/Executor',
+ name="marie_received_response_bytes",
+ unit="By",
+ description="Size in bytes of the response returned from the Head/Executor",
),
send_requests_bytes_metrics=meter.create_histogram(
- name='marie_sent_request_bytes',
- unit='By',
- description='Size in bytes of the request sent to the Head/Executor',
+ name="marie_sent_request_bytes",
+ unit="By",
+ description="Size in bytes of the request sent to the Head/Executor",
),
- histogram_metric_labels={'runtime_name': runtime_name},
+ histogram_metric_labels={"runtime_name": runtime_name},
)
else:
self._histograms = _NetworkingHistograms()
@@ -153,6 +155,7 @@ def __init__(
tracing_client_interceptor=self.tracing_client_interceptor,
channel_options=self.channel_options,
load_balancer_type=load_balancer_type,
+ load_balancer=load_balancer,
)
self._deployment_address_map = {}
@@ -192,7 +195,7 @@ def send_requests(
for replica_list in shard_replica_lists:
connections.append(replica_list)
else:
- raise ValueError(f'Unsupported polling type {polling_type}')
+ raise ValueError(f"Unsupported polling type {polling_type}")
for replica_list in connections:
task = self._send_requests(
@@ -233,7 +236,7 @@ def send_discover_endpoint(
)
else:
self._logger.debug(
- f'no available connections for deployment {deployment} and shard {shard_id}'
+ f"no available connections for deployment {deployment} and shard {shard_id}"
)
return None
@@ -273,7 +276,7 @@ def send_requests_once(
return result
else:
self._logger.debug(
- f'no available connections for deployment {deployment} and shard {shard_id}'
+ f"no available connections for deployment {deployment} and shard {shard_id}"
)
return None
@@ -310,7 +313,7 @@ def send_single_document_request(
)
return result_async_generator
else:
- self._logger.debug(f'no available connections for deployment {deployment}')
+ self._logger.debug(f"no available connections for deployment {deployment}")
return None
def add_connection(
@@ -369,16 +372,16 @@ async def _handle_aiorpcerror(
self,
error: AioRpcError,
retry_i: int = 0,
- request_id: str = '',
+ request_id: str = "",
tried_addresses: Set[str] = {
- ''
+ ""
}, # same deployment can have multiple addresses (replicas)
total_num_tries: int = 1, # number of retries + 1
- current_address: str = '', # the specific address that was contacted during this attempt
- current_deployment: str = '', # the specific deployment that was contacted during this attempt
+ current_address: str = "", # the specific address that was contacted during this attempt
+ current_deployment: str = "", # the specific deployment that was contacted during this attempt
connection_list: Optional[_ReplicaList] = None,
- task_type: str = 'DataRequest',
- ) -> 'Optional[Union[AioRpcError, InternalNetworkError]]':
+ task_type: str = "DataRequest",
+ ) -> "Optional[Union[AioRpcError, InternalNetworkError]]":
# connection failures, cancelled requests, and timed out requests should be retried
# all other cases should not be retried and will be raised immediately
# connection failures have the code grpc.StatusCode.UNAVAILABLE
@@ -390,15 +393,15 @@ async def _handle_aiorpcerror(
skip_resetting = False
if (
error.code() == grpc.StatusCode.UNAVAILABLE
- and 'not the leader' in error.details()
+ and "not the leader" in error.details()
):
self._logger.debug(
- f'RAFT node of {current_deployment} is not the leader. Trying next replica, if available.'
+ f"RAFT node of {current_deployment} is not the leader. Trying next replica, if available."
)
skip_resetting = True # no need to reset, no problem with channel
else:
self._logger.debug(
- f'gRPC call to {current_deployment} for {task_type} errored, with error {format_grpc_error(error)} and for the {retry_i + 1}th time.'
+ f"gRPC call to {current_deployment} for {task_type} errored, with error {format_grpc_error(error)} and for the {retry_i + 1}th time."
)
errors_to_retry = [
grpc.StatusCode.UNAVAILABLE,
@@ -415,7 +418,7 @@ async def _handle_aiorpcerror(
return error
elif error.code() in errors_to_retry and retry_i >= total_num_tries - 1:
self._logger.debug(
- f'gRPC call for {current_deployment} failed, retries exhausted'
+ f"gRPC call for {current_deployment} failed, retries exhausted"
)
from marie.excepts import InternalNetworkError
@@ -447,13 +450,13 @@ def _send_single_doc_request(
metadata: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
retries: Optional[int] = -1,
- ) -> 'asyncio.Task[Union[Tuple, AioRpcError, InternalNetworkError]]':
+ ) -> "asyncio.Task[Union[Tuple, AioRpcError, InternalNetworkError]]":
# this wraps the awaitable object from grpc as a coroutine so it can be used as a task
# the grpc call function is not a coroutine but some _AioCall
if endpoint:
metadata = metadata or {}
- metadata['endpoint'] = endpoint
+ metadata["endpoint"] = endpoint
if metadata:
metadata = tuple(metadata.items())
@@ -483,7 +486,10 @@ async def async_generator_wrapper():
break
tried_addresses.add(current_connection.address)
try:
- async for resp, metadata_resp in current_connection.send_single_doc_request(
+ async for (
+ resp,
+ metadata_resp,
+ ) in current_connection.send_single_doc_request(
request=request,
metadata=metadata,
compression=self.compression,
@@ -501,7 +507,7 @@ async def async_generator_wrapper():
current_address=current_connection.address,
current_deployment=current_connection.deployment_name,
connection_list=connections,
- task_type='SingleDocumentRequest',
+ task_type="SingleDocumentRequest",
)
if error:
yield error, None
@@ -520,13 +526,13 @@ def _send_requests(
metadata: Optional[Dict[str, str]] = None,
timeout: Optional[float] = None,
retries: Optional[int] = -1,
- ) -> 'asyncio.Task[Union[Tuple, AioRpcError, InternalNetworkError]]':
+ ) -> "asyncio.Task[Union[Tuple, AioRpcError, InternalNetworkError]]":
# this wraps the awaitable object from grpc as a coroutine so it can be used as a task
# the grpc call function is not a coroutine but some _AioCall
if endpoint:
metadata = metadata or {}
- metadata['endpoint'] = endpoint
+ metadata["endpoint"] = endpoint
if metadata:
metadata = tuple(metadata.items())
@@ -573,7 +579,7 @@ async def task_wrapper():
current_address=current_connection.address,
current_deployment=current_connection.deployment_name,
connection_list=connections,
- task_type='DataRequest',
+ task_type="DataRequest",
)
if error:
return error
@@ -623,7 +629,7 @@ async def task_coroutine():
current_deployment=connection.deployment_name,
connection_list=connection_list,
total_num_tries=total_num_tries,
- task_type='EndpointDiscovery',
+ task_type="EndpointDiscovery",
)
if error:
raise error
diff --git a/marie/serve/networking/balancer/interceptor.py b/marie/serve/networking/balancer/interceptor.py
new file mode 100644
index 00000000..1625ad53
--- /dev/null
+++ b/marie/serve/networking/balancer/interceptor.py
@@ -0,0 +1,38 @@
+import abc
+
+
+class LoadBalancerInterceptor(abc.ABC):
+ """
+ Base class for load balancer interceptors, that can be used to intercept the connection acquisition process
+ and provide callbacks.
+
+ """
+
+ @abc.abstractmethod
+ def on_connection_acquired(self, connection):
+ """
+ Called when a connection is acquired from the load balancer.
+ """
+ pass
+
+ @abc.abstractmethod
+ def on_connection_released(self, connection):
+ """
+ Called when a connection is released back to the load balancer.
+ """
+ pass
+
+ @abc.abstractmethod
+ def on_connection_failed(self, connection, exception):
+ """
+ Called when a connection attempt fails.
+ """
+ pass
+
+ @abc.abstractmethod
+ def on_connections_updated(self, connections):
+ """
+ Called when a connections have been updated.
+ This can be used to update the internal state of the interceptor.
+ """
+ pass
diff --git a/marie/serve/networking/balancer/least_connection_balancer.py b/marie/serve/networking/balancer/least_connection_balancer.py
index 4844da69..e7888abd 100644
--- a/marie/serve/networking/balancer/least_connection_balancer.py
+++ b/marie/serve/networking/balancer/least_connection_balancer.py
@@ -16,9 +16,6 @@ def __init__(self, deployment_name: str, logger: MarieLogger):
super().__init__(deployment_name, logger)
self._rr_counter = 0 # round robin counter
- async def get_next_connection(self, num_retries=3):
- return await self._get_next_connection(num_retries=num_retries)
-
async def _get_next_connection(self, num_retries=3):
# Find the connection with the least active connections
min_active_connections = int(1e9)
@@ -36,9 +33,9 @@ async def _get_next_connection(self, num_retries=3):
self._rr_counter = 0
self._logger.debug(
- f'least_connection_balancer.py: min_use_connections: {min_use_connections}'
- f' min_active_connections: {min_active_connections}'
- f' self._rr_counter: {self._rr_counter}'
+ f"least_connection_balancer.py: min_use_connections: {min_use_connections}"
+ f" min_active_connections: {min_active_connections}"
+ f" self._rr_counter: {self._rr_counter}"
)
# Round robin between the connections with the least active connections
@@ -54,12 +51,12 @@ async def _get_next_connection(self, num_retries=3):
if all_connections_unavailable:
if num_retries <= 0:
raise EstablishGrpcConnectionError(
- f'Error while resetting connections {self._connections} for {self._deployment_name}. Connections cannot be used.'
+ f"Error while resetting connections {self._connections} for {self._deployment_name}. Connections cannot be used."
)
elif connection is None:
# give control back to async event loop so connection resetting can be completed; then retry
self._logger.debug(
- f' No valid connection found for {self._deployment_name}, give chance for potential resetting of connection'
+ f" No valid connection found for {self._deployment_name}, give chance for potential resetting of connection"
)
return await self._get_next_connection(num_retries=num_retries - 1)
except IndexError:
diff --git a/marie/serve/networking/balancer/load_balancer.py b/marie/serve/networking/balancer/load_balancer.py
index dd8b5bf7..e503868d 100644
--- a/marie/serve/networking/balancer/load_balancer.py
+++ b/marie/serve/networking/balancer/load_balancer.py
@@ -1,8 +1,10 @@
import abc
from enum import Enum
-from typing import Union
+from typing import Optional, Sequence, Union
+from marie.excepts import EstablishGrpcConnectionError
from marie.logging.logger import MarieLogger
+from marie.serve.networking.balancer.interceptor import LoadBalancerInterceptor
class LoadBalancerType(Enum):
@@ -30,24 +32,48 @@ def from_value(value: str):
return LoadBalancerType.ROUND_ROBIN
-class LoadBalancer(metaclass=abc.ABCMeta):
+class LoadBalancer(abc.ABC):
"""Base class for load balancers."""
- def __init__(self, deployment_name: str, logger):
+ def __init__(
+ self,
+ deployment_name: str,
+ logger: Optional[MarieLogger] = None,
+ tracing_interceptors: Optional[Sequence[LoadBalancerInterceptor]] = None,
+ ):
self._connections = []
self._deployment_name = deployment_name
- self._logger = logger
+ self._logger = logger or MarieLogger(self.__class__.__name__)
self.active_counter = {}
+ self.debug_loging_enabled = False
+ self.tracing_interceptors = tracing_interceptors or []
- @abc.abstractmethod
async def get_next_connection(self, num_retries=3):
"""
Returns the next connection to be used based on the load balancing algorithm.
:param num_retries: Number of times to retry if the connection is not available.
"""
- ...
+ connection = await self._get_next_connection(num_retries=num_retries)
+
+ if connection is None:
+ raise EstablishGrpcConnectionError(
+ f"Error while acquiring connection {self._deployment_name}. Connection cannot be used."
+ )
+
+ for interceptor in self.tracing_interceptors:
+ interceptor.on_connection_acquired(connection)
+
+ return connection
+
+ @abc.abstractmethod
+ async def _get_next_connection(self, num_retries=3):
+ """
+ Implementation that returns the next connection to be used based on the load balancing algorithm.
+ :param num_retries: Number of times to retry if the connection is not available.
+ """
+ raise NotImplementedError
- def update_connections(self, connections):
+ def update_connections(self, connections: list):
"""
Rebalance the connections.
:param connections: List of connections to be used for load balancing.
@@ -57,6 +83,14 @@ def update_connections(self, connections):
if connection.address not in self.active_counter:
self.active_counter[connection.address] = 0
+ if self.debug_loging_enabled:
+ self._logger.debug(
+ f"update_connections: self._connections: {self._connections}"
+ )
+
+ for interceptor in self.tracing_interceptors:
+ interceptor.on_connections_updated(self._connections)
+
@staticmethod
def get_load_balancer(
load_balancer_type: Union[LoadBalancerType, str],
@@ -65,7 +99,7 @@ def get_load_balancer(
) -> "LoadBalancer":
"""
Get the load balancer based on the type.
- :param logger:
+ :param logger: Logger to be used.
:param load_balancer_type: Type of load balancer.
:param deployment_name: Name of the deployment.
:return:
@@ -102,7 +136,7 @@ def incr_usage(self, address: str) -> int:
self._logger.debug(f"Incrementing usage for address : {address}")
self.active_counter[address] = self.active_counter.get(address, 0) + 1
- self._logger.debug(f'incr_usage: self.active_counter: {self.active_counter}')
+ self._logger.debug(f"incr_usage: self.active_counter: {self.active_counter}")
return self.active_counter[address]
@@ -114,12 +148,16 @@ def decr_usage(self, address: str) -> int:
self._logger.debug(f"Decrementing usage for address: {address}")
self.active_counter[address] = max(0, self.active_counter.get(address, 0) - 1)
- self._logger.debug(f'decr_usage: self.active_counter: {self.active_counter}')
+ self._logger.debug(f"decr_usage: self.active_counter: {self.active_counter}")
return self.active_counter[address]
def get_active_count(self, address: str) -> int:
"""Get the number of active requests for a given address"""
return self.active_counter.get(address, 0)
- def get_active_counter(self):
+ def get_active_counter(self) -> dict:
+ """
+ Get the active counter for all the connections
+ :return:
+ """
return self.active_counter
diff --git a/marie/serve/networking/balancer/round_robin_balancer.py b/marie/serve/networking/balancer/round_robin_balancer.py
index 60aa0ad3..5aac8bc2 100644
--- a/marie/serve/networking/balancer/round_robin_balancer.py
+++ b/marie/serve/networking/balancer/round_robin_balancer.py
@@ -1,35 +1,44 @@
+from typing import Optional, Sequence
+
from marie.excepts import EstablishGrpcConnectionError
from marie.logging.logger import MarieLogger
+from marie.serve.networking.balancer.interceptor import LoadBalancerInterceptor
from marie.serve.networking.balancer.load_balancer import LoadBalancer
class RoundRobinLoadBalancer(LoadBalancer):
"""
- Round robin load balancer.
+ Round-robin load balancer.
"""
- def __init__(self, deployment_name: str, logger: MarieLogger):
- super().__init__(deployment_name, logger)
- self._rr_counter = 0 # round robin counter
-
- async def get_next_connection(self, num_retries=3):
- return await self._get_next_connection(num_retries=num_retries)
+ def __init__(
+ self,
+ deployment_name: str,
+ logger: Optional[MarieLogger] = None,
+ tracing_interceptors: Optional[Sequence[LoadBalancerInterceptor]] = None,
+ ):
+ super().__init__(
+ deployment_name, logger, tracing_interceptors=tracing_interceptors
+ )
+ self._rr_counter = 0 # round-robin counter
async def _get_next_connection(self, num_retries=3):
"""
:param num_retries: how many retries should be performed when all connections are currently unavailable
:returns: A connection from the pool
"""
- self._logger.debug(
- f'round_robin_balancer.py: self._connections: {self._connections} , {num_retries}'
- )
+ if self.debug_loging_enabled:
+ self._logger.debug(
+ f"round_robin_balancer.py: self._connections: {self._connections} , {num_retries}"
+ )
try:
connection = None
for i in range(len(self._connections)):
internal_rr_counter = (self._rr_counter + i) % len(self._connections)
- self._logger.debug(
- f'round_robin_balancer.py: internal_rr_counter: {internal_rr_counter}'
- )
+ if self.debug_loging_enabled:
+ self._logger.debug(
+ f"round_robin_balancer.py: internal_rr_counter: {internal_rr_counter}"
+ )
connection = self._connections[internal_rr_counter]
# connection is None if it is currently being reset. In that case, try different connection
if connection is not None:
@@ -38,13 +47,14 @@ async def _get_next_connection(self, num_retries=3):
if all_connections_unavailable:
if num_retries <= 0:
raise EstablishGrpcConnectionError(
- f'Error while resetting connections {self._connections} for {self._deployment_name}. Connections cannot be used.'
+ f"Error while resetting connections {self._connections} for {self._deployment_name}. Connections cannot be used."
)
elif connection is None:
# give control back to async event loop so connection resetting can be completed; then retry
- self._logger.debug(
- f' No valid connection found for {self._deployment_name}, give chance for potential resetting of connection'
- )
+ if self.debug_loging_enabled:
+ self._logger.debug(
+ f" No valid connection found for {self._deployment_name}, give chance for potential resetting of connection"
+ )
return await self._get_next_connection(num_retries=num_retries - 1)
except IndexError:
# This can happen as a race condition while _removing_ connections
diff --git a/marie/serve/networking/connection_pool_map.py b/marie/serve/networking/connection_pool_map.py
index 92c298fb..c001dfd3 100644
--- a/marie/serve/networking/connection_pool_map.py
+++ b/marie/serve/networking/connection_pool_map.py
@@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
from marie.logging.logger import MarieLogger
+from marie.serve.networking import LoadBalancer
from marie.serve.networking.instrumentation import (
_NetworkingHistograms,
_NetworkingMetrics,
@@ -23,10 +24,11 @@ def __init__(
logger: Optional[MarieLogger],
metrics: _NetworkingMetrics,
histograms: _NetworkingHistograms,
- aio_tracing_client_interceptors: Optional[Sequence['ClientInterceptor']] = None,
- tracing_client_interceptor: Optional['OpenTelemetryClientInterceptor'] = None,
+ aio_tracing_client_interceptors: Optional[Sequence["ClientInterceptor"]] = None,
+ tracing_client_interceptor: Optional["OpenTelemetryClientInterceptor"] = None,
channel_options: Optional[list] = None,
- load_balancer_type: Optional[str] = 'round_robin',
+ load_balancer_type: Optional[str] = "round_robin",
+ load_balancer: Optional[LoadBalancer] = None,
):
self._logger = logger
# this maps deployments to shards or heads
@@ -36,21 +38,22 @@ def __init__(
self._metrics = metrics
self._histograms = histograms
self.runtime_name = runtime_name
- if os.name != 'nt':
- os.unsetenv('http_proxy')
- os.unsetenv('https_proxy')
+ if os.name != "nt":
+ os.unsetenv("http_proxy")
+ os.unsetenv("https_proxy")
self.aio_tracing_client_interceptors = aio_tracing_client_interceptors
self.tracing_client_interceptor = tracing_client_interceptor
self.channel_options = channel_options
self.load_balancer_type = load_balancer_type
+ self.load_balancer = load_balancer
def add_replica(self, deployment: str, shard_id: int, address: str):
- self._add_connection(deployment, shard_id, address, 'shards')
+ self._add_connection(deployment, shard_id, address, "shards")
def add_head(
self, deployment: str, address: str, head_id: Optional[int] = 0
): # the head_id is always 0 for now, this will change when scaling the head
- self._add_connection(deployment, head_id, address, 'heads')
+ self._add_connection(deployment, head_id, address, "heads")
def get_replicas(
self,
@@ -61,7 +64,7 @@ def get_replicas(
) -> Optional[_ReplicaList]:
# returns all replicas of a given deployment, using a given shard
if deployment in self._deployments:
- type_ = 'heads' if head else 'shards'
+ type_ = "heads" if head else "shards"
if entity_id is None and head:
entity_id = 0
return self._get_connection_list(
@@ -69,7 +72,7 @@ def get_replicas(
)
else:
self._logger.debug(
- f'Unknown deployment {deployment}, no replicas available'
+ f"Unknown deployment {deployment}, no replicas available"
)
return None
@@ -78,9 +81,9 @@ def get_replicas_all_shards(self, deployment: str) -> List[_ReplicaList]:
# result is a list of 'shape' (num_shards, num_replicas), containing all replicas for all shards
replicas = []
if deployment in self._deployments:
- for shard_id in self._deployments[deployment]['shards']:
+ for shard_id in self._deployments[deployment]["shards"]:
replicas.append(
- self._get_connection_list(deployment, 'shards', shard_id)
+ self._get_connection_list(deployment, "shards", shard_id)
)
return replicas
@@ -125,7 +128,7 @@ def _get_connection_list(
def _add_deployment(self, deployment: str):
if deployment not in self._deployments:
- self._deployments[deployment] = {'shards': {}, 'heads': {}}
+ self._deployments[deployment] = {"shards": {}, "heads": {}}
self._access_count[deployment] = 0
def _add_connection(
@@ -147,29 +150,30 @@ def _add_connection(
deployment_name=deployment,
channel_options=self.channel_options,
load_balancer_type=self.load_balancer_type,
+ load_balancer=self.load_balancer,
)
self._deployments[deployment][type][entity_id] = connection_list
if not self._deployments[deployment][type][entity_id].has_connection(address):
self._logger.debug(
- f'adding connection for deployment {deployment}/{type}/{entity_id} to {address}'
+ f"adding connection for deployment {deployment}/{type}/{entity_id} to {address}"
)
self._deployments[deployment][type][entity_id].add_connection(
address, deployment_name=deployment
)
self._logger.debug(
- f'connection for deployment {deployment}/{type}/{entity_id} to {address} added'
+ f"connection for deployment {deployment}/{type}/{entity_id} to {address} added"
)
else:
self._logger.debug(
- f'ignoring activation of pod for deployment {deployment}, {address} already known'
+ f"ignoring activation of pod for deployment {deployment}, {address} already known"
)
async def remove_head(self, deployment, address, head_id: Optional[int] = 0):
- return await self._remove_connection(deployment, head_id, address, 'heads')
+ return await self._remove_connection(deployment, head_id, address, "heads")
async def remove_replica(self, deployment, address, shard_id: Optional[int] = 0):
- return await self._remove_connection(deployment, shard_id, address, 'shards')
+ return await self._remove_connection(deployment, shard_id, address, "shards")
async def _remove_connection(self, deployment, entity_id, address, type):
if (
@@ -177,7 +181,7 @@ async def _remove_connection(self, deployment, entity_id, address, type):
and entity_id in self._deployments[deployment][type]
):
self._logger.debug(
- f'removing connection for deployment {deployment}/{type}/{entity_id} to {address}'
+ f"removing connection for deployment {deployment}/{type}/{entity_id} to {address}"
)
await self._deployments[deployment][type][entity_id].remove_connection(
address
diff --git a/marie/serve/networking/replica_list.py b/marie/serve/networking/replica_list.py
index ff5040fd..eefe0b61 100644
--- a/marie/serve/networking/replica_list.py
+++ b/marie/serve/networking/replica_list.py
@@ -30,13 +30,14 @@ def __init__(
histograms: _NetworkingHistograms,
logger,
runtime_name: str,
- aio_tracing_client_interceptors: Optional[Sequence['ClientInterceptor']] = None,
- tracing_client_interceptor: Optional['OpenTelemetryClientInterceptor'] = None,
- deployment_name: str = '',
+ aio_tracing_client_interceptors: Optional[Sequence["ClientInterceptor"]] = None,
+ tracing_client_interceptor: Optional["OpenTelemetryClientInterceptor"] = None,
+ deployment_name: str = "",
channel_options: Optional[Union[list, Dict[str, Any]]] = None,
load_balancer_type: Optional[
Union[LoadBalancerType, str]
] = LoadBalancerType.ROUND_ROBIN,
+ load_balancer: Optional[LoadBalancer] = None,
):
self.runtime_name = runtime_name
self._connections = []
@@ -49,11 +50,15 @@ def __init__(
self.tracing_client_interceptors = tracing_client_interceptor
self._deployment_name = deployment_name
self.channel_options = channel_options
+
# a set containing all the ConnectionStubs that will be created using add_connection
# this set is not updated in reset_connection and remove_connection
- self.load_balancer = LoadBalancer.get_load_balancer(
- load_balancer_type, deployment_name, logger
- )
+ if load_balancer is not None:
+ self.load_balancer = load_balancer
+ else:
+ self.load_balancer = LoadBalancer.get_load_balancer(
+ load_balancer_type, deployment_name, logger
+ )
async def reset_connection(self, address: str, deployment_name: str):
"""
@@ -64,7 +69,7 @@ async def reset_connection(self, address: str, deployment_name: str):
:param address: Target address of this connection
:param deployment_name: Target deployment of this connection
"""
- self._logger.debug(f'resetting connection for {deployment_name} to {address}')
+ self._logger.debug(f"resetting connection for {deployment_name} to {address}")
parsed_address = urlparse(address)
resolved_address = parsed_address.netloc if parsed_address.netloc else address
if (
@@ -132,7 +137,7 @@ async def remove_connection(self, address: str):
def _create_connection(self, address, deployment_name: str):
self._logger.debug(
- f'create_connection connection for {deployment_name} to {address}'
+ f"create_connection connection for {deployment_name} to {address}"
)
parsed_address = urlparse(address)
address = parsed_address.netloc if parsed_address.netloc else address
diff --git a/marie/serve/runtimes/asyncio.py b/marie/serve/runtimes/asyncio.py
index 853567bc..11f4bf43 100644
--- a/marie/serve/runtimes/asyncio.py
+++ b/marie/serve/runtimes/asyncio.py
@@ -206,6 +206,22 @@ def _get_server(self):
cors=getattr(self.args, 'cors', None),
is_cancel=self.is_cancel,
)
+ elif (
+ hasattr(self.args, 'provider') and self.args.provider == ProviderType.AZURE
+ ):
+ from marie.serve.runtimes.servers.http import AzureHTTPServer
+
+ return AzureHTTPServer(
+ name=self.args.name,
+ runtime_args=self.args,
+ req_handler_cls=self.req_handler_cls,
+ proxy=getattr(self.args, 'proxy', None),
+ uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None),
+ ssl_keyfile=getattr(self.args, 'ssl_keyfile', None),
+ ssl_certfile=getattr(self.args, 'ssl_certfile', None),
+ cors=getattr(self.args, 'cors', None),
+ is_cancel=self.is_cancel,
+ )
elif not hasattr(self.args, 'protocol') or (
len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.GRPC
):
@@ -338,10 +354,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.logger.debug(f'{self!r} is interrupted by user')
elif exc_type and issubclass(exc_type, Exception):
self.logger.error(
- f'{exc_val!r} during {self.run_forever!r}'
- + f'\n add "--quiet-error" to suppress the exception details'
- if not self.args.quiet_error
- else '',
+ (
+ f'{exc_val!r} during {self.run_forever!r}'
+ + f'\n add "--quiet-error" to suppress the exception details'
+ if not self.args.quiet_error
+ else ''
+ ),
exc_info=not self.args.quiet_error,
)
try:
@@ -351,10 +369,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
pass
except Exception as ex:
self.logger.error(
- f'{ex!r} during {self.teardown!r}'
- + f'\n add "--quiet-error" to suppress the exception details'
- if not self.args.quiet_error
- else '',
+ (
+ f'{ex!r} during {self.teardown!r}'
+ + f'\n add "--quiet-error" to suppress the exception details'
+ if not self.args.quiet_error
+ else ''
+ ),
exc_info=not self.args.quiet_error,
)
diff --git a/marie/serve/runtimes/gateway/async_request_response_handling.py b/marie/serve/runtimes/gateway/async_request_response_handling.py
index c2ea46b6..83150092 100644
--- a/marie/serve/runtimes/gateway/async_request_response_handling.py
+++ b/marie/serve/runtimes/gateway/async_request_response_handling.py
@@ -201,9 +201,11 @@ async def _process_results_at_end_gateway(
asyncio.ensure_future(
_process_results_at_end_gateway(responding_tasks, request_graph)
),
- asyncio.ensure_future(asyncio.gather(*floating_tasks))
- if len(floating_tasks) > 0
- else None,
+ (
+ asyncio.ensure_future(asyncio.gather(*floating_tasks))
+ if len(floating_tasks) > 0
+ else None
+ ),
)
return _handle_request
diff --git a/marie/serve/runtimes/gateway/graph/topology_graph.py b/marie/serve/runtimes/gateway/graph/topology_graph.py
index 6b6e7f99..48d81321 100644
--- a/marie/serve/runtimes/gateway/graph/topology_graph.py
+++ b/marie/serve/runtimes/gateway/graph/topology_graph.py
@@ -224,15 +224,15 @@ async def task():
self._pydantic_models_by_endpoint = {}
models_created_by_name = {}
for endpoint, inner_dict in schemas.items():
- input_model_name = inner_dict["input"]["name"]
- input_model_schema = inner_dict["input"]["model"]
+ input_model_name = inner_dict['input']['name']
+ input_model_schema = inner_dict['input']['model']
if input_model_schema in models_schema_list:
input_model = models_list[
models_schema_list.index(input_model_schema)
]
- models_created_by_name[
- input_model_name
- ] = input_model
+ models_created_by_name[input_model_name] = (
+ input_model
+ )
else:
if input_model_name not in models_created_by_name:
if input_model_schema == legacy_doc_schema:
@@ -245,24 +245,24 @@ async def task():
models_created_by_name,
)
)
- models_created_by_name[
- input_model_name
- ] = input_model
+ models_created_by_name[input_model_name] = (
+ input_model
+ )
input_model = models_created_by_name[
input_model_name
]
models_schema_list.append(input_model_schema)
models_list.append(input_model)
- output_model_name = inner_dict["output"]["name"]
- output_model_schema = inner_dict["output"]["model"]
+ output_model_name = inner_dict['output']['name']
+ output_model_schema = inner_dict['output']['model']
if output_model_schema in models_schema_list:
output_model = models_list[
models_schema_list.index(output_model_schema)
]
- models_created_by_name[
- output_model_name
- ] = output_model
+ models_created_by_name[output_model_name] = (
+ output_model
+ )
else:
if output_model_name not in models_created_by_name:
if output_model_name == legacy_doc_schema:
@@ -275,18 +275,18 @@ async def task():
models_created_by_name,
)
)
- models_created_by_name[
- output_model_name
- ] = output_model
+ models_created_by_name[output_model_name] = (
+ output_model
+ )
output_model = models_created_by_name[
output_model_name
]
models_schema_list.append(output_model)
models_list.append(output_model)
- parameters_model_name = inner_dict["parameters"]["name"]
- parameters_model_schema = inner_dict["parameters"][
- "model"
+ parameters_model_name = inner_dict['parameters']['name']
+ parameters_model_schema = inner_dict['parameters'][
+ 'model'
]
if parameters_model_schema is not None:
if parameters_model_schema in models_schema_list:
@@ -600,17 +600,17 @@ def _get_leaf_input_output_model(
list_of_outputs = []
for outgoing_node in self.outgoing_nodes:
list_of_maps = outgoing_node._get_leaf_input_output_model(
- previous_input=new_map["input"] if new_map is not None else None,
- previous_output=new_map["output"] if new_map is not None else None,
- previous_is_generator=new_map["is_generator"]
- if new_map is not None
- else None,
- previous_is_singleton_doc=new_map["is_singleton_doc"]
- if new_map is not None
- else None,
- previous_parameters=new_map["parameters"]
- if new_map is not None
- else None,
+ previous_input=new_map['input'] if new_map is not None else None,
+ previous_output=new_map['output'] if new_map is not None else None,
+ previous_is_generator=(
+ new_map['is_generator'] if new_map is not None else None
+ ),
+ previous_is_singleton_doc=(
+ new_map['is_singleton_doc'] if new_map is not None else None
+ ),
+ previous_parameters=(
+ new_map['parameters'] if new_map is not None else None
+ ),
endpoint=endpoint,
)
# We are interested in the last one, that will be the task that awaits all the previous
@@ -732,7 +732,6 @@ def _find_route(request):
return request
class _EndGatewayNode(_ReqReplyNode):
-
"""
Dummy node to be added before the gateway. This is to solve a problem we had when implementing `floating Executors`.
If we do not add this at the end, this structure does not work:
@@ -809,9 +808,11 @@ def __init__(
metadata = deployments_metadata.get(node_name, None)
nodes[node_name] = self._ReqReplyNode(
name=node_name,
- number_of_parts=num_parts_per_node[node_name]
- if num_parts_per_node[node_name] > 0
- else 1,
+ number_of_parts=(
+ num_parts_per_node[node_name]
+ if num_parts_per_node[node_name] > 0
+ else 1
+ ),
floating=node_name in floating_deployment_set,
filter_condition=condition,
metadata=metadata,
diff --git a/marie/serve/runtimes/gateway/health_model.py b/marie/serve/runtimes/gateway/health_model.py
index 99bb6162..2ed7c355 100644
--- a/marie/serve/runtimes/gateway/health_model.py
+++ b/marie/serve/runtimes/gateway/health_model.py
@@ -1,13 +1,13 @@
-from typing import Dict
+from typing import Dict, Optional
from pydantic import BaseModel
def _to_camel_case(snake_str: str) -> str:
- components = snake_str.split('_')
+ components = snake_str.split("_")
# We capitalize the first letter of each component except the first one
# with the 'title' method and join them together.
- return components[0] + ''.join(x.title() for x in components[1:])
+ return components[0] + "".join(x.title() for x in components[1:])
class JinaHealthModel(BaseModel):
@@ -19,7 +19,8 @@ class JinaHealthModel(BaseModel):
class JinaInfoModel(BaseModel):
"""Pydantic BaseModel for Jina status, used as the response model in REST app."""
- jina: Dict
+ jina: Optional[Dict]
+ marie: Optional[Dict]
envs: Dict
class Config:
diff --git a/marie/serve/runtimes/gateway/http_fastapi_app.py b/marie/serve/runtimes/gateway/http_fastapi_app.py
index 3c1285e6..d39af4d8 100644
--- a/marie/serve/runtimes/gateway/http_fastapi_app.py
+++ b/marie/serve/runtimes/gateway/http_fastapi_app.py
@@ -325,9 +325,9 @@ async def foo_no_post(body: JinaRequestModel):
}
for k, v in crud.items():
v['tags'] = ['CRUD']
- v[
- 'description'
- ] = f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.'
+ v['description'] = (
+ f'Post data requests to the Flow. Executors with `@requests(on="{k}")` will respond.'
+ )
expose_executor_endpoint(exec_endpoint=k, **v)
if openapi_tags:
diff --git a/marie/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py b/marie/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py
index 91550141..b80bfe42 100644
--- a/marie/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py
+++ b/marie/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py
@@ -268,12 +268,27 @@ async def event_generator():
return EventSourceResponse(event_generator())
for endpoint, input_output_map in request_models_map.items():
- if endpoint != "_jina_dry_run_":
- input_doc_model = input_output_map["input"]
- output_doc_model = input_output_map["output"]
- is_generator = input_output_map["is_generator"]
- parameters_model = input_output_map["parameters"] or Optional[Dict]
- default_parameters = ... if input_output_map["parameters"] else None
+ if endpoint != '_jina_dry_run_':
+ input_doc_model = input_output_map['input']
+ output_doc_model = input_output_map['output']
+ is_generator = input_output_map['is_generator']
+ parameters_model = input_output_map['parameters']
+ parameters_model_needed = parameters_model is not None
+ if parameters_model_needed:
+ try:
+ _ = parameters_model()
+ parameters_model_needed = False
+ except:
+ parameters_model_needed = True
+ parameters_model = (
+ parameters_model
+ if parameters_model_needed
+ else Optional[parameters_model]
+ )
+ default_parameters = ... if parameters_model_needed else None
+ else:
+ parameters_model = Optional[Dict]
+ default_parameters = None
_config = inherit_config(InnerConfig, BaseDoc.__config__)
diff --git a/marie/serve/runtimes/gateway/models.py b/marie/serve/runtimes/gateway/models.py
index 29cda239..39b70bfd 100644
--- a/marie/serve/runtimes/gateway/models.py
+++ b/marie/serve/runtimes/gateway/models.py
@@ -183,9 +183,11 @@ def protobuf_to_pydantic_model(
all_fields[field_name] = (
field_type,
- Field(default_factory=default_factory)
- if default_factory
- else Field(default=default_value),
+ (
+ Field(default_factory=default_factory)
+ if default_factory
+ else Field(default=default_value)
+ ),
)
# Post-processing (Handle oneof fields)
diff --git a/marie/serve/runtimes/gateway/request_handling.py b/marie/serve/runtimes/gateway/request_handling.py
index 5af16541..3696d42e 100644
--- a/marie/serve/runtimes/gateway/request_handling.py
+++ b/marie/serve/runtimes/gateway/request_handling.py
@@ -70,9 +70,11 @@ def __init__(
meter=meter,
aio_tracing_client_interceptors=aio_tracing_client_interceptors,
tracing_client_interceptor=tracing_client_interceptor,
- grpc_channel_options=self.runtime_args.grpc_channel_options
- if hasattr(self.runtime_args, 'grpc_channel_options')
- else None,
+ grpc_channel_options=(
+ self.runtime_args.grpc_channel_options
+ if hasattr(self.runtime_args, 'grpc_channel_options')
+ else None
+ ),
)
GatewayStreamer._set_env_streamer_args(
diff --git a/marie/serve/runtimes/gateway/streamer.py b/marie/serve/runtimes/gateway/streamer.py
index 21bdd289..9b80a628 100644
--- a/marie/serve/runtimes/gateway/streamer.py
+++ b/marie/serve/runtimes/gateway/streamer.py
@@ -18,6 +18,7 @@
from marie.logging.logger import MarieLogger
from marie.proto import jina_pb2
from marie.serve.networking import GrpcConnectionPool
+from marie.serve.networking.balancer.load_balancer import LoadBalancer
from marie.serve.runtimes.gateway.async_request_response_handling import (
AsyncRequestResponseHandler,
)
@@ -29,7 +30,7 @@
if docarray_v2:
from docarray import DocList
-__all__ = ['GatewayStreamer']
+__all__ = ["GatewayStreamer"]
if TYPE_CHECKING: # pragma: no cover
from grpc.aio._interceptor import ClientInterceptor
@@ -55,15 +56,16 @@ def __init__(
timeout_send: Optional[float] = None,
retries: int = 0,
compression: Optional[str] = None,
- runtime_name: str = 'custom gateway',
+ runtime_name: str = "custom gateway",
prefetch: int = 0,
- logger: Optional['MarieLogger'] = None,
- metrics_registry: Optional['CollectorRegistry'] = None,
- meter: Optional['Meter'] = None,
- aio_tracing_client_interceptors: Optional[Sequence['ClientInterceptor']] = None,
- tracing_client_interceptor: Optional['OpenTelemetryClientInterceptor'] = None,
+ logger: Optional["MarieLogger"] = None,
+ metrics_registry: Optional["CollectorRegistry"] = None,
+ meter: Optional["Meter"] = None,
+ aio_tracing_client_interceptors: Optional[Sequence["ClientInterceptor"]] = None,
+ tracing_client_interceptor: Optional["OpenTelemetryClientInterceptor"] = None,
grpc_channel_options: Optional[list] = None,
- load_balancer_type: Optional[str] = 'round_robin',
+ load_balancer_type: Optional[str] = "round_robin",
+ load_balancer: Optional[LoadBalancer] = None,
):
"""
:param graph_representation: A dictionary describing the topology of the Deployments. 2 special nodes are expected, the name `start-gateway` and `end-gateway` to
@@ -85,7 +87,7 @@ def __init__(
:param aio_tracing_client_interceptors: Optional list of aio grpc tracing server interceptors.
:param tracing_client_interceptor: Optional gprc tracing server interceptor.
:param grpc_channel_options: Optional gprc channel options.
- :param load_balancer_type: Optional load balancer type. Default is round robin.
+ :param load_balancer_type: Optional load balancer type. Default is round-robin.
"""
self.logger = logger or MarieLogger(self.__class__.__name__)
self.topology_graph = TopologyGraph(
@@ -113,6 +115,7 @@ def __init__(
tracing_client_interceptor,
grpc_channel_options,
load_balancer_type,
+ load_balancer,
)
request_handler = AsyncRequestResponseHandler(
metrics_registry, meter, runtime_name, logger
@@ -143,7 +146,8 @@ def _create_connection_pool(
aio_tracing_client_interceptors,
tracing_client_interceptor,
grpc_channel_options=None,
- load_balancer_type='round_robin',
+ load_balancer_type=None,
+ load_balancer=None,
):
# add the connections needed
connection_pool = GrpcConnectionPool(
@@ -156,6 +160,7 @@ def _create_connection_pool(
tracing_client_interceptor=tracing_client_interceptor,
channel_options=grpc_channel_options,
load_balancer_type=load_balancer_type,
+ load_balancer=load_balancer,
)
for deployment_name, addresses in deployments_addresses.items():
for address in addresses:
@@ -216,7 +221,7 @@ async def stream(
parameters: Optional[Dict] = None,
results_in_order: bool = False,
return_type: Type[DocumentArray] = DocumentArray,
- ) -> AsyncIterator[Tuple[Union[DocumentArray, 'Request'], 'ExecutorError']]:
+ ) -> AsyncIterator[Tuple[Union[DocumentArray, "Request"], "ExecutorError"]]:
"""
stream Documents and yield Documents or Responses and unpacked Executor error if any.
@@ -256,14 +261,14 @@ async def stream(
async def stream_doc(
self,
- doc: 'Document',
+ doc: "Document",
return_results: bool = False,
exec_endpoint: Optional[str] = None,
target_executor: Optional[str] = None,
parameters: Optional[Dict] = None,
request_id: Optional[str] = None,
return_type: Type[DocumentArray] = DocumentArray,
- ) -> AsyncIterator[Tuple[Union[DocumentArray, 'Request'], 'ExecutorError']]:
+ ) -> AsyncIterator[Tuple[Union[DocumentArray, "Request"], "ExecutorError"]]:
"""
stream Documents and yield Documents or Responses and unpacked Executor error if any.
@@ -421,15 +426,15 @@ def get_streamer():
:return: Returns an instance of `GatewayStreamer`
"""
- if 'JINA_STREAMER_ARGS' in os.environ:
- args_dict = json.loads(os.environ['JINA_STREAMER_ARGS'])
+ if "JINA_STREAMER_ARGS" in os.environ:
+ args_dict = json.loads(os.environ["JINA_STREAMER_ARGS"])
return GatewayStreamer(**args_dict)
else:
- raise OSError('JINA_STREAMER_ARGS environment variable is not set')
+ raise OSError("JINA_STREAMER_ARGS environment variable is not set")
@staticmethod
def _set_env_streamer_args(**kwargs):
- os.environ['JINA_STREAMER_ARGS'] = json.dumps(kwargs)
+ os.environ["JINA_STREAMER_ARGS"] = json.dumps(kwargs)
class _ExecutorStreamer:
@@ -507,7 +512,7 @@ def batch(iterable, n=1):
async def stream_doc(
self,
- inputs: 'Document',
+ inputs: "Document",
on: Optional[str] = None,
parameters: Optional[Dict] = None,
**kwargs,
diff --git a/marie/serve/runtimes/head/request_handling.py b/marie/serve/runtimes/head/request_handling.py
index 12e190ad..ab9e9d38 100644
--- a/marie/serve/runtimes/head/request_handling.py
+++ b/marie/serve/runtimes/head/request_handling.py
@@ -358,9 +358,9 @@ async def task():
output_model_schema = inner_dict['output']['model']
if input_model_schema == legacy_doc_schema:
- models_created_by_name[
- input_model_name
- ] = LegacyDocument
+ models_created_by_name[input_model_name] = (
+ LegacyDocument
+ )
elif input_model_name not in models_created_by_name:
input_model = _create_pydantic_model_from_schema(
input_model_schema, input_model_name, {}
@@ -368,9 +368,9 @@ async def task():
models_created_by_name[input_model_name] = input_model
if output_model_name == legacy_doc_schema:
- models_created_by_name[
- output_model_name
- ] = LegacyDocument
+ models_created_by_name[output_model_name] = (
+ LegacyDocument
+ )
elif output_model_name not in models_created_by_name:
output_model = _create_pydantic_model_from_schema(
output_model_schema, output_model_name, {}
@@ -471,7 +471,9 @@ async def process_data(self, requests: List[DataRequest], context) -> DataReques
)
context.set_trailing_metadata(metadata.items())
return response
- except InternalNetworkError as err: # can't connect, Flow broken, interrupt the streaming through gRPC error mechanism
+ except (
+ InternalNetworkError
+ ) as err: # can't connect, Flow broken, interrupt the streaming through gRPC error mechanism
return self._handle_internalnetworkerror(
err=err, context=context, response=Response()
)
@@ -480,9 +482,12 @@ async def process_data(self, requests: List[DataRequest], context) -> DataReques
Exception,
) as ex: # some other error, keep streaming going just add error info
self.logger.error(
- f'{ex!r}' + f'\n add "--quiet-error" to suppress the exception details'
- if not self.args.quiet_error
- else '',
+ (
+ f'{ex!r}'
+ + f'\n add "--quiet-error" to suppress the exception details'
+ if not self.args.quiet_error
+ else ''
+ ),
exc_info=not self.args.quiet_error,
)
requests[0].add_exception(ex, executor=None)
@@ -522,7 +527,9 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto:
)
response.endpoints.extend(worker_response.endpoints)
response.schemas.update(worker_response.schemas)
- except InternalNetworkError as err: # can't connect, Flow broken, interrupt the streaming through gRPC error mechanism
+ except (
+ InternalNetworkError
+ ) as err: # can't connect, Flow broken, interrupt the streaming through gRPC error mechanism
return self._handle_internalnetworkerror(
err=err, context=context, response=response
)
diff --git a/marie/serve/runtimes/servers/grpc.py b/marie/serve/runtimes/servers/grpc.py
index 93157c32..2c7e2375 100644
--- a/marie/serve/runtimes/servers/grpc.py
+++ b/marie/serve/runtimes/servers/grpc.py
@@ -157,7 +157,7 @@ async def setup_server(self):
await self.health_servicer.set(
service, health_pb2.HealthCheckResponse.SERVING
)
- self.logger.debug(f'GRPC server setup successful')
+ self.logger.debug(f'GRPC server setup successful : {bind_addr}')
async def shutdown(self):
"""Free other resources allocated with the server, e.g, gateway object, ..."""
diff --git a/marie/serve/runtimes/servers/http.py b/marie/serve/runtimes/servers/http.py
index 512860a4..9f369b82 100644
--- a/marie/serve/runtimes/servers/http.py
+++ b/marie/serve/runtimes/servers/http.py
@@ -285,7 +285,44 @@ def app(self):
"""Get the sagemaker fastapi app
:return: Return a FastAPI app for the sagemaker container
"""
- return self._request_handler._http_fastapi_sagemaker_app(
+ return self._request_handler._http_fastapi_csp_app(
+ title=self.title,
+ description=self.description,
+ no_crud_endpoints=self.no_crud_endpoints,
+ no_debug_endpoints=self.no_debug_endpoints,
+ expose_endpoints=self.expose_endpoints,
+ expose_graphql_endpoint=self.expose_graphql_endpoint,
+ tracing=self.tracing,
+ tracer_provider=self.tracer_provider,
+ cors=self.cors,
+ logger=self.logger,
+ )
+
+
+class AzureHTTPServer(FastAPIBaseServer):
+ """
+ :class:`AzureHTTPServer` is a FastAPIBaseServer that uses a custom FastAPI app for azure endpoints
+
+ """
+
+ @property
+ def port(self):
+ """Get the port for the azure server
+ :return: Return the port for the azure server, always 8080"""
+ return 8080
+
+ @property
+ def ports(self):
+ """Get the port for the azure server
+ :return: Return the port for the azure server, always 8080"""
+ return [8080]
+
+ @property
+ def app(self):
+ """Get the azure fastapi app
+ :return: Return a FastAPI app for the azure container
+ """
+ return self._request_handler._http_fastapi_csp_app(
title=self.title,
description=self.description,
no_crud_endpoints=self.no_crud_endpoints,
diff --git a/marie/serve/runtimes/worker/http_sagemaker_app.py b/marie/serve/runtimes/worker/http_csp_app.py
similarity index 92%
rename from marie/serve/runtimes/worker/http_sagemaker_app.py
rename to marie/serve/runtimes/worker/http_csp_app.py
index f347591f..6fc3743b 100644
--- a/marie/serve/runtimes/worker/http_sagemaker_app.py
+++ b/marie/serve/runtimes/worker/http_csp_app.py
@@ -41,7 +41,7 @@ def get_fastapi_app(
from marie.serve.runtimes.gateway.models import _to_camel_case
if not docarray_v2:
- logger.warning('Only docarray v2 is supported with Sagemaker. ')
+ logger.warning('Only docarray v2 is supported with CSP. ')
return
class Header(BaseModel):
@@ -117,7 +117,7 @@ async def process(body) -> output_model:
if body.parameters is not None:
req.parameters = body.parameters
req.header.exec_endpoint = endpoint_path
- req.document_array_cls = DocList[input_doc_model]
+ req.document_array_cls = DocList[input_doc_list_model]
data = body.data
if isinstance(data, list):
@@ -230,10 +230,23 @@ def construct_model_from_line(
if endpoint != '_jina_dry_run_':
input_doc_model = input_output_map['input']['model']
output_doc_model = input_output_map['output']['model']
- parameters_model = input_output_map['parameters']['model'] or Optional[Dict]
- default_parameters = (
- ... if input_output_map['parameters']['model'] else None
- )
+ parameters_model = input_output_map['parameters']['model']
+ parameters_model_needed = parameters_model is not None
+ if parameters_model_needed:
+ try:
+ _ = parameters_model()
+ parameters_model_needed = False
+ except:
+ parameters_model_needed = True
+ parameters_model = (
+ parameters_model
+ if parameters_model_needed
+ else Optional[parameters_model]
+ )
+ default_parameters = ... if parameters_model_needed else None
+ else:
+ parameters_model = Optional[Dict]
+ default_parameters = None
_config = inherit_config(InnerConfig, BaseDoc.__config__)
endpoint_input_model = pydantic.create_model(
diff --git a/marie/serve/runtimes/worker/http_fastapi_app.py b/marie/serve/runtimes/worker/http_fastapi_app.py
index 15ac63ce..5ff41392 100644
--- a/marie/serve/runtimes/worker/http_fastapi_app.py
+++ b/marie/serve/runtimes/worker/http_fastapi_app.py
@@ -161,10 +161,23 @@ async def streaming_get(request: Request = None, body: input_doc_model = None):
input_doc_model = input_output_map['input']['model']
output_doc_model = input_output_map['output']['model']
is_generator = input_output_map['is_generator']
- parameters_model = input_output_map['parameters']['model'] or Optional[Dict]
- default_parameters = (
- ... if input_output_map['parameters']['model'] else None
- )
+ parameters_model = input_output_map['parameters']['model']
+ parameters_model_needed = parameters_model is not None
+ if parameters_model_needed:
+ try:
+ _ = parameters_model()
+ parameters_model_needed = False
+ except:
+ parameters_model_needed = True
+ parameters_model = (
+ parameters_model
+ if parameters_model_needed
+ else Optional[parameters_model]
+ )
+ default_parameters = ... if parameters_model_needed else None
+ else:
+ parameters_model = Optional[Dict]
+ default_parameters = None
if docarray_v2:
_config = inherit_config(InnerConfig, BaseDoc.__config__)
diff --git a/marie/serve/runtimes/worker/request_handling.py b/marie/serve/runtimes/worker/request_handling.py
index 2896cc54..acdb094c 100644
--- a/marie/serve/runtimes/worker/request_handling.py
+++ b/marie/serve/runtimes/worker/request_handling.py
@@ -192,8 +192,8 @@ async def _shutdown():
return extend_rest_interface(app)
- def _http_fastapi_sagemaker_app(self, **kwargs):
- from marie.serve.runtimes.worker.http_sagemaker_app import get_fastapi_app
+ def _http_fastapi_csp_app(self, **kwargs):
+ from jina.serve.runtimes.worker.http_csp_app import get_fastapi_app
request_models_map = self._executor._get_endpoint_models_dict()
@@ -1167,7 +1167,7 @@ async def stream(
:param kwargs: keyword arguments
:yield: responses to the request
"""
- self.logger.debug("recv a stream request")
+ self.logger.debug("recv a stream request from client")
async for request in request_iterator:
yield await self.process_data([request], context)
diff --git a/marie/serve/stream/__init__.py b/marie/serve/stream/__init__.py
index 29b2e4a5..10efa533 100644
--- a/marie/serve/stream/__init__.py
+++ b/marie/serve/stream/__init__.py
@@ -16,7 +16,7 @@
from marie.serve.stream.helper import AsyncRequestsIterator, _RequestsCounter
from marie.types.request.data import DataRequest
-__all__ = ['RequestStreamer']
+__all__ = ["RequestStreamer"]
from marie._docarray import DocumentArray
from marie.types.request.data import Response
@@ -36,13 +36,13 @@ class _EndOfStreaming:
def __init__(
self,
request_handler: Callable[
- ['Request'], Tuple[Awaitable['Request'], Optional[Awaitable['Request']]]
+ ["Request"], Tuple[Awaitable["Request"], Optional[Awaitable["Request"]]]
],
- result_handler: Callable[['Request'], Optional['Request']],
+ result_handler: Callable[["Request"], Optional["Request"]],
prefetch: int = 0,
iterate_sync_in_thread: bool = True,
end_of_iter_handler: Optional[Callable[[], None]] = None,
- logger: Optional['MarieLogger'] = None,
+ logger: Optional["MarieLogger"] = None,
**logger_kwargs,
):
"""
@@ -80,11 +80,11 @@ async def _get_endpoints_input_output_models(
# Flow.
# create loop and get from topology_graph
_endpoints_models_map = {}
- self.logger.debug(f'Get all endpoints from TopologyGraph')
+ self.logger.debug(f"Get all endpoints from TopologyGraph")
endpoints = await topology_graph._get_all_endpoints(
connection_pool, retry_forever=True, is_cancel=is_cancel
)
- self.logger.debug(f'Got all endpoints from TopologyGraph {endpoints}')
+ self.logger.debug(f"Got all endpoints from TopologyGraph {endpoints}")
if endpoints is not None:
for endp in endpoints:
for origin_node in topology_graph.origin_nodes:
@@ -103,14 +103,14 @@ async def _get_endpoints_input_output_models(
_endpoints_models_map[endp] = leaf_input_output_model[0]
cached_models = {}
for k, v in _endpoints_models_map.items():
- if v['input'].__name__ not in cached_models:
- cached_models[v['input'].__name__] = v['input']
+ if v["input"].__name__ not in cached_models:
+ cached_models[v["input"].__name__] = v["input"]
else:
- v['input'] = cached_models[v['input'].__name__]
- if v['output'].__name__ not in cached_models:
- cached_models[v['output'].__name__] = v['output']
+ v["input"] = cached_models[v["input"].__name__]
+ if v["output"].__name__ not in cached_models:
+ cached_models[v["output"].__name__] = v["output"]
else:
- v['output'] = cached_models[v['output'].__name__]
+ v["output"] = cached_models[v["output"].__name__]
return _endpoints_models_map
async def stream_doc(
@@ -118,7 +118,7 @@ async def stream_doc(
request,
context=None,
*args,
- ) -> AsyncIterator['Request']:
+ ) -> AsyncIterator["Request"]:
"""
stream requests from client iterator and stream responses back.
@@ -141,7 +141,7 @@ async def stream_doc(
context.set_code(err.code())
context.set_trailing_metadata(err.trailing_metadata())
self.logger.error(
- f'Error while getting responses from deployments: {err.details()}'
+ f"Error while getting responses from deployments: {err.details()}"
)
r = Response()
if err.request_id:
@@ -149,13 +149,13 @@ async def stream_doc(
yield r
else: # HTTP and WS need different treatment further up the stack
self.logger.error(
- f'Error while getting responses from deployments: {err.details()}'
+ f"Error while getting responses from deployments: {err.details()}"
)
raise
except (
Exception
) as err: # HTTP and WS need different treatment further up the stack
- self.logger.error(f'Error while getting responses from deployments: {err}')
+ self.logger.error(f"Error while getting responses from deployments: {err}")
raise err
async def stream(
@@ -166,7 +166,7 @@ async def stream(
prefetch: Optional[int] = None,
return_type: Type[DocumentArray] = DocumentArray,
*args,
- ) -> AsyncIterator['Request']:
+ ) -> AsyncIterator["Request"]:
"""
stream requests from client iterator and stream responses back.
@@ -181,13 +181,13 @@ async def stream(
prefetch = prefetch or self._prefetch
if context is not None:
for metadatum in context.invocation_metadata():
- if metadatum.key == '__results_in_order__':
- results_in_order = metadatum.value == 'true'
- if metadatum.key == '__prefetch__':
+ if metadatum.key == "__results_in_order__":
+ results_in_order = metadatum.value == "true"
+ if metadatum.key == "__prefetch__":
try:
prefetch = int(metadatum.value)
except:
- self.logger.debug(f'Couldn\'t parse prefetch to int value!')
+ self.logger.debug(f"Couldn't parse prefetch to int value!")
try:
async_iter: AsyncIterator = self._stream_requests(
@@ -206,7 +206,7 @@ async def stream(
context.set_code(err.code())
context.set_trailing_metadata(err.trailing_metadata())
self.logger.error(
- f'Error while getting responses from deployments: {err.details()}'
+ f"Error while getting responses from deployments: {err.details()}"
)
r = Response()
if err.request_id:
@@ -214,13 +214,13 @@ async def stream(
yield r
else: # HTTP and WS need different treatment further up the stack
self.logger.error(
- f'Error while getting responses from deployments: {err.details()}'
+ f"Error while getting responses from deployments: {err.details()}"
)
raise
except (
Exception
) as err: # HTTP and WS need different treatment further up the stack
- self.logger.error(f'Error while getting responses from deployments: {err}')
+ self.logger.error(f"Error while getting responses from deployments: {err}")
raise err
async def _stream_requests(
@@ -257,7 +257,7 @@ async def end_future():
async def exception_raise(exception):
raise exception
- def callback(future: 'asyncio.Future'):
+ def callback(future: "asyncio.Future"):
"""callback to be run after future is completed.
1. Put the future in the result queue.
2. Remove the future from futures when future is completed.
@@ -269,7 +269,7 @@ def callback(future: 'asyncio.Future'):
"""
result_queue.put_nowait(future)
- def hanging_callback(future: 'asyncio.Future'):
+ def hanging_callback(future: "asyncio.Future"):
floating_results_queue.put_nowait(future)
async def iterate_requests() -> None:
diff --git a/marie/storage/database/postgres.py b/marie/storage/database/postgres.py
index d093ab59..77afc102 100644
--- a/marie/storage/database/postgres.py
+++ b/marie/storage/database/postgres.py
@@ -10,16 +10,22 @@
class PostgresqlMixin:
"""Bind PostgreSQL database provider."""
- provider = 'postgres'
-
- def _setup_storage(self, config: Dict[str, Any], create_table_callback: Optional[Callable] = None,
- reset_table_callback: Optional[Callable] = None) -> None:
+ provider = "postgres"
+
+ def _setup_storage(
+ self,
+ config: Dict[str, Any],
+ create_table_callback: Optional[Callable] = None,
+ reset_table_callback: Optional[Callable] = None,
+ connection_only=False,
+ ) -> None:
"""
Setup PostgreSQL connection pool.
@param config:
@param create_table_callback: Create table if it doesn't exist.
@param reset_table_callback: Reset table if it exists.
+ @param connection_only: Only connect to the database.
@return:
"""
try:
@@ -28,10 +34,6 @@ def _setup_storage(self, config: Dict[str, Any], create_table_callback: Optional
username = config["username"]
password = config["password"]
database = config["database"]
- self.table = config["default_table"]
-
- if self.table is None or self.table == "":
- raise ValueError("default_table cannot be empty")
max_connections = 10
self.postgreSQL_pool = psycopg2.pool.SimpleConnectionPool(
@@ -43,11 +45,20 @@ def _setup_storage(self, config: Dict[str, Any], create_table_callback: Optional
host=hostname,
port=port,
)
+
+ if connection_only:
+ self.logger.info(f"Connected to postgresql database: {config}")
+ return
+
+ self.table = config["default_table"]
+ if self.table is None or self.table == "":
+ raise ValueError("default_table cannot be empty")
+
self._init_table(create_table_callback, reset_table_callback)
except Exception as e:
raise BadConfigSource(
- f'Cannot connect to postgresql database: {config}, {e}'
+ f"Cannot connect to postgresql database: {config}, {e}"
)
def __enter__(self):
@@ -70,8 +81,11 @@ def _get_connection(self):
connection.autocommit = False
return connection
- def _init_table(self, create_table_callback: Optional[Callable] = None,
- reset_table_callback: Optional[Callable] = None) -> None:
+ def _init_table(
+ self,
+ create_table_callback: Optional[Callable] = None,
+ reset_table_callback: Optional[Callable] = None,
+ ) -> None:
"""
Use table if exists or create one if it doesn't.
"""
@@ -102,9 +116,14 @@ def _table_exists(self) -> bool:
(self.table,),
).fetchall()[0][0]
- def _execute_sql_gracefully(self, statement, data=tuple(), *,
- named_cursor_name: Optional[str] = None,
- itersize: Optional[int] = 10000) -> psycopg2.extras.DictCursor:
+ def _execute_sql_gracefully(
+ self,
+ statement,
+ data=tuple(),
+ *,
+ named_cursor_name: Optional[str] = None,
+ itersize: Optional[int] = 10000,
+ ) -> psycopg2.extras.DictCursor:
try:
if named_cursor_name:
cursor = self.connection.cursor(named_cursor_name)
diff --git a/marie/utils/pydantic.py b/marie/utils/pydantic.py
new file mode 100644
index 00000000..a46df370
--- /dev/null
+++ b/marie/utils/pydantic.py
@@ -0,0 +1,37 @@
+import pydantic
+
+major_version = int(pydantic.__version__.split('.')[0])
+
+print(major_version)
+
+
+def patch_pydantic_schema(cls):
+ print(major_version)
+ raise NotImplementedError
+
+
+if major_version >= 2:
+ from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
+ from pydantic_core import PydanticOmit, core_schema
+
+ class PydanticJsonSchema(GenerateJsonSchema):
+ def handle_invalid_for_json_schema(
+ self, schema: core_schema.CoreSchema, error_info: str
+ ) -> JsonSchemaValue:
+ if "core_schema.PlainValidatorFunctionSchema" in error_info:
+ raise PydanticOmit
+ return super().handle_invalid_for_json_schema(schema, error_info)
+
+ def patch_pydantic_schema(cls):
+ major_version = int(pydantic.__version__.split('.')[0])
+ # Check if the major version is 2 or higher
+ if major_version < 2:
+ schema = cls.model_json_schema(mode="validation")
+ else:
+ schema = cls.model_json_schema(
+ mode="validation", schema_generator=PydanticJsonSchema
+ )
+ return schema
+
+
+patch_pydantic_schema_2x = patch_pydantic_schema
diff --git a/marie/utils/timer.py b/marie/utils/timer.py
new file mode 100644
index 00000000..7fdff524
--- /dev/null
+++ b/marie/utils/timer.py
@@ -0,0 +1,34 @@
+import threading
+import time
+
+
+class RepeatedTimer(object):
+ """A timer that executes a function after a certain interval."""
+
+ def __init__(self, interval, function, *args, **kwargs):
+ self._timer = None
+ self.interval = interval
+ self.function = function
+ self.args = args
+ self.kwargs = kwargs
+ self.is_running = False
+ self.next_call = time.time()
+ self.start()
+
+ def _run(self):
+ self.is_running = False
+ self.start()
+ self.function(*self.args, **self.kwargs)
+
+ def start(self):
+ """Start the timer."""
+ if not self.is_running:
+ self.next_call += self.interval
+ self._timer = threading.Timer(self.next_call - time.time(), self._run)
+ self._timer.start()
+ self.is_running = True
+
+ def stop(self):
+ """Stop the timer."""
+ self._timer.cancel()
+ self.is_running = False
diff --git a/marie_server/__main__.py b/marie_server/__main__.py
index dd8dcad9..7347f652 100644
--- a/marie_server/__main__.py
+++ b/marie_server/__main__.py
@@ -7,6 +7,7 @@
from typing import Any, Dict, Optional
from docarray import BaseDoc, DocList
+from docarray.documents.legacy import LegacyDocument
from rich.traceback import install
import marie.helper
@@ -31,6 +32,8 @@
from marie.messaging.publisher import event_builder
from marie.storage import S3StorageHandler, StorageManager
from marie.utils.device import gpu_device_count
+from marie.utils.json import store_json_object
+from marie.utils.pydantic import patch_pydantic_schema_2x
from marie.utils.types import strtobool
from marie_server.rest_extension import extend_rest_interface
@@ -176,6 +179,19 @@ def main(
__main__(yml_config, env, env_file)
+def patch_libs():
+ """Patch the libraries"""
+ logger.warning("Patching libraries")
+ # patch pydantic schema
+ # LegacyDocument.schema = classmethod(patch_pydantic_schema_2x)
+ #
+ # schema = LegacyDocument.schema()
+ # logger.info(f"Schema : {schema}")
+
+ # store_json_object(schema, os.path.join("/home/greg/tmp/marie", "schema-2.x.json"))
+ # print(schema)
+
+
def __main__(
yml_config: str,
env: Optional[Dict[str, str]] = None,
@@ -265,21 +281,26 @@ def __main__(
for k, v in os.environ.items():
print(f"{k} = {v}")
+ patch_libs()
+
# Load the config file and set up the toast events
config = load_yaml(yml_config, substitute=True, context=context)
prefetch = config.get("prefetch", 1)
# flow or deployment
- f = Flow.load_config(
- config,
- extra_search_paths=[os.path.dirname(inspect.getfile(inspect.currentframe()))],
- substitute=True,
- context=context,
- include_gateway=True,
- noblock_on_start=False,
- prefetch=prefetch,
- external=True,
- ).config_gateway(prefetch=prefetch)
+ if True:
+ f = Flow.load_config(
+ config,
+ extra_search_paths=[
+ os.path.dirname(inspect.getfile(inspect.currentframe()))
+ ],
+ substitute=True,
+ context=context,
+ include_gateway=True,
+ noblock_on_start=False,
+ prefetch=prefetch,
+ external=True,
+ ).config_gateway(prefetch=prefetch)
if False:
f = Deployment.load_config(
@@ -289,9 +310,10 @@ def __main__(
],
substitute=True,
context=context,
- include_gateway=False,
+ include_gateway=True,
noblock_on_start=False,
prefetch=prefetch,
+ statefull=False,
external=True,
)
@@ -337,6 +359,7 @@ def setup_server(config: Dict[str, Any]) -> None:
setup_toast_events(config.get("toast", {}))
setup_storage(config.get("storage", {}))
setup_auth(config.get("auth", {}))
+
# setup_scheduler(config.get("scheduler", {}))
diff --git a/marie_server/job/common.py b/marie_server/job/common.py
index 8afb32d8..3da59c46 100644
--- a/marie_server/job/common.py
+++ b/marie_server/job/common.py
@@ -14,6 +14,8 @@
INTERNAL_NAMESPACE_PREFIX = "marie_internal"
JOB_STATUS_KEY = f"{INTERNAL_NAMESPACE_PREFIX}/job_status"
+ActorHandle = Any
+
class JobStatus(str, Enum):
"""An enumeration for describing the status of a job."""
diff --git a/marie_server/job/gateway_job_distributor.py b/marie_server/job/gateway_job_distributor.py
index bc00acd5..e80e3cd6 100644
--- a/marie_server/job/gateway_job_distributor.py
+++ b/marie_server/job/gateway_job_distributor.py
@@ -1,9 +1,6 @@
-from typing import Any, Optional
+from typing import Optional
-from m import Document
-
-from marie import DocumentArray
-from marie.clients.request import asyncio
+from marie import Document, DocumentArray
from marie.logging.logger import MarieLogger
from marie.serve.runtimes.gateway.streamer import GatewayStreamer
from marie.types.request.data import DataRequest
@@ -17,11 +14,11 @@ def __init__(
gateway_streamer: Optional[GatewayStreamer] = None,
logger: Optional[MarieLogger] = None,
):
- self.gateway_streamer = gateway_streamer
- self._logger = logger or MarieLogger(self.__class__.__name__)
+ self.streamer = gateway_streamer
+ self.logger = logger or MarieLogger(self.__class__.__name__)
- async def submit_job(self, job_info: JobInfo) -> DataRequest:
- self._logger.info(f"Publishing job {job_info} to gateway")
+ async def submit_job(self, job_info: JobInfo, doc: Document) -> DataRequest:
+ self.logger.info(f"Publishing job {job_info} to gateway")
curr_status = job_info.status
curr_message = job_info.message
@@ -31,21 +28,37 @@ async def submit_job(self, job_info: JobInfo) -> DataRequest:
f"Current status is {curr_status} with message {curr_message}."
)
- # attempt to get gateway streamer if not initialized
- if self.gateway_streamer is None:
- self._logger.warning(f"Gateway streamer is not initialized")
- self.gateway_streamer = GatewayStreamer.get_streamer()
+ # attempt to get gateDDDway streamer if not initialized
+ if self.streamer is None:
+ self.logger.warning(f"Gateway streamer is not initialized")
+ raise RuntimeError("Gateway streamer is not initialized")
+
+ async for docs in self.streamer.stream_docs(
+ doc=doc,
+ # exec_endpoint="/extract", # _jina_dry_run_
+ exec_endpoint="_jina_dry_run_", # _jina_dry_run_
+ # target_executor="executor0",
+ return_results=False,
+ ):
+ self.logger.info(f"Received {len(docs)} docs from gateway")
+ print(docs)
+ result = docs[0].text
+
+ return result
- if self.gateway_streamer is None:
- raise Exception("Gateway streamer is not initialized")
+ if False:
+ # convert job_info to DataRequest
+ request = DataRequest()
+ # request.header.exec_endpoint = on
+ request.header.target_executor = job_info.entrypoint
+ request.parameters = job_info.metadata
- # convert job_info to DataRequest
- request = DataRequest()
- # request.header.exec_endpoint = on
- request.header.target_executor = job_info.entrypoint
- request.parameters = job_info.metadata
+ request.data.docs = DocumentArray([Document(text="sample text")])
+ response = await self.streamer.process_single_data(request=request)
- request.data.docs = DocumentArray([Document(text="sample text")])
- response = await self.gateway_streamer.process_single_data(request=request)
+ return response
- return response
+ async def close(self):
+ self.logger.debug(f"Closing GatewayJobDistributor")
+ await self.streamer.close()
+ self.logger.debug(f"GatewayJobDistributor closed")
diff --git a/marie_server/job/job_distributor.py b/marie_server/job/job_distributor.py
index d8e29ec0..3f2fae5f 100644
--- a/marie_server/job/job_distributor.py
+++ b/marie_server/job/job_distributor.py
@@ -1,20 +1,22 @@
import abc
+from marie import Document
from marie.types.request.data import DataRequest
from marie_server.job.common import JobInfo
class JobDistributor(abc.ABC):
"""
- Job Distributor is responsible for publishing jobs to the underlying executor which can be a gateway/flow.
+ Job Distributor is responsible for publishing jobs to the underlying executor which can be a gateway/flow/deployment.
"""
@abc.abstractmethod
- async def submit_job(self, job_info: JobInfo) -> DataRequest:
+ async def submit_job(self, job_info: JobInfo, doc: Document) -> DataRequest:
"""
Publish a job.
:param job_info: The job info to publish.
+ :param doc: The document to process.
:return:
"""
...
diff --git a/marie_server/job/job_manager.py b/marie_server/job/job_manager.py
index 4dca646f..e41e0259 100644
--- a/marie_server/job/job_manager.py
+++ b/marie_server/job/job_manager.py
@@ -7,16 +7,20 @@
from marie._core.utils import run_background_task
from marie.logging.logger import MarieLogger
-from marie_server.job.common import JobInfo, JobInfoStorageClient, JobStatus
+from marie_server.job.common import (
+ ActorHandle,
+ JobInfo,
+ JobInfoStorageClient,
+ JobStatus,
+)
from marie_server.job.job_distributor import JobDistributor
+from marie_server.job.job_supervisor import JobSupervisor
from marie_server.job.scheduling_strategies import (
NodeAffinitySchedulingStrategy,
SchedulingStrategyT,
)
from marie_server.storage.storage_client import StorageArea
-ActorHandle = Any
-
# The max time to wait for the JobSupervisor to start before failing the job.
DEFAULT_JOB_START_TIMEOUT_SECONDS = 60 * 15
JOB_START_TIMEOUT_SECONDS_ENV_VAR = "JOB_START_TIMEOUT_SECONDS"
@@ -33,80 +37,6 @@ def get_event_logger():
return None
-class JobSupervisor:
- """
- Supervise jobs and keep track of their status.
- """
-
- DEFAULT_JOB_STOP_WAIT_TIME_S = 3
-
- def __init__(
- self,
- job_id: str,
- job_info_client: JobInfoStorageClient,
- job_distributor: JobDistributor,
- ):
- self._job_id = job_id
- self._job_info_client = job_info_client
- self._job_distributor = job_distributor
-
- def ping(self):
- """Used to check the health of the actor/executor/deployment."""
- pass
-
- async def run(
- self,
- # Signal actor used in testing to capture PENDING -> RUNNING cases
- _start_signal_actor: Optional[ActorHandle] = None,
- ):
- """
- Stop and start both happen asynchronously, coordinated by asyncio event
- and coroutine, respectively.
-
- 1) Sets job status as running
- 2) Pass runtime env and metadata to subprocess as serialized env
- variables.
- 3) Handle concurrent events of driver execution and
- """
- curr_info = await self._job_info_client.get_info(self._job_id)
- if curr_info is None:
- raise RuntimeError(f"Status could not be retrieved for job {self._job_id}.")
- curr_status = curr_info.status
- curr_message = curr_info.message
- if curr_status == JobStatus.RUNNING:
- raise RuntimeError(
- f"Job {self._job_id} is already in RUNNING state. "
- f"JobSupervisor.run() should only be called once. "
- )
- if curr_status != JobStatus.PENDING:
- raise RuntimeError(
- f"Job {self._job_id} is not in PENDING state. "
- f"Current status is {curr_status} with message {curr_message}."
- )
- if _start_signal_actor:
- # Block in PENDING state until start signal received.
- await _start_signal_actor.wait.remote()
-
- driver_agent_http_address = "grpc://127.0.0.1"
- driver_node_id = "GET_NODE_ID_FROM_CLUSTER"
-
- await self._job_info_client.put_status(
- self._job_id,
- JobStatus.RUNNING,
- jobinfo_replace_kwargs={
- "driver_agent_http_address": driver_agent_http_address,
- "driver_node_id": driver_node_id,
- },
- )
-
- response = await self._job_distributor.submit_job(curr_info)
- # format the response
- print("Response: ", response)
- print("Response type: ", type(response))
- print("Response data: ", response.data)
- print("Response status: ", response.status)
-
-
class JobManager:
"""Provide python APIs for job submission and management.
diff --git a/marie_server/job/job_supervisor.py b/marie_server/job/job_supervisor.py
new file mode 100644
index 00000000..174766a4
--- /dev/null
+++ b/marie_server/job/job_supervisor.py
@@ -0,0 +1,83 @@
+from typing import Any, Optional
+
+from marie_server.job.common import (
+ ActorHandle,
+ JobInfo,
+ JobInfoStorageClient,
+ JobStatus,
+)
+from marie_server.job.job_distributor import JobDistributor
+
+
+class JobSupervisor:
+ """
+ Supervise jobs and keep track of their status.
+ """
+
+ DEFAULT_JOB_STOP_WAIT_TIME_S = 3
+
+ def __init__(
+ self,
+ job_id: str,
+ job_info_client: JobInfoStorageClient,
+ job_distributor: JobDistributor,
+ ):
+ self._job_id = job_id
+ self._job_info_client = job_info_client
+ self._job_distributor = job_distributor
+
+ def ping(self):
+ """Used to check the health of the actor/executor/deployment."""
+ pass
+
+ async def run(
+ self,
+ # Signal actor used in testing to capture PENDING -> RUNNING cases
+ _start_signal_actor: Optional[ActorHandle] = None,
+ ):
+ """
+ Stop and start both happen asynchronously, coordinated by asyncio event
+ and coroutine, respectively.
+
+ 1) Sets job status as running
+ 2) Pass runtime env and metadata to subprocess as serialized env
+ variables.
+ 3) Handle concurrent events of driver execution and
+ """
+ curr_info = await self._job_info_client.get_info(self._job_id)
+ if curr_info is None:
+ raise RuntimeError(f"Status could not be retrieved for job {self._job_id}.")
+ curr_status = curr_info.status
+ curr_message = curr_info.message
+ if curr_status == JobStatus.RUNNING:
+ raise RuntimeError(
+ f"Job {self._job_id} is already in RUNNING state. "
+ f"JobSupervisor.run() should only be called once. "
+ )
+ if curr_status != JobStatus.PENDING:
+ raise RuntimeError(
+ f"Job {self._job_id} is not in PENDING state. "
+ f"Current status is {curr_status} with message {curr_message}."
+ )
+ if _start_signal_actor:
+ # Block in PENDING state until start signal received.
+ await _start_signal_actor.wait.remote()
+
+ driver_agent_http_address = "grpc://127.0.0.1"
+ driver_node_id = "GET_NODE_ID_FROM_CLUSTER"
+
+ await self._job_info_client.put_status(
+ self._job_id,
+ JobStatus.RUNNING,
+ jobinfo_replace_kwargs={
+ "driver_agent_http_address": driver_agent_http_address,
+ "driver_node_id": driver_node_id,
+ },
+ )
+
+ response = await self._job_distributor.submit_job(curr_info)
+ # format the response
+ print("Response: ", response)
+ print("Response type: ", type(response))
+ print("Response data: ", response.data)
+ print("Response status: ", response.status)
diff --git a/marie_server/scheduler/fixtures.py b/marie_server/scheduler/fixtures.py
new file mode 100644
index 00000000..94c4346e
--- /dev/null
+++ b/marie_server/scheduler/fixtures.py
@@ -0,0 +1,121 @@
+from marie_server.scheduler.state import States
+
+
+def create_schema(schema: str):
+ return f"CREATE SCHEMA IF NOT EXISTS {schema}"
+
+
+def create_version_table(schema: str):
+ return f"""
+ CREATE TABLE {schema}.version (
+ version int primary key,
+ maintained_on timestamp with time zone,
+ cron_on timestamp with time zone
+ )
+ """
+
+
+def create_job_state_enum(schema: str):
+ return f"""
+ CREATE TYPE {schema}.job_state AS ENUM (
+ '{States.CREATED.value}',
+ '{States.RETRY.value}',
+ '{States.ACTIVE.value}',
+ '{States.COMPLETED.value}',
+ '{States.EXPIRED.value}',
+ '{States.CANCELLED.value}',
+ '{States.FAILED.value}'
+ )
+ """
+
+
+def create_job_table(schema: str):
+ return f"""
+ CREATE TABLE {schema}.job (
+ id uuid primary key not null default gen_random_uuid(),
+ name text not null,
+ priority integer not null default(0),
+ data jsonb,
+ state {schema}.job_state not null default('{States.CREATED.value}'),
+ retry_limit integer not null default(0),
+ retry_count integer not null default(0),
+ retry_delay integer not null default(0),
+ retry_backoff boolean not null default false,
+ start_after timestamp with time zone not null default now(),
+ started_on timestamp with time zone,
+ singleton_key text,
+ singleton_on timestamp without time zone,
+ expire_in interval not null default interval '15 minutes',
+ created_on timestamp with time zone not null default now(),
+ completed_on timestamp with time zone,
+ keep_until timestamp with time zone NOT NULL default now() + interval '14 days',
+ on_complete boolean not null default false,
+ output jsonb
+ )
+ """
+
+
+def clone_job_table_for_archive(schema):
+ return f"CREATE TABLE {schema}.archive (LIKE {schema}.job)"
+
+
+def create_schedule_table(schema):
+ return f"""
+ CREATE TABLE {schema}.schedule (
+ name text primary key,
+ cron text not null,
+ timezone text,
+ data jsonb,
+ options jsonb,
+ created_on timestamp with time zone not null default now(),
+ updated_on timestamp with time zone not null default now()
+ )
+ """
+
+
+def create_subscription_table(schema):
+ return f"""
+ CREATE TABLE {schema}.subscription (
+ event text not null,
+ name text not null,
+ created_on timestamp with time zone not null default now(),
+ updated_on timestamp with time zone not null default now(),
+ PRIMARY KEY(event, name)
+ )
+ """
+
+
+def add_archived_on_to_archive(schema):
+ return f"ALTER TABLE {schema}.archive ADD archived_on timestamptz NOT NULL DEFAULT now()"
+
+
+def add_archived_on_index_to_archive(schema):
+ return f"CREATE INDEX archive_archivedon_idx ON {schema}.archive(archived_on)"
+
+
+def add_id_index_to_archive(schema):
+ return f"CREATE INDEX archive_id_idx ON {schema}.archive(id)"
+
+
+def create_index_singleton_on(schema):
+ return f"""
+ CREATE UNIQUE INDEX job_singleton_on ON {schema}.job (name, singleton_on) WHERE state < '{States.EXPIRED.value}' AND singleton_key IS NULL
+ """
+
+
+def create_index_singleton_key_on(schema):
+ return f"""
+ CREATE UNIQUE INDEX job_singleton_key_on ON {schema}.job (name, singleton_on, singleton_key) WHERE state < '{States.EXPIRED.value}'
+ """
+
+
+def create_index_job_name(schema):
+ return f"""
+ CREATE INDEX job_name ON {schema}.job (name text_pattern_ops)
+ """
+
+
+def create_index_job_fetch(schema):
+ return f"""
+ CREATE INDEX job_fetch ON {schema}.job (name text_pattern_ops, start_after) WHERE state < '{States.ACTIVE.value}'
+ """
diff --git a/marie_server/scheduler/psql.py b/marie_server/scheduler/psql.py
index 652cec73..8872c4ec 100644
--- a/marie_server/scheduler/psql.py
+++ b/marie_server/scheduler/psql.py
@@ -1,6 +1,7 @@
import asyncio
import threading
import traceback
+from enum import Enum
from typing import Any, Dict, Generator, List
import psycopg2
@@ -9,62 +10,91 @@
from marie.logging.logger import MarieLogger
from marie.logging.predefined import default_logger as logger
from marie.storage.database.postgres import PostgresqlMixin
+from marie_server.scheduler.fixtures import *
from marie_server.scheduler.scheduler import Scheduler
+from marie_server.scheduler.state import States
INIT_POLL_PERIOD = 1.250 # 250ms
MAX_POLL_PERIOD = 16.0 # 16s
+DEFAULT_SCHEMA = "marie_scheduler"
+COMPLETION_JOB_PREFIX = f"__state__{States.COMPLETED.value}__"
+
+
class PostgreSQLJobScheduler(PostgresqlMixin, Scheduler):
def __init__(self, config: Dict[str, Any]):
super().__init__()
self.logger = MarieLogger("PostgreSQLJobScheduler")
print("config", config)
self.running = False
- self._setup_storage(config)
+ self._setup_storage(config, connection_only=True)
+
+ def create_tables(self, schema: str):
+ """
+ :param schema: The name of the schema where the tables will be created.
+ :return: None
+ """
+ commands = [
+ create_schema(schema),
+ create_version_table(schema),
+ create_job_state_enum(schema),
+ create_job_table(schema),
+ clone_job_table_for_archive(schema),
+ create_schedule_table(schema),
+ create_subscription_table(schema),
+ add_archived_on_to_archive(schema),
+ add_archived_on_index_to_archive(schema),
+ add_id_index_to_archive(schema),
+ create_index_singleton_on(schema),
+ create_index_singleton_key_on(schema),
+ create_index_job_name(schema),
+ create_index_job_fetch(schema),
+ ]
+
+ query = ";\n".join(commands)
+
+ locked_query = f"""
+ BEGIN;
+ SET LOCAL statement_timeout = '30s';
+ SELECT pg_try_advisory_lock(1);
+ {query};
+ SELECT pg_advisory_unlock(1);
+ COMMIT;
+ """
+
+ with self:
+ self._execute_sql_gracefully(locked_query)
def start_schedule(self) -> None:
+ """
+ Starts the job scheduling agent.
+
+ :return: None
+ """
logger.info("Starting job scheduling agent")
+ self.create_tables(DEFAULT_SCHEMA)
- def _run():
- try:
+ if False:
+
+ def _run():
try:
- loop = asyncio.get_running_loop()
- except RuntimeError:
- loop = None
-
- if loop is None:
- asyncio.run(self.__poll())
- else:
- loop.run_until_complete(self.__poll())
- except Exception as e:
- logger.error(f"Unable to setup job scheduler: {e}")
- logger.error(traceback.format_exc())
-
- t = threading.Thread(target=_run, daemon=True)
- t.start()
-
- def _create_table(self, table_name: str) -> None:
- """Create the table if it doesn't exist."""
- print("creating table : ", table_name)
-
- self._execute_sql_gracefully(
- f"""
- CREATE TABLE IF NOT EXISTS queue (
- id UUID PRIMARY KEY,
- created_at TIMESTAMP WITH TIME ZONE NOT NULL,
- updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
-
- scheduled_for TIMESTAMP WITH TIME ZONE NOT NULL,
- failed_attempts INT NOT NULL,
- status INT NOT NULL,
- message JSONB NOT NULL
- );
-
- CREATE INDEX index_queue_on_scheduled_for ON queue (scheduled_for);
- CREATE INDEX index_queue_on_status ON queue (status);
- """,
- )
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = None
+
+ if loop is None:
+ asyncio.run(self.__poll())
+ else:
+ loop.run_until_complete(self.__poll())
+ except Exception as e:
+ logger.error(f"Unable to setup job scheduler: {e}")
+ logger.error(traceback.format_exc())
+
+ t = threading.Thread(target=_run, daemon=True)
+ t.start()
+ t.join() # FOR TESTING PURPOSES ONLY
async def __poll(self):
print("Starting poller with psql")
@@ -109,22 +139,24 @@ def get_document_iterator(
:param limit: the maximal number records to get
:return:
"""
- try:
- cursor = self.connection.cursor("doc_iterator")
- cursor.itersize = 10000
- cursor.execute(
- f"""
- SELECT * FROM job_queue
- """
- # + (f" limit = {limit}" if limit > 0 else "")
- )
- for record in cursor:
- print(record)
- doc_id = record[0]
-
- yield doc_id
- except (Exception, psycopg2.Error) as error:
- self.logger.error(f"Error importing snapshot: {error}")
- self.connection.rollback()
- self.connection.commit()
+ with self:
+ try:
+ cursor = self.connection.cursor("doc_iterator")
+ cursor.itersize = 10000
+ cursor.execute(
+ f"""
+ SELECT * FROM job_queue
+ """
+ # + (f" limit = {limit}" if limit > 0 else "")
+ )
+ for record in cursor:
+ print(record)
+ doc_id = record[0]
+
+ yield doc_id
+
+ except (Exception, psycopg2.Error) as error:
+ self.logger.error(f"Error importing snapshot: {error}")
+ self.connection.rollback()
+ self.connection.commit()
diff --git a/marie_server/scheduler/state.py b/marie_server/scheduler/state.py
new file mode 100644
index 00000000..7f69ab88
--- /dev/null
+++ b/marie_server/scheduler/state.py
@@ -0,0 +1,11 @@
+from enum import Enum
+
+
+class States(Enum):
+ CREATED = "created"
+ RETRY = "retry"
+ ACTIVE = "active"
+ COMPLETED = "completed"
+ EXPIRED = "expired"
+ CANCELLED = "cancelled"
+ FAILED = "failed"
diff --git a/poc/custom_gateway/deployment_gateway.py b/poc/custom_gateway/deployment_gateway.py
new file mode 100644
index 00000000..99f4890d
--- /dev/null
+++ b/poc/custom_gateway/deployment_gateway.py
@@ -0,0 +1,50 @@
+import asyncio
+
+from marie import Gateway as BaseGateway
+from marie.serve.runtimes.servers.composite import CompositeServer
+from marie.serve.runtimes.servers.grpc import GRPCServer
+
+
+class MariePodGateway(BaseGateway, CompositeServer):
+ """A custom Gateway for Marie deployment pods (Worker nodes) ."""
+
+ def __init__(self, **kwargs):
+ """Initialize a new Gateway."""
+ print("MariePodGateway init called")
+ super().__init__(**kwargs)
+
+ async def setup_server(self):
+ """
+ setup servers inside MariePodGateway
+ """
+ self.logger.debug(f"Setting up MariePodGateway server")
+ await super().setup_server()
+ self.logger.debug(f"MariePodGateway server setup successful")
+
+ for server in self.servers:
+ if isinstance(server, GRPCServer):
+ print(f"Registering GRPC server {server}")
+ host = server.host
+ port = server.port
+ ctrl_address = f"{host}:{port}"
+ self._register_gateway(ctrl_address)
+
+ def _register_gateway(self, ctrl_address: str):
+ """Check if the gateway is ready."""
+ print("Registering gateway with controller at : ", ctrl_address)
+
+ async def _async_wait_all(ctrl_address: str):
+ """Wait for all servers to be ready."""
+
+ print("waiting for all servers to be ready at : ", ctrl_address)
+ while True:
+ print(f"checking is ready at {ctrl_address}")
+ res = GRPCServer.is_ready(ctrl_address)
+ print(f"res: {res}")
+ if res:
+ print(f"Gateway is ready at {ctrl_address}")
+ break
+ await asyncio.sleep(1)
+
+ asyncio.create_task(_async_wait_all(ctrl_address))
+ print("Done waiting for all servers to be ready")
diff --git a/poc/custom_gateway/direct-flow.py b/poc/custom_gateway/direct-flow.py
index c1f05c3e..50452379 100644
--- a/poc/custom_gateway/direct-flow.py
+++ b/poc/custom_gateway/direct-flow.py
@@ -1,48 +1,120 @@
+import inspect
+import os
+import time
+
from docarray import DocList
from docarray.documents import TextDoc
-from marie import Executor, Flow, requests
-from marie.serve.runtimes.gateway.http.fastapi import FastAPIBaseGateway
+from marie import Deployment, Executor, Flow, requests
+from marie.conf.helper import load_yaml
+from poc.custom_gateway.deployment_gateway import MariePodGateway
+
+class TestExecutorXYZ(Executor):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ print("TestExecutorXYZ init called")
-class MyGateway(FastAPIBaseGateway):
- @property
- def app(self):
- from fastapi import FastAPI
+ # emulate the long loading time
+ # time.sleep(5)
- app = FastAPI(title="Custom FastAPI Gateway")
+ @requests(on="/classify")
+ def func(
+ self,
+ docs: DocList[TextDoc],
+ parameters: dict = {},
+ *args,
+ **kwargs,
+ ):
+ print(f"FirstExec func called : {len(docs)}, {parameters}")
+ for doc in docs:
+ doc.text += " First"
- @app.get("/endpoint")
- async def get(text: str):
- result = None
- async for docs in self.streamer.stream_docs(
- docs=DocList[TextDoc]([TextDoc(text=text)]),
- exec_endpoint="/",
- target_executor="executor0",
- ):
- result = docs[0].text
- return {"result": result}
+ return {
+ "parameters": parameters,
+ "data": "Data reply",
+ }
- return app
+class TestExecutor(Executor):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ print("TestExecutor init called")
+ # emulate the long loading time
+ time.sleep(1)
-class FirstExec(Executor):
- @requests
- def func(self, docs, **kwargs):
+ @requests(on="/extract")
+ def func(
+ self,
+ docs: DocList[TextDoc],
+ parameters: dict = {},
+ *args,
+ **kwargs,
+ ):
+ print(f"FirstExec func called : {len(docs)}, {parameters}")
for doc in docs:
doc.text += " First"
+ return {
+ "parameters": parameters,
+ "data": "Data reply",
+ }
-class SecondExec(Executor):
- @requests
- def func(self, docs, **kwargs):
- for doc in docs:
- doc.text += " Second"
+
+def main_deployment():
+ context = {"name": "test"}
+ yml_config = "/mnt/data/marie-ai/config/service/deployment.yml"
+
+ # Load the config file and set up the toast events
+ config = load_yaml(yml_config, substitute=True, context=context)
+ f = Deployment.load_config(
+ config,
+ extra_search_paths=[os.path.dirname(inspect.getfile(inspect.currentframe()))],
+ substitute=True,
+ context=context,
+ include_gateway=False,
+ noblock_on_start=False,
+ prefetch=1,
+ statefull=False,
+ )
+
+ with Deployment(
+ uses=TestExecutor,
+ timeout_ready=-1,
+ protocol="grpc",
+ port=61000,
+ include_gateway=True,
+ replicas=3,
+ ):
+ f.block()
+
+ # with (Flow().add(uses=FirstExec) as f):
+ # f.block()
+
+
+def main():
+ context = {"name": "test"}
+ # yml_config = "/mnt/data/marie-ai/config/service/deployment.yml"
+ # # Load the config file and set up the toast events
+ # config = load_yaml(yml_config, substitute=True, context=context)
+ print("Bootstrapping server gateway")
+ with (
+ Flow(
+ discovery=True, # server gateway does not need discovery service
+ discovery_host="127.0.0.1",
+ discovery_port=2379,
+ discovery_watchdog_interval=5,
+ )
+ .add(uses=TestExecutor, name="executor0", replicas=3)
+ .config_gateway(
+ # uses=MariePodGateway, protocols=["GRPC", "HTTP"], ports=[61000, 61001]
+ ) as f
+ ):
+ f.block()
-with Flow(port=12345).config_gateway(uses=MyGateway, protocol="http", port=50975).add(
- uses=FirstExec, name="executor0"
-).add(uses=SecondExec, name="executor1") as flow:
- flow.block()
+if __name__ == "__main__":
+ main()
-# curl -X GET "http://localhost:50975/endpoint?text=abc"
+# curl -X GET "http://localhost:51000/endpoint?text=abc"
+# https://docs.jina.ai/concepts/serving/gateway/customization/
diff --git a/poc/custom_gateway/etcd3-leader-election.py b/poc/custom_gateway/etcd3-leader-election.py
new file mode 100644
index 00000000..a69f8c19
--- /dev/null
+++ b/poc/custom_gateway/etcd3-leader-election.py
@@ -0,0 +1,152 @@
+import sys
+import time
+from threading import Event
+
+import etcd3
+
+__all__ = ["LeaderElection"]
+
+
+class LeaderElection:
+ def __init__(self, etcd_client, leader_key, my_id, lease_ttl):
+ self.client = etcd_client
+ self.leader_key = leader_key
+ self.my_id = my_id
+ self.lease_ttl = lease_ttl
+ self.leader = None
+
+ def elect_leader(self, leaderCb):
+ """
+ elect a leader.
+
+ Args:
+ leaderCb - leader callback function. If leader is changed, the
+ leaderCb will be called
+ """
+ while True:
+ try:
+ status, lease, self.leader = self._elect_leader()
+ if self.leader is not None:
+ leaderCb(self.leader)
+ if status:
+ self._refresh_lease(lease)
+ else:
+ self._wait_leader_release()
+ time.sleep(5)
+ except Exception as ex:
+ print(ex)
+
+ def _elect_leader(self):
+ try:
+ lease = self.client.lease(self.lease_ttl)
+ status, responses = self.client.transaction(
+ compare=[self.client.transactions.version(self.leader_key) == 0],
+ success=[
+ self.client.transactions.put(self.leader_key, self.my_id, lease)
+ ],
+ failure=[self.client.transactions.get(self.leader_key)],
+ )
+ if status:
+ return status, lease, self.my_id
+ elif len(responses) == 1 and len(responses[0]) == 1:
+ return status, lease, responses[0][0][0]
+ except Exception as ex:
+ print(ex)
+ return None, None, None
+
+ def _refresh_lease(self, lease):
+ """
+ refresh the lease period
+ """
+ try:
+ while True:
+ lease.refresh()
+ time.sleep(self.lease_ttl / 3.0 - 0.01)
+ except (Exception, KeyboardInterrupt):
+ pass
+ finally:
+ lease.revoke()
+
+ def _wait_leader_release(self):
+ """
+ wait for the leader key deleted
+ """
+ leader_release_event = Event()
+
+ def leader_delete_watch_cb(event):
+ if isinstance(event, etcd3.events.DeleteEvent):
+ leader_release_event.set()
+
+ watch_id = None
+ try:
+ watch_id = self.client.add_watch_callback(
+ self.leader_key, leader_delete_watch_cb
+ )
+ leader_release_event.wait()
+ except:
+ pass
+ finally:
+ if watch_id is not None:
+ self.client.cancel_watch(watch_id)
+
+
+def main():
+ import argparse
+
+ parser = argparse.ArgumentParser(description="elect leader from etcd cluster")
+ parser.add_argument(
+ "--host",
+ help="the etcd host, default = 127.0.0.1",
+ required=False,
+ default="127.0.0.1",
+ )
+ parser.add_argument(
+ "--port",
+ help="the etcd port, default = 2379",
+ required=False,
+ default=2379,
+ type=int,
+ )
+ parser.add_argument("--ca-cert", help="the etcd ca-cert", required=False)
+ parser.add_argument("--cert-key", help="the etcd cert key", required=False)
+ parser.add_argument("--cert-cert", help="the etcd cert", required=False)
+ parser.add_argument("--leader-key", help="the election leader key", required=True)
+ parser.add_argument(
+ "--lease-ttl",
+ help="the lease ttl in seconds, default is 10",
+ required=False,
+ default=10,
+ type=int,
+ )
+ parser.add_argument("--my-id", help="my identifier", required=True)
+ parser.add_argument(
+ "--timeout",
+ help="the etcd operation timeout in seconds, default is 2",
+ required=False,
+ type=int,
+ default=2,
+ )
+ args = parser.parse_args()
+
+ params = {"host": args.host, "port": args.port, "timeout": args.timeout}
+ if args.ca_cert:
+ params["ca_cert"] = args.ca_cert
+ if args.cert_key:
+ params["cert_key"] = args.cert_key
+ if args.cert_cert:
+ params["cert_cert"] = args.cert_cert
+
+ client = etcd3.client(**params)
+
+ leader_election = LeaderElection(
+ client, args.leader_key, args.my_id, args.lease_ttl
+ )
+
+ def print_leader(leader):
+ print("leader is %s" % leader)
+
+ leader_election.elect_leader(print_leader)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/poc/custom_gateway/leader_election.py b/poc/custom_gateway/leader_election.py
new file mode 100644
index 00000000..17612275
--- /dev/null
+++ b/poc/custom_gateway/leader_election.py
@@ -0,0 +1,73 @@
+"""etcd3 Leader election."""
+
+import sys
+import time
+from threading import Event
+
+import etcd3
+
+LEADER_KEY = '/leader'
+LEASE_TTL = 5
+SLEEP = 1
+
+
+def put_not_exist(client, key, value, lease=None):
+ status, _ = client.transaction(
+ compare=[client.transactions.version(key) == 0],
+ success=[client.transactions.put(key, value, lease)],
+ failure=[],
+ )
+ return status
+
+
+def leader_election(client, me):
+ try:
+ lease = client.lease(LEASE_TTL)
+ status = put_not_exist(client, LEADER_KEY, me, lease)
+ except Exception:
+ status = False
+ return status, lease
+
+
+def main(me):
+ client = etcd3.client(timeout=2)
+
+ while True:
+ print('leader election')
+ leader, lease = leader_election(client, me)
+
+ if leader:
+ print('leader')
+ try:
+ while True:
+ # do work
+ lease.refresh()
+ time.sleep(SLEEP)
+ except (Exception, KeyboardInterrupt):
+ return
+ finally:
+ lease.revoke()
+ else:
+ print('follower; standby')
+
+ election_event = Event()
+
+ def watch_cb(event):
+ if isinstance(event, etcd3.events.DeleteEvent):
+ election_event.set()
+
+ watch_id = client.add_watch_callback(LEADER_KEY, watch_cb)
+
+ try:
+ while not election_event.is_set():
+ time.sleep(SLEEP)
+ print('new election')
+ except (Exception, KeyboardInterrupt):
+ return
+ finally:
+ client.cancel_watch(watch_id)
+
+
+if __name__ == '__main__':
+ me = sys.argv[1]
+ main(me)
diff --git a/poc/custom_gateway/send_request_to_gateway.py b/poc/custom_gateway/send_request_to_gateway.py
new file mode 100644
index 00000000..85d8db1c
--- /dev/null
+++ b/poc/custom_gateway/send_request_to_gateway.py
@@ -0,0 +1,59 @@
+import asyncio
+
+from docarray import DocList
+from docarray.documents import TextDoc
+
+from marie import Client
+from marie.serve.runtimes.servers.grpc import GRPCServer
+
+
+async def main():
+ """
+ This function sends a request to a Marie server gateway.
+ """
+
+ ctrl_address = "0.0.0.0:61000"
+ res = GRPCServer.is_ready(ctrl_address)
+ print(f"res: {res}")
+
+ return
+ parameters = {}
+ parameters["payload"] = {"payload": "sample payload"}
+ docs = DocList[TextDoc]([TextDoc(text="Sample Text")])
+
+ client = Client(
+ host="127.0.0.1", port=52000, protocol="grpc", request_size=-1, asyncio=True
+ )
+
+ ready = await client.is_flow_ready()
+ print(f"Flow is ready: {ready}")
+
+ async for resp in client.post(
+ on="/",
+ inputs=docs,
+ parameters=parameters,
+ request_size=-1,
+ return_responses=True, # return DocList instead of Response
+ return_exceptions=True,
+ ):
+ print(resp)
+ # for doc in resp:
+ # print(doc.text)
+
+ print(resp.parameters)
+ print(resp.data)
+ # await asyncio.sleep(1)
+
+ print("DONE")
+
+
+if __name__ == "__main__":
+ loop = asyncio.get_event_loop()
+ try:
+ asyncio.ensure_future(main())
+ loop.run_forever()
+ except KeyboardInterrupt:
+ pass
+ finally:
+ print("Closing Loop")
+ loop.close()
diff --git a/poc/custom_gateway/server_gateway.py b/poc/custom_gateway/server_gateway.py
new file mode 100644
index 00000000..00fc3122
--- /dev/null
+++ b/poc/custom_gateway/server_gateway.py
@@ -0,0 +1,456 @@
+import asyncio
+import json
+import time
+from datetime import datetime
+from typing import Callable, Optional
+from urllib.parse import urlparse
+
+import grpc
+from docarray import DocList
+from docarray.documents import TextDoc
+from grpc.aio import ClientInterceptor
+
+import marie
+import marie.helper
+from marie import Gateway as BaseGateway
+from marie.excepts import RuntimeFailToStart
+from marie.helper import get_or_reuse_loop
+from marie.logging.logger import MarieLogger
+from marie.proto import jina_pb2, jina_pb2_grpc
+from marie.serve.discovery import JsonAddress
+from marie.serve.discovery.resolver import EtcdServiceResolver
+from marie.serve.networking.balancer.interceptor import LoadBalancerInterceptor
+from marie.serve.networking.balancer.load_balancer import LoadBalancerType
+from marie.serve.networking.balancer.round_robin_balancer import RoundRobinLoadBalancer
+from marie.serve.networking.connection_stub import _ConnectionStubs
+from marie.serve.networking.utils import get_grpc_channel
+from marie.serve.runtimes.gateway.streamer import GatewayStreamer
+from marie.serve.runtimes.servers.composite import CompositeServer
+from marie.serve.runtimes.servers.grpc import GRPCServer
+from marie_server.job.common import JobInfo, JobStatus
+from marie_server.job.gateway_job_distributor import GatewayJobDistributor
+
+
+def create_trace_interceptor() -> ClientInterceptor:
+ return CustomClientInterceptor()
+
+
+def create_balancer_interceptor() -> LoadBalancerInterceptor:
+ def notify(event, connection):
+ print(f"notify: {event}, {connection}")
+
+ return GatewayLoadBalancerInterceptor(notifier=notify)
+
+
+class MarieServerGateway(BaseGateway, CompositeServer):
+ """A custom Gateway for Marie server.
+ Effectively we are providing a custom implementation of the Gateway class that providers communication between individual executors and the server.
+
+ This utilizes service discovery to find deployed Executors from discovered gateways that could have spawned them(Flow/Deployment).
+
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.logger = MarieLogger(self.__class__.__name__)
+ self.logger.info(f"Setting up MarieServerGateway")
+ self._loop = get_or_reuse_loop()
+ self.deployment_nodes = {}
+ self.event_queue = asyncio.Queue()
+ self.distributor = GatewayJobDistributor(
+ gateway_streamer=self.streamer, logger=self.logger
+ )
+
+ def _extend_rest_function(app):
+ @app.on_event("shutdown")
+ async def _shutdown():
+ await self.distributor.close()
+
+ @app.get("/endpoint")
+ async def get(text: str):
+ self.logger.info(f"Received request at {datetime.now()}")
+ docs = DocList[TextDoc]([TextDoc(text=text)])
+ doc = TextDoc(text=text)
+
+ if False:
+ result = await self.distributor.submit_job(
+ JobInfo(status=JobStatus.PENDING, entrypoint="_jina_dry_run_"),
+ doc=doc,
+ )
+
+ return {"result": result}
+
+ if True:
+ result = None
+ async for docs in self.streamer.stream_docs(
+ docs=DocList[TextDoc]([TextDoc(text=text)]),
+ # doc=TextDoc(text=text),
+ # exec_endpoint="/extract", # _jina_dry_run_
+ exec_endpoint="_jina_dry_run_", # _jina_dry_run_
+ # exec_endpoint="/endpoint",
+ # target_executor="executor0",
+ return_results=False,
+ ):
+ result = docs[0].text
+ # result = docs
+ print(f"result: {result}")
+ return {"result": result}
+
+ @app.get("/check")
+ async def get_health(text: str):
+ self.logger.info(f"Received request at {datetime.now()}")
+ return {"result": "ok"}
+
+ return app
+
+ marie.helper.extend_rest_interface = _extend_rest_function
+
+ async def setup_server(self):
+ """
+ setup servers inside CompositeServer
+ """
+ self.logger.debug(f"Setting up MarieGateway server")
+ await super().setup_server()
+ await self.setup_service_discovery()
+
+ async def run_server(self):
+ """Run servers inside CompositeServer forever"""
+ run_server_tasks = []
+ for server in self.servers:
+ run_server_tasks.append(asyncio.create_task(server.run_server()))
+
+ # task for processing events
+ run_server_tasks.append(asyncio.create_task(self.process_events(max_errors=5)))
+ await asyncio.gather(*run_server_tasks)
+
+ async def setup_service_discovery(
+ self,
+ etcd_host: Optional[str] = "0.0.0.0",
+ etcd_port: Optional[int] = 2379,
+ watchdog_interval: Optional[int] = 2,
+ ):
+ """
+ Setup service discovery for the gateway.
+
+ :param etcd_host: Optional[str] - The host address of the ETCD service. Default is "0.0.0.0".
+ :param etcd_port: Optional[int] - The port of the ETCD service. Default is 2379.
+ :param watchdog_interval: Optional[int] - The interval in seconds between each service address check. Default is 2.
+ :return: None
+
+ """
+ self.logger.info("Setting up service discovery ")
+ service_name = "gateway/service_test"
+
+ async def _start_watcher():
+ resolver = EtcdServiceResolver(
+ etcd_host,
+ etcd_port,
+ namespace="marie",
+ start_listener=False,
+ listen_timeout=5,
+ )
+
+ self.logger.info(f"checking : {resolver.resolve(service_name)}")
+ resolver.watch_service(service_name, self.handle_discovery_event)
+
+ # validate the service address
+ if False:
+ while True:
+ self.logger.info("Checking service address...")
+ await asyncio.sleep(watchdog_interval)
+
+ task = asyncio.create_task(_start_watcher())
+ try:
+ await task # This raises an exception if the task had an exception
+ except Exception as e:
+ self.logger.error(
+ f"Initialize etcd client failed failed on {etcd_host}:{etcd_port}"
+ )
+ if isinstance(e, RuntimeFailToStart):
+ raise e
+ raise RuntimeFailToStart(
+ f"Initialize etcd client failed failed on {etcd_host}:{etcd_port}, ensure the etcd server is running."
+ )
+
+ def handle_discovery_event(self, service: str, event: str) -> None:
+ """
+ Enqueue the event to be processed.
+ :param service: The name of the service that is available.
+ :param event: The event that triggered the method.
+ :return:
+ """
+
+ self._loop.call_soon_threadsafe(
+ lambda: asyncio.ensure_future(self.event_queue.put((service, event)))
+ )
+
+ async def process_events(self, max_errors=5) -> None:
+ """
+ Handle a discovery event.
+ :param max_errors: The maximum number of errors to allow before stopping the event processing.
+ :return: None
+ """
+
+ error_counter = 0
+ while True:
+ service, event = await self.event_queue.get()
+ try:
+ self.logger.info(
+ f"Queue size : {self.event_queue.qsize()} event = {service}, {event}"
+ )
+ ev_type = event.event
+ ev_key = event.key
+ ev_value = event.value
+ if ev_type == "put":
+ await self.gateway_server_online(service, ev_value)
+ elif ev_type == "delete":
+ self.logger.info(f"Service {service} is unavailable")
+ await self.gateway_server_offline(service, ev_value)
+ else:
+ raise TypeError(f"Not recognized event type : {ev_type}")
+ error_counter = 0 # reset error counter on successful processing
+ except Exception as ex:
+ self.logger.error(f"Error processing event: {ex}")
+ error_counter += 1
+ if error_counter >= max_errors:
+ self.logger.error(f"Reached maximum error limit: {max_errors}")
+ break
+ await asyncio.sleep(1)
+ finally:
+ self.event_queue.task_done()
+
+ async def gateway_server_online(self, service, event_value):
+ """
+ Handle the event when a gateway server comes online.
+
+ :param service: The name of the service that is available.
+ :param event_value: The value of the event that triggered the method.
+ :return: None
+
+ This method is used to handle the event when a gateway server comes online. It checks if the gateway server is ready and then discovers all executors from the gateway. It updates the gateway streamer with the discovered nodes.
+
+ """
+ self.logger.info(f"Service {service} is available @ {event_value}")
+
+ # convert event_value to JsonAddress
+ json_address = JsonAddress.from_value(event_value)
+ ctrl_address = json_address._addr
+ metadata = json.loads(json_address._metadata)
+
+ self.logger.info(f"JsonAddress : {ctrl_address}, {metadata}")
+
+ max_tries = 10
+ tries = 0
+ is_ready = False
+ while tries < max_tries:
+ self.logger.info(f"checking is ready at {ctrl_address}")
+ is_ready = GRPCServer.is_ready(ctrl_address)
+ self.logger.info(f"gateway status: {is_ready}")
+ if is_ready:
+ break
+ time.sleep(1)
+ tries += 1
+
+ if is_ready is False:
+ self.logger.warning(
+ f"Gateway is not ready at {ctrl_address} after {max_tries}, will retry on next event"
+ )
+ return
+
+ self.logger.info(f"Gateway is ready at {ctrl_address}")
+ # discover all executors from the gateway
+ # stub = jina_pb2_grpc.JinaDiscoverEndpointsRPCStub(GRPCServer.get_channel(ctrl_address))
+ TLS_PROTOCOL_SCHEMES = ["grpcs"]
+
+ parsed_address = urlparse(ctrl_address)
+ address = parsed_address.netloc if parsed_address.netloc else ctrl_address
+ use_tls = parsed_address.scheme in TLS_PROTOCOL_SCHEMES
+ channel_options = None
+ timeout = 1
+
+ for executor, deployment_addresses in metadata.items():
+ for deployment_address in deployment_addresses:
+ endpoints = []
+ tries = 0
+ while tries < max_tries:
+ try:
+ with get_grpc_channel(
+ address,
+ tls=use_tls,
+ root_certificates=None,
+ options=channel_options,
+ ) as channel:
+ metadata = ()
+ stub = jina_pb2_grpc.JinaDiscoverEndpointsRPCStub(channel)
+ response, call = stub.endpoint_discovery.with_call(
+ jina_pb2.google_dot_protobuf_dot_empty__pb2.Empty(),
+ timeout=timeout,
+ metadata=metadata,
+ )
+ self.logger.info(f"response: {response.endpoints}")
+ endpoints = response.endpoints
+ break
+ except grpc.RpcError as e:
+ time.sleep(1)
+ tries += 1
+ if (
+ e.code() != grpc.StatusCode.UNAVAILABLE
+ or tries >= max_tries
+ ):
+ raise
+
+ for endpoint in endpoints:
+ if executor not in self.deployment_nodes:
+ self.deployment_nodes[executor] = []
+ deployment_details = {
+ "address": deployment_address,
+ "endpoint": endpoint,
+ "executor": executor,
+ "gateway": ctrl_address,
+ }
+ self.deployment_nodes[executor].append(deployment_details)
+ self.logger.info(
+ f"Discovered endpoint: {executor} : {deployment_details}"
+ )
+
+ for executor, nodes in self.deployment_nodes.items():
+ self.logger.info(f"Discovered nodes for executor : {executor}")
+ for node in nodes:
+ self.logger.info(f"\tNode : {node}")
+
+ await self.update_gateway_streamer()
+
+ async def update_gateway_streamer(self):
+ """Update the gateway streamer with the discovered executors."""
+ self.logger.info("Updating gateway streamer")
+
+ # FIXME: testing with only one executor
+ deployments_addresses = {}
+ graph_description = {
+ "start-gateway": ["executor0"],
+ "executor0": ["end-gateway"],
+ }
+ deployments_metadata = {"deployment0": {"key": "value"}}
+ for i, (executor, nodes) in enumerate(self.deployment_nodes.items()):
+ connections = []
+ for node in nodes:
+ address = node["address"]
+ parsed_address = urlparse(address)
+ port = parsed_address.port
+ host = parsed_address.hostname
+ connections.append(f"{host}:{port}")
+ deployments_addresses[executor] = list(set(connections))
+
+ self.logger.info(f"graph_description: {graph_description}")
+ self.logger.info(f"deployments_addresses: {deployments_addresses}")
+
+ load_balancer = RoundRobinLoadBalancer(
+ "deployment-gateway",
+ self.logger,
+ tracing_interceptors=[create_balancer_interceptor()],
+ )
+
+ streamer = GatewayStreamer(
+ graph_representation=graph_description,
+ executor_addresses=deployments_addresses,
+ deployments_metadata=deployments_metadata,
+ load_balancer_type=LoadBalancerType.ROUND_ROBIN.name,
+ load_balancer=load_balancer,
+ aio_tracing_client_interceptors=[create_trace_interceptor()],
+ )
+ self.streamer = streamer
+ self.distributor.streamer = streamer
+ request_models_map = streamer._endpoints_models_map
+ self.logger.info(f"request_models_map: {request_models_map}")
+
+ async def gateway_server_offline(self, service: str, ev_value):
+ """
+ Handle the event when a gateway server goes offline.
+
+ :param service: The name of the service.
+ :param ev_value: The value representing the offline gateway.
+ :return: None
+ """
+ ctrl_address = service.split("/")[-1]
+ self.logger.info(
+ f"Service {service} is offline @ {ctrl_address}, removing nodes"
+ )
+ for executor, nodes in self.deployment_nodes.items():
+ self.deployment_nodes[executor] = [
+ node for node in nodes if node["gateway"] != ctrl_address
+ ]
+ await self.update_gateway_streamer()
+
+
+class GatewayLoadBalancerInterceptor(LoadBalancerInterceptor):
+ def __init__(self, notifier: Optional[Callable] = None):
+ super().__init__()
+ self.active_connection = None
+ self.notifier = notifier
+
+ def notify(self, event: str, connection: _ConnectionStubs):
+ """
+ :param event: The event that triggered the notification.
+ :param connection: The connection that initiated the event.
+ :return: None
+
+ """
+ if self.notifier:
+ self.notifier(event, connection)
+
+ def on_connection_released(self, connection):
+ print(f"on_connection_released: {connection}")
+ self.active_connection = None
+ self.notify("released", connection)
+
+ def on_connection_failed(self, connection: _ConnectionStubs, exception):
+ print(f"on_connection_failed: {connection}, {exception}")
+ self.active_connection = None
+ self.notify("failed", connection)
+
+ def on_connection_acquired(self, connection: _ConnectionStubs):
+ print(f"on_connection_acquired: {connection}")
+ self.active_connection = connection
+ self.notify("acquired", connection)
+
+ def on_connections_updated(self, connections: list[_ConnectionStubs]):
+ print(f"on_connections_updated: {connections}")
+ self.notify("updated", connections)
+
+ def get_active_connection(self):
+ """
+ Get the active connection.
+ :return:
+ """
+ return self.active_connection
+
+
+class CustomClientInterceptor(
+ grpc.aio.UnaryUnaryClientInterceptor,
+ grpc.aio.UnaryStreamClientInterceptor,
+ grpc.aio.StreamUnaryClientInterceptor,
+ grpc.aio.StreamStreamClientInterceptor,
+):
+ async def intercept_unary_unary(self, continuation, client_call_details, request):
+ print(f"intercept_unary_unary: {client_call_details}, {request}")
+ return await continuation(client_call_details, request)
+
+ async def intercept_unary_stream(self, continuation, client_call_details, request):
+ print(f"intercept_unary_stream: {client_call_details}, {request}")
+ return await continuation(client_call_details, request)
+
+ async def intercept_stream_unary(
+ self, continuation, client_call_details, request_iterator
+ ):
+ print(f"intercept_stream_unary: {client_call_details}, {request_iterator}")
+ return await continuation(client_call_details, request_iterator)
+
+ async def intercept_stream_stream(
+ self, continuation, client_call_details, request_iterator
+ ):
+ print(f"intercept_stream_stream: {client_call_details}, {request_iterator}")
+ return await continuation(client_call_details, request_iterator)
+
+
+# clear;for i in {0..10};do curl localhost:51000/endpoint?text=x_${i} ;done;
diff --git a/poc/custom_gateway/start_gateway.py b/poc/custom_gateway/start_gateway.py
new file mode 100644
index 00000000..8a748dac
--- /dev/null
+++ b/poc/custom_gateway/start_gateway.py
@@ -0,0 +1,44 @@
+import time
+
+from marie import Flow
+from marie.serve.runtimes.servers.grpc import GRPCServer
+from poc.custom_gateway.server_gateway import MarieServerGateway
+
+
+def main():
+ print("Bootstrapping server gateway")
+
+ if False:
+ ctrl_address = "0.0.0.0:61000"
+ print('waiting for all servers to be ready at : ', ctrl_address)
+ while True:
+ print(f"checking is ready at {ctrl_address}")
+ res = GRPCServer.is_ready(ctrl_address)
+
+ print(f"res: {res}")
+ if res:
+ print(f"Gateway is ready at {ctrl_address}")
+
+ break
+ time.sleep(1)
+ return
+
+ # gateway --protocol http --discovery --discovery-host 127.0.0.1 --discovery-port 8500 --host 192.168.102.65 --port 5555
+
+ with (
+ Flow(
+ # server gateway does not need discovery service this will be available as runtime_args.discovery: bool
+ discovery=False,
+ ).config_gateway(
+ uses=MarieServerGateway,
+ protocols=["GRPC", "HTTP"],
+ ports=[52000, 51000],
+ )
+ # .add(tls=False, host="0.0.0.0", external=True, port=61000)
+ as flow
+ ):
+ flow.block()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/check_qrcode.py b/tests/check_datamatrix_code.py
similarity index 96%
rename from tests/check_qrcode.py
rename to tests/check_datamatrix_code.py
index 0e986f59..551ca37b 100644
--- a/tests/check_qrcode.py
+++ b/tests/check_datamatrix_code.py
@@ -23,7 +23,7 @@ def process_data_matrix(filename):
if len(unique_values) > 2:
ret, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
- pad = 50
+ pad = 100
gray = cv2.copyMakeBorder(gray, pad, pad, pad, pad, cv2.BORDER_CONSTANT, value=255)
decoded_objects = decode(gray, corrections=3)
@@ -31,6 +31,7 @@ def process_data_matrix(filename):
output = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
for obj in decoded_objects:
+ print("Decoded Data : ", obj.data.decode('utf-8'))
rect = obj.rect
cv2.rectangle(output, (rect.left, rect.top), (rect.left + rect.width, rect.top + rect.height), (0, 0, 255), 2)
# cv2.putText(output, obj.data.decode('utf-8'), (rect.left, rect.top - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 255), 1)
diff --git a/tests/core/test_job_manager.py b/tests/core/test_job_manager.py
index 7d0c50aa..bd6ef9ff 100644
--- a/tests/core/test_job_manager.py
+++ b/tests/core/test_job_manager.py
@@ -1,5 +1,4 @@
import asyncio
-import json
import multiprocessing
import os
import random
@@ -8,13 +7,13 @@
import time
import pytest
-from docarray import Document
+from docarray import DocList
+from docarray.documents import TextDoc
-from marie import Deployment, DocumentArray, Executor, requests
+from marie import Deployment, Document, DocumentArray, Executor, requests
from marie.enums import PollingType
from marie.parsers import set_deployment_parser
from marie.proto import jina_pb2
-from marie.serve.networking import GrpcConnectionPool, _ReplicaList
from marie.serve.networking.balancer.load_balancer import LoadBalancerType
from marie.serve.runtimes.asyncio import AsyncNewLoopRuntime
from marie.serve.runtimes.gateway.streamer import GatewayStreamer
@@ -61,13 +60,19 @@ async def check_job_running(job_manager, job_id):
@pytest.fixture
async def job_manager(tmp_path):
storage = InMemoryKV()
+ # TODO: Externalize the storage configuration
+ storage_config = {
+ "hostname": "127.0.0.1",
+ "port": 5432,
+ "username": "postgres",
+ "password": "123456",
+ "database": "postgres",
+ "default_table": "kv_store_a",
+ "max_pool_size": 5,
+ "max_connections": 5,
+ }
- storage_config = {"hostname": "127.0.0.1", "port": 5432, "username": "postgres", "password": "123456",
- "database": "postgres",
- "default_table": "kv_store_a", "max_pool_size": 5,
- "max_connections": 5}
-
- # storage = PostgreSQLKV(config=storage_config, reset=True)
+ storage = PostgreSQLKV(config=storage_config, reset=True)
yield JobManager(storage=storage, job_distributor=NoopJobDistributor())
@@ -89,8 +94,12 @@ async def test_list_jobs(job_manager: JobManager):
metadata=metadata,
)
- _ = asyncio.create_task(async_delay(update_job_status(job_manager, "1", JobStatus.SUCCEEDED), 1))
- _ = asyncio.create_task(async_delay(update_job_status(job_manager, "2", JobStatus.SUCCEEDED), 1))
+ _ = asyncio.create_task(
+ async_delay(update_job_status(job_manager, "1", JobStatus.SUCCEEDED), 1)
+ )
+ _ = asyncio.create_task(
+ async_delay(update_job_status(job_manager, "2", JobStatus.SUCCEEDED), 1)
+ )
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id="1"
@@ -123,7 +132,11 @@ async def test_pass_job_id(job_manager):
)
assert returned_id == submission_id
- _ = asyncio.create_task(async_delay(update_job_status(job_manager, submission_id, JobStatus.SUCCEEDED), 1))
+ _ = asyncio.create_task(
+ async_delay(
+ update_job_status(job_manager, submission_id, JobStatus.SUCCEEDED), 1
+ )
+ )
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=submission_id
@@ -146,7 +159,9 @@ async def test_simultaneous_submit_job(job_manager):
)
for job_id in job_ids:
- _ = asyncio.create_task(async_delay(update_job_status(job_manager, job_id, JobStatus.SUCCEEDED), 1))
+ _ = asyncio.create_task(
+ async_delay(update_job_status(job_manager, job_id, JobStatus.SUCCEEDED), 1)
+ )
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
@@ -167,7 +182,9 @@ async def test_simultaneous_with_same_id(job_manager):
assert "Job with submission_id 1 already exists" in str(excinfo.value)
# Check that the (first) job can still succeed.
- _ = asyncio.create_task(async_delay(update_job_status(job_manager, "1", JobStatus.SUCCEEDED), 1))
+ _ = asyncio.create_task(
+ async_delay(update_job_status(job_manager, "1", JobStatus.SUCCEEDED), 1)
+ )
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id="1"
@@ -177,12 +194,12 @@ async def test_simultaneous_with_same_id(job_manager):
class StreamerTestExecutor(Executor):
@requests
def foo(self, docs, parameters, **kwargs):
- text_to_add = parameters.get('text_to_add', 'default ')
+ text_to_add = parameters.get("text_to_add", "default ")
for doc in docs:
doc.text += text_to_add
-def _create_worker_runtime(port, uses, name=''):
+def _create_worker_runtime(port, uses, name=""):
args = _generate_pod_args()
args.port = [port]
args.name = name
@@ -193,40 +210,40 @@ def _create_worker_runtime(port, uses, name=''):
def _setup(pod0_port, pod1_port):
pod0_process = multiprocessing.Process(
- target=_create_worker_runtime, args=(pod0_port, 'StreamerTestExecutor')
+ target=_create_worker_runtime, args=(pod0_port, "StreamerTestExecutor")
)
pod0_process.start()
pod1_process = multiprocessing.Process(
- target=_create_worker_runtime, args=(pod1_port, 'StreamerTestExecutor')
+ target=_create_worker_runtime, args=(pod1_port, "StreamerTestExecutor")
)
pod1_process.start()
assert BaseServer.wait_for_ready_or_shutdown(
timeout=5.0,
- ctrl_address=f'0.0.0.0:{pod0_port}',
+ ctrl_address=f"0.0.0.0:{pod0_port}",
ready_or_shutdown_event=multiprocessing.Event(),
)
assert BaseServer.wait_for_ready_or_shutdown(
timeout=5.0,
- ctrl_address=f'0.0.0.0:{pod1_port}',
+ ctrl_address=f"0.0.0.0:{pod1_port}",
ready_or_shutdown_event=multiprocessing.Event(),
)
return pod0_process, pod1_process
@pytest.mark.parametrize(
- 'parameters, target_executor, expected_text',
+ "parameters, target_executor, expected_text",
[ # (None, None, 'default default '),
- ({'pod0__text_to_add': 'param_pod0 '}, None, 'param_pod0 default '),
- (None, 'pod1', 'default '),
- ({'pod0__text_to_add': 'param_pod0 '}, 'pod0', 'param_pod0 '),
+ ({"pod0__text_to_add": "param_pod0 "}, None, "param_pod0 default "),
+ (None, "pod1", "default "),
+ ({"pod0__text_to_add": "param_pod0 "}, "pod0", "param_pod0 "),
],
)
-@pytest.mark.parametrize('results_in_order', [False, True])
+@pytest.mark.parametrize("results_in_order", [False, True])
@pytest.mark.asyncio
async def test_gateway_job_manager(
- port_generator, parameters, target_executor, expected_text, results_in_order
+ port_generator, parameters, target_executor, expected_text, results_in_order
):
pod0_port = port_generator()
pod1_port = port_generator()
@@ -243,15 +260,15 @@ async def test_gateway_job_manager(
)
try:
- input_da = DocumentArray.empty(60)
- resp = DocumentArray.empty(0)
+ input_da = DocList([TextDoc(text="default ") for _ in range(60)])
+ resp = DocList([])
num_resp = 0
async for r in gateway_streamer.stream_docs(
- docs=input_da,
- request_size=10,
- parameters=parameters,
- target_executor=target_executor,
- results_in_order=results_in_order,
+ docs=input_da,
+ request_size=10,
+ parameters=parameters,
+ target_executor=target_executor,
+ results_in_order=results_in_order,
):
num_resp += 1
resp.extend(r)
@@ -277,18 +294,18 @@ async def test_gateway_job_manager(
def _create_regular_deployment(
- port,
- name='',
- executor=None,
- noblock_on_start=True,
- polling=PollingType.ANY,
- shards=None,
- replicas=None,
+ port,
+ name="",
+ executor=None,
+ noblock_on_start=True,
+ polling=PollingType.ANY,
+ shards=None,
+ replicas=None,
):
# return Deployment(uses=executor, include_gateway=False, noblock_on_start=noblock_on_start, replicas=replicas,
# shards=shards)
- args = set_deployment_parser().parse_args(['--port', str(port)])
+ args = set_deployment_parser().parse_args(["--port", str(port)])
args.name = name
if shards:
args.shards = shards
@@ -307,7 +324,7 @@ def encode(self, docs, **kwargs):
assert len(docs) == 1
doc = docs[0]
r = 0
- if doc.text == 'slow':
+ if doc.text == "slow":
# random sleep between 0.1 and 0.5
# time.sleep(.5)
r = random.random() / 2 + 0.1
@@ -315,8 +332,8 @@ def encode(self, docs, **kwargs):
time.sleep(r)
print(f"{os.getpid()} : {doc.id} >> {doc.text} : {r}")
- doc.text += f'return encode {os.getpid()}'
- doc.tags['pid'] = os.getpid()
+ doc.text += f"return encode {os.getpid()}"
+ doc.tags["pid"] = os.getpid()
class NoopJobDistributor(JobDistributor):
@@ -324,32 +341,45 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def submit_job(self, job_info: JobInfo) -> DataRequest:
- print(f'NoopJobDistributor: {job_info}')
+ print(f"NoopJobDistributor: {job_info}")
if job_info.status != JobStatus.PENDING:
- raise Exception('Job status is not PENDING')
+ raise Exception("Job status is not PENDING")
r = DataRequest()
r.status.code = jina_pb2.StatusProto.ERROR
return r
+
@pytest.mark.asyncio
async def test_deployment_gateway_streamer(port_generator):
deployment_port = port_generator()
- graph_description = {"start-gateway": ["deployment0"], "deployment0": ["end-gateway"]}
+ graph_description = {
+ "start-gateway": ["deployment0"],
+ "deployment0": ["end-gateway"],
+ }
replica_count = 4
- deployment = _create_regular_deployment(deployment_port, 'deployment0', executor=FastSlowPIDExecutor.__name__,
- noblock_on_start=False, replicas=replica_count, shards=None)
+ deployment = _create_regular_deployment(
+ deployment_port,
+ "deployment0",
+ executor=FastSlowPIDExecutor.__name__,
+ noblock_on_start=False,
+ replicas=replica_count,
+ shards=None,
+ )
deployment.start()
- connections = [f'{host}:{port}' for host, port in zip(deployment.hosts, deployment.ports)]
+ connections = [
+ f"{host}:{port}" for host, port in zip(deployment.hosts, deployment.ports)
+ ]
deployments_addresses = {"deployment0": connections}
deployments_metadata = {"deployment0": {"key": "value"}}
# manually start the deployment
gateway_streamer = GatewayStreamer(
- graph_representation=graph_description, executor_addresses=deployments_addresses,
+ graph_representation=graph_description,
+ executor_addresses=deployments_addresses,
deployments_metadata=deployments_metadata,
load_balancer_type=LoadBalancerType.ROUND_ROBIN.name,
# load_balancer_type=LoadBalancerType.LEAST_CONNECTION.name,
@@ -366,7 +396,7 @@ async def test_deployment_gateway_streamer(port_generator):
print(f"scheduling request : {i}")
request = DataRequest()
# request.data.docs = DocumentArray([Document(text='slow' if i % 2 == 0 else 'fast')])
- request.data.docs = DocumentArray([Document(text='slow')])
+ request.data.docs = DocumentArray([Document(text="slow")])
response = gateway_streamer.process_single_data(request=request)
tasks.append(response)
# time.sleep(.2)
@@ -377,7 +407,7 @@ async def test_deployment_gateway_streamer(port_generator):
for response in futures:
assert len(response.docs) == 1
for doc in response.docs:
- pid = int(doc.tags['pid'])
+ pid = int(doc.tags["pid"])
if pid not in pids:
pids[pid] = 0
pids[pid] += 1
@@ -394,12 +424,14 @@ async def test_deployment_gateway_streamer(port_generator):
print("--" * 10)
print(f"sending request : {i}")
request = DataRequest()
- request.data.docs = DocumentArray([Document(text='slow' if i % 2 == 0 else 'fast')])
+ request.data.docs = DocumentArray(
+ [Document(text="slow" if i % 2 == 0 else "fast")]
+ )
response = await gateway_streamer.process_single_data(request=request)
assert len(response.docs) == 1
for doc in response.docs:
- pid = int(doc.tags['pid'])
+ pid = int(doc.tags["pid"])
print(pid)
if pid not in pids:
pids[pid] = 0
@@ -415,20 +447,32 @@ async def test_deployment_gateway_streamer(port_generator):
@pytest.mark.asyncio
async def test_deployment_with_job_manager(port_generator, job_manager):
deployment_port = port_generator()
- graph_description = {"start-gateway": ["deployment0"], "deployment0": ["end-gateway"]}
+ graph_description = {
+ "start-gateway": ["deployment0"],
+ "deployment0": ["end-gateway"],
+ }
replica_count = 4
- deployment = _create_regular_deployment(deployment_port, 'deployment0', executor=FastSlowPIDExecutor.__name__,
- noblock_on_start=False, replicas=replica_count, shards=None)
+ deployment = _create_regular_deployment(
+ deployment_port,
+ "deployment0",
+ executor=FastSlowPIDExecutor.__name__,
+ noblock_on_start=False,
+ replicas=replica_count,
+ shards=None,
+ )
deployment.start()
- connections = [f'{host}:{port}' for host, port in zip(deployment.hosts, deployment.ports)]
+ connections = [
+ f"{host}:{port}" for host, port in zip(deployment.hosts, deployment.ports)
+ ]
deployments_addresses = {"deployment0": connections}
deployments_metadata = {"deployment0": {"key": "value"}}
# manually start the deployment
gateway_streamer = GatewayStreamer(
- graph_representation=graph_description, executor_addresses=deployments_addresses,
+ graph_representation=graph_description,
+ executor_addresses=deployments_addresses,
deployments_metadata=deployments_metadata,
load_balancer_type=LoadBalancerType.ROUND_ROBIN.name,
# load_balancer_type=LoadBalancerType.LEAST_CONNECTION.name,
@@ -437,7 +481,6 @@ async def test_deployment_with_job_manager(port_generator, job_manager):
stop_event = threading.Event()
await gateway_streamer.warmup(stop_event=stop_event)
-
pids = {}
if False:
pids = {}
@@ -449,7 +492,7 @@ async def test_deployment_with_job_manager(port_generator, job_manager):
print(f"scheduling request : {i}")
request = DataRequest()
# request.data.docs = DocumentArray([Document(text='slow' if i % 2 == 0 else 'fast')])
- request.data.docs = DocumentArray([Document(text='slow')])
+ request.data.docs = DocumentArray([Document(text="slow")])
response = gateway_streamer.process_single_data(request=request)
tasks.append(response)
# time.sleep(.2)
@@ -460,7 +503,7 @@ async def test_deployment_with_job_manager(port_generator, job_manager):
for response in futures:
assert len(response.docs) == 1
for doc in response.docs:
- pid = int(doc.tags['pid'])
+ pid = int(doc.tags["pid"])
if pid not in pids:
pids[pid] = 0
pids[pid] += 1
@@ -476,6 +519,5 @@ async def test_deployment_with_job_manager(port_generator, job_manager):
await gateway_streamer.close()
-
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
diff --git a/tests/helper.py b/tests/helper.py
index b1221c22..6a32a7d3 100644
--- a/tests/helper.py
+++ b/tests/helper.py
@@ -1,28 +1,28 @@
-from docarray import DocumentArray
-
+# from docarray import DocumentArray
from marie import Executor, requests
+from marie._docarray import Document, DocumentArray
from marie.parsers import set_pod_parser
class ProcessExecutor(Executor):
- @requests(on='/')
+ @requests(on="/")
def process(self, docs: DocumentArray, **kwargs):
for doc in docs:
- doc.text = doc.text + 'world'
- doc.tags['processed'] = True
+ doc.text = doc.text + "world"
+ doc.tags["processed"] = True
def _validate_dummy_custom_gateway_response(port, expected):
import requests
- resp = requests.get(f'http://127.0.0.1:{port}/').json()
+ resp = requests.get(f"http://127.0.0.1:{port}/").json()
assert resp == expected
def _validate_custom_gateway_process(port, text, expected):
import requests
- resp = requests.get(f'http://127.0.0.1:{port}/stream?text={text}').json()
+ resp = requests.get(f"http://127.0.0.1:{port}/stream?text={text}").json()
assert resp == expected
diff --git a/tests/integration/check_document_boundry.py b/tests/integration/check_document_boundry.py
index 685f7a1c..1fb69237 100644
--- a/tests/integration/check_document_boundry.py
+++ b/tests/integration/check_document_boundry.py
@@ -28,8 +28,10 @@ def check_boundry_registration():
model_name_or_path="../../model_zoo/unilm/dit/object_detection/document_boundary",
use_gpu=True,
)
- filepath = "~/PID_402_8220_0_200802683.tif"
- filepath = "~/PID_3736_11058_0_200942298.tif" # => 201458362
+ filepath = "~/PID_3736_11058_0_200942298.tif"
+ filepath = "~/tmp/analysis/document-boundary/samples/204169581/PID_3585_10907_0_204169581.tif"
+ filepath = "~/tmp/analysis/document-boundary/samples/203822066/PID_3585_10907_0_203822066.tif"
+ filepath = "~/tmp/analysis/document-boundary/samples/204446542/PID_3585_10907_0_204446542.tif"
basename = filepath.split("/")[-1].split(".")[0]
documents = docs_from_file(filepath)
diff --git a/tests/integration/check_ocr_eninge_renderer.py b/tests/integration/check_ocr_eninge_renderer.py
index 25562b16..bca20544 100644
--- a/tests/integration/check_ocr_eninge_renderer.py
+++ b/tests/integration/check_ocr_eninge_renderer.py
@@ -3,6 +3,7 @@
import cv2
import torch
+from PIL import Image
from marie.boxes.box_processor import PSMode
from marie.ocr import CoordinateFormat, DefaultOcrEngine, OcrEngine
@@ -42,7 +43,19 @@ def process_file(ocr_engine: OcrEngine, img_path: str):
key = img_path.split("/")[-1]
frames = frames_from_file(img_path)
- results = ocr_engine.extract(frames, PSMode.SPARSE, CoordinateFormat.XYWH)
+ # test failure of small region
+ frame = frames[0]
+ xywh = (1691, 473, 255, 28)
+ fragment = frame[xywh[1]:xywh[1] + xywh[3], xywh[0]:xywh[0] + xywh[2]]
+ cv2.imwrite(f"/tmp/fragments/{key}.png", fragment)
+
+ pil_frag = Image.open(img_path).crop(
+ (xywh[0], xywh[1], xywh[0] + xywh[2], xywh[1] + xywh[3])
+ )
+ pil_frag.save(f"/tmp/fragments/{key}-PIL.png")
+ # save the fragment
+
+ results = ocr_engine.extract([pil_frag], PSMode.SPARSE, CoordinateFormat.XYWH)
print("Testing text renderer")
json_path = os.path.join("/tmp/fragments", f"results-{key}.json")
@@ -142,7 +155,7 @@ def extract_bouding_boxes(img_path: str, metadata_path: str, ngram: int = 2):
img_path = "~/tmp/4007/176073139.tif"
img_path = "~/tmp/demo/159581778_1.png"
- img_path = os.path.expanduser("~/tmp/demo")
+ # img_path = os.path.expanduser("~/tmp/demo")
# img_path = os.path.expanduser("/home/gbugaj/dev/marieai/marie-ai/assets/template_matching/sample-001.png")
use_cuda = torch.cuda.is_available()
diff --git a/tests/integration/core/text/check_text_executor.py b/tests/integration/core/text/check_text_executor.py
index bfc30683..f1078e61 100644
--- a/tests/integration/core/text/check_text_executor.py
+++ b/tests/integration/core/text/check_text_executor.py
@@ -7,11 +7,9 @@
from docarray.documents import TextDoc
from marie import Client
-from marie.executor.text import TextExtractionExecutor, TextExtractionExecutorMock
+from marie.executor.text import TextExtractionExecutorMock
from marie.storage import S3StorageHandler, StorageManager
-from marie.utils.docs import docs_from_file, frames_from_docs
from marie.utils.json import load_json_file, store_json_object
-from marie.utils.utils import ensure_exists
from marie_server.rest_extension import parse_payload_to_docs_sync
diff --git a/tests/integration/discovery/__init__.py b/tests/integration/discovery/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/integration/discovery/conftest.py b/tests/integration/discovery/conftest.py
new file mode 100644
index 00000000..2a8ab6fb
--- /dev/null
+++ b/tests/integration/discovery/conftest.py
@@ -0,0 +1,27 @@
+import pytest
+
+from marie.serve.discovery.registry import EtcdServiceRegistry
+from marie.serve.discovery.resolver import EtcdServiceResolver
+
+
+@pytest.fixture(scope='function')
+def grpc_resolver(mocker, grpc_addr):
+ client = mocker.Mock()
+ resolver = EtcdServiceResolver(etcd_client=client, start_listener=False)
+
+ def resolve(service_name):
+ return grpc_addr
+
+ old_resolve = resolver.resolve
+ setattr(resolver, 'resolve', resolve)
+ yield resolver
+
+ setattr(resolver, 'resolve', old_resolve)
+
+
+@pytest.fixture(scope='function')
+def etcd_registry(mocker):
+ client = mocker.Mock()
+ registry = EtcdServiceRegistry(etcd_client=client)
+
+ return registry
diff --git a/tests/integration/discovery/test_address.py b/tests/integration/discovery/test_address.py
new file mode 100644
index 00000000..525e870f
--- /dev/null
+++ b/tests/integration/discovery/test_address.py
@@ -0,0 +1,55 @@
+import json
+
+import pytest
+
+from marie.serve.discovery.address import JsonAddress, PlainAddress
+
+
+@pytest.mark.parametrize(
+ 'addr, exp_addr', (
+ ('1.2.3.4', '1.2.3.4'),
+ ('5.6.7.8', '5.6.7.8'),
+ )
+)
+def test_to_plain_address(addr, exp_addr):
+ assert PlainAddress(addr).add_value() == exp_addr
+ assert PlainAddress(addr).delete_value() == exp_addr
+
+
+@pytest.mark.parametrize(
+ 'addr, exp_addr', (
+ (b'1.2.3.4', '1.2.3.4'),
+ (b'5.6.7.8', '5.6.7.8'),
+ (b'11.2.3.4', '11.2.3.4'),
+ (b'55.6.7.8', '55.6.7.8'),
+ )
+)
+def test_from_plain_address(addr, exp_addr):
+ assert PlainAddress.from_value(addr) == exp_addr
+
+
+@pytest.mark.parametrize(
+ 'val', (
+ '1.2.3.4',
+ '5.6.7.8',
+ )
+)
+def test_to_json_address(val):
+ assert JsonAddress(val).add_value() == json.dumps({
+ 'Op': 0, 'Addr': val, 'Metadata': "{}"})
+ assert JsonAddress(
+ val, metadata={'name': 'host1'}).delete_value() == json.dumps({
+ 'Op': 1, 'Addr': val, 'Metadata': json.dumps({'name': 'host1'})})
+
+
+@pytest.mark.parametrize(
+ 'val, op, addr', (
+ (b'{"Op": 1, "Addr": "1.2.3.4", "Metadata": "{}"}', False, '1.2.3.4'),
+ (b'{"Op": 0, "Addr": "5.6.7.8", "Metadata": "{}"}', True, '5.6.7.8'),
+ (b'{"Op": 1, "Addr": "11.2.3.4", "Metadata": "{}"}', False, '11.2.3.4'),
+ (b'{"Op": 0, "Addr": "55.6.7.8", "Metadata": "{}"}', True, '55.6.7.8'),
+ )
+)
+def test_from_json_address(val, op, addr):
+ data = JsonAddress.from_value(val)
+ assert (op, addr) == data
diff --git a/tests/integration/discovery/test_registry.py b/tests/integration/discovery/test_registry.py
new file mode 100644
index 00000000..0dee07a1
--- /dev/null
+++ b/tests/integration/discovery/test_registry.py
@@ -0,0 +1,20 @@
+import time
+
+import pytest
+
+
+@pytest.mark.parametrize(
+ 'service_names, service_addr, service_ttl', (
+ (('grpc.service_test', 'grpc.service_list'), '10.30.1.1.50011', 120),
+ (('grpc.service_create', 'grpc.service_update'), '10.30.1.1.50011', 120),
+ )
+)
+def test_service_registry(
+ etcd_registry, service_names, service_addr, service_ttl):
+ etcd_registry.register(service_names, service_addr, service_ttl)
+ assert etcd_registry._services[service_addr] == set(service_names)
+ assert service_addr in etcd_registry._leases
+
+ etcd_registry.unregister((service_names[0],), service_addr)
+ assert etcd_registry._services[service_addr] == {service_names[1]}
+ assert service_addr in etcd_registry._leases
diff --git a/tests/integration/gateway_clients/test_long_flow_keepalive.py b/tests/integration/gateway_clients/test_long_flow_keepalive.py
index 34d8e9d3..6e7d4b44 100644
--- a/tests/integration/gateway_clients/test_long_flow_keepalive.py
+++ b/tests/integration/gateway_clients/test_long_flow_keepalive.py
@@ -1,9 +1,10 @@
import time
import pytest
-from docarray import DocumentArray
+from docarray import DocList
+from docarray.documents import TextDoc
-from marie import Executor, Flow, requests
+from marie import DocumentArray, Executor, Flow, requests
@pytest.fixture()
@@ -23,7 +24,7 @@ def test_long_flow_keep_alive(slow_executor):
# it tests that the connection to a flow that take a lot of time to process will not be killed by the keepalive feature
with Flow().add(uses=slow_executor) as f:
- docs = f.search(inputs=DocumentArray.empty(10))
+ docs = f.search(inputs=DocList[TextDoc]([TextDoc()]))
for doc_ in docs:
assert doc_.text == 'process'
diff --git a/tests/integration/scheduler/__init__.py b/tests/integration/scheduler/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/integration/scheduler/test_job_scheduler_core.py b/tests/integration/scheduler/test_job_scheduler_core.py
new file mode 100644
index 00000000..d92ab7b0
--- /dev/null
+++ b/tests/integration/scheduler/test_job_scheduler_core.py
@@ -0,0 +1,27 @@
+import pytest
+
+from marie_server.scheduler import PostgreSQLJobScheduler
+
+
+@pytest.fixture
+def num_jobs(request):
+ print("request.param", request.param)
+ return request.param
+
+
+@pytest.mark.parametrize("num_jobs", [1], indirect=True)
+def test_job_scheduler_setup(num_jobs):
+ print("num_jobs", num_jobs)
+ scheduler_config = {
+ "hostname": "localhost",
+ "port": 5432,
+ "database": "postgres",
+ "username": "postgres",
+ "password": "123456",
+ }
+
+ scheduler = PostgreSQLJobScheduler(config=scheduler_config)
+ scheduler.start_schedule()
+ print(scheduler)
+
+ assert scheduler.running
diff --git a/tests/unit/schemas/test_legacdocument_schema.py b/tests/unit/schemas/test_legacdocument_schema.py
new file mode 100644
index 00000000..02e29830
--- /dev/null
+++ b/tests/unit/schemas/test_legacdocument_schema.py
@@ -0,0 +1,9 @@
+from docarray.documents.legacy import LegacyDocument
+
+from marie.utils.pydantic import patch_pydantic_schema_2x
+
+
+def test_legacy_schema():
+ LegacyDocument.schema = classmethod(patch_pydantic_schema_2x)
+ legacy_doc_schema = LegacyDocument.schema()
+ print(legacy_doc_schema)
diff --git a/tests/unit/serve/executors/test_bad_executor_constructor.py b/tests/unit/serve/executors/test_bad_executor_constructor.py
index 4563f7c0..e17fe241 100644
--- a/tests/unit/serve/executors/test_bad_executor_constructor.py
+++ b/tests/unit/serve/executors/test_bad_executor_constructor.py
@@ -1,6 +1,8 @@
import pytest
+from docarray.documents.legacy import LegacyDocument
from marie import Executor, Flow, requests
+from marie.utils.pydantic import patch_pydantic_schema_2x
class GoodExecutor(Executor):
@@ -22,6 +24,7 @@ def foo(self, docs, parameters, docs_matrix):
def test_bad_executor_constructor():
+
# executor can be used as out of Flow as Python object
exec1 = GoodExecutor()
exec2 = GoodExecutor2({}, {}, {}, {})
diff --git a/tests/unit/serve/runtimes/gateway/grpc/test_grpc_gateway_runtime.py b/tests/unit/serve/runtimes/gateway/grpc/test_grpc_gateway_runtime.py
index 9fbedd69..32035f54 100644
--- a/tests/unit/serve/runtimes/gateway/grpc/test_grpc_gateway_runtime.py
+++ b/tests/unit/serve/runtimes/gateway/grpc/test_grpc_gateway_runtime.py
@@ -7,8 +7,8 @@
import grpc
import pytest
-from docarray import Document, DocumentArray
+from marie import Document, DocumentArray
from marie.clients.request import request_generator
from marie.helper import random_port
from marie.parsers import set_gateway_parser
@@ -493,7 +493,7 @@ def _create_runtime():
p = multiprocessing.Process(target=_create_runtime)
p.start()
- time.sleep(1.0)
+ time.sleep(2.0)
assert BaseServer.is_ready(ctrl_address=f'127.0.0.1:{port}')
async with grpc.aio.insecure_channel(f'127.0.0.1:{port}') as channel:
diff --git a/tools/check_line_frgments_ocr.py b/tools/check_line_frgments_ocr.py
index 8e8a329b..900e7a51 100644
--- a/tools/check_line_frgments_ocr.py
+++ b/tools/check_line_frgments_ocr.py
@@ -23,7 +23,7 @@ def build_ocr_engines():
trocr_processor = TrOcrProcessor(
models_dir="/mnt/data/marie-ai/model_zoo/trocr",
- model_name_or_path="/data/models/unilm/trocr/ft_SROIE_LINES_SET30/checkpoint_best.pt",
+ model_name_or_path="/data/models/unilm/trocr/ft_SROIE_LINES_SET41/checkpoint_best.pt",
cuda=use_cuda,
)
# craft_processor = CraftOcrProcessor(cuda=True)
diff --git a/tools/ocr_diff.py b/tools/ocr_diff.py
index fce29518..b76f1c1f 100644
--- a/tools/ocr_diff.py
+++ b/tools/ocr_diff.py
@@ -31,7 +31,7 @@ def build_ocr_engines():
)
ocr2_processor = TrOcrProcessor(
- model_name_or_path="/data/models/unilm/trocr/ft_SROIE_LINES_SET27/checkpoint_best.pt",
+ model_name_or_path="/data/models/unilm/trocr/ft_SROIE_LINES_SET38/checkpoint_best.pt",
cuda=use_cuda,
)
@@ -42,27 +42,18 @@ def build_ocr_engines():
def create_image_with_text(snippet, text1, text2, output_path):
- # Create a new image with space for the text
width, height = snippet.size
new_image = Image.new("RGB", (width, height + 60), (255, 255, 255))
-
- # Paste the snippet onto the new image
new_image.paste(snippet, (0, 0))
-
- # Get a drawing context
draw = ImageDraw.Draw(new_image)
# Load a font (this font should be available on your system)
# font = ImageFont.truetype("arial", 15)
-
- # Load the default font
font = ImageFont.load_default()
- # Draw the text onto the new image
draw.text((10, height + 10), text1, font=font, fill=(0, 0, 0))
draw.text((10, height + 30), text2, font=font, fill=(0, 0, 0))
- # Save the new image
new_image.save(output_path)
@@ -87,7 +78,7 @@ def process_image(img_path, box_processor, ocr1_processor, ocr2_processor):
"gradio ", "00000", image, boxes, fragments, lines, return_overlay=True
)
- output_dir = os.path.expanduser("~/tmp/ocr-diffs/v3")
+ output_dir = os.path.expanduser("~/tmp/ocr-diffs/v5")
output_dir_raw = os.path.join(output_dir, "raw")
output_dir_diff = os.path.join(output_dir, "diff")
output_dir_ocr1 = os.path.join(output_dir, "ocr1")
@@ -107,10 +98,14 @@ def process_image(img_path, box_processor, ocr1_processor, ocr2_processor):
print("word : ", word1, word2)
# filter only for words that contain digits
- if not any(char.isdigit() for char in word1["text"]):
- continue
+ # if not any(char.isdigit() for char in word1["text"]):
+ # continue
+
+ # only compare words with low confidence
+ conf1 = word1["confidence"]
+ conf2 = word2["confidence"]
- if word1["text"] != word2["text"]:
+ if word1["text"] != word2["text"] and (conf1 > 0.8 and conf1 > conf2):
print("DIFFERENT")
print(word1)
print(word2)
@@ -119,8 +114,20 @@ def process_image(img_path, box_processor, ocr1_processor, ocr2_processor):
conf2 = word2["confidence"]
mix_word_len = min(len(word1["text"]), len(word2["text"]))
- if mix_word_len < 3:
- print("skipping short word : " + word1["text"] + " " + word2["text"])
+ # if mix_word_len < 3:
+ # print("skipping short word : " + word1["text"] + " " + word2["text"])
+ # continue
+
+ # check for any digit or dollar sign
+ if (
+ not any(char.isdigit() for char in word1["text"])
+ and not any(char.isdigit() for char in word2["text"])
+ and not any(char == "$" for char in word1["text"])
+ and not any(char == "$" for char in word2["text"])
+ ):
+ print(
+ "skipping non digit word : " + word1["text"] + " " + word2["text"]
+ )
continue
# clip the image snippet from the original image
@@ -186,7 +193,8 @@ def process_dir(image_dir: str, box_processor, ocr1_processor, ocr2_processor):
box_processor, ocr1_processor, ocr2_processor = build_ocr_engines()
process_dir(
- "/home/greg/datasets/funsd_dit/IMAGES/LbxIDImages_boundingBox_6292023",
+ # "/home/greg/datasets/funsd_dit/IMAGES/LbxIDImages_boundingBox_6292023",
+ "~/datasets/private/eob-extract/converted/imgs/eob-extract/eob-003",
box_processor,
ocr1_processor,
ocr2_processor,
diff --git a/workspaces/bounding-boxes-gradio/app.py b/workspaces/bounding-boxes-gradio/app.py
index e535b460..5d28e2ad 100644
--- a/workspaces/bounding-boxes-gradio/app.py
+++ b/workspaces/bounding-boxes-gradio/app.py
@@ -1,45 +1,73 @@
+import uuid
+
import gradio as gr
from PIL import Image
from marie.boxes import BoxProcessorUlimDit, PSMode
from marie.boxes.dit.ulim_dit_box_processor import visualize_bboxes
-
-box = BoxProcessorUlimDit(
- models_dir="../../model_zoo/unilm/dit/text_detection",
- cuda=True,
-)
-
-sel_bbox_optimization = False
-sel_content_aware = False
-
-
-def update_content_aware(value):
- global sel_content_aware
- sel_content_aware = value
-
-
-def update_bbox_optimzation(value):
- global sel_bbox_optimization
- sel_bbox_optimization = value
-
-
-def process_image(image):
- print("Processing image")
- print(f"Content Aware: {sel_content_aware}")
- print(f"BBox Optimization: {sel_bbox_optimization}")
-
- (boxes, fragments, lines, _, lines_bboxes,) = box.extract_bounding_boxes(
- "gradio",
- "field",
- image,
- PSMode.SPARSE,
- sel_content_aware,
- sel_bbox_optimization,
- )
-
- bboxes_img = visualize_bboxes(image, boxes, format="xywh")
- lines_img = visualize_bboxes(image, lines_bboxes, format="xywh")
- return bboxes_img, lines_img
+from marie.models.utils import setup_torch_optimizations
+from marie.utils.json import store_json_object
+
+
+class GradioBoxProcessor:
+ def __init__(self):
+ self.sel_content_aware = False
+ self.sel_bbox_optimization = False
+ self.sel_bbox_refinement = True
+
+ self.box = BoxProcessorUlimDit(
+ models_dir="../../model_zoo/unilm/dit/text_detection",
+ cuda=True,
+ )
+
+ def update_bbox_refinement(self, value):
+ self.sel_bbox_refinement = value
+
+ def update_content_aware(self, value):
+ self.sel_content_aware = value
+
+ def update_bbox_optimization(self, value):
+ self.sel_bbox_optimization = value
+
+ def process_image(self, image):
+ print("Processing image")
+ print(f"Content Aware: {self.sel_content_aware}")
+ print(f"BBox Optimization: {self.sel_bbox_optimization}")
+
+ (
+ boxes,
+ fragments,
+ lines,
+ _,
+ lines_bboxes,
+ ) = self.box.extract_bounding_boxes(
+ "gradio",
+ "field",
+ image,
+ PSMode.SPARSE,
+ self.sel_content_aware,
+ self.sel_bbox_optimization,
+ self.sel_bbox_refinement,
+ )
+
+ bboxes_img = visualize_bboxes(
+ image,
+ boxes,
+ format="xywh",
+ # blackout_color=(100, 100, 200, 100),
+ # blackout=True,
+ )
+
+ name = str(uuid.uuid4())
+ store_json_object(boxes, f"/tmp/boxes/boxes-{name}.json")
+ lines_img = visualize_bboxes(image, lines_bboxes, format="xywh")
+
+ return bboxes_img, lines_img, len(boxes), len(lines_bboxes)
+
+
+processor = GradioBoxProcessor()
+processor.update_content_aware(True)
+processor.update_bbox_optimization(True)
def interface():
@@ -53,52 +81,66 @@ def interface():
with gr.Row():
with gr.Column():
- src = gr.Image(type="pil", source="upload")
+ src = gr.Image(type="pil", sources=["upload"])
with gr.Column():
chk_apply_bbox_optimization = gr.Checkbox(
label="Bounding Box optimization",
- default=True,
interactive=True,
)
chk_apply_bbox_optimization.change(
- lambda x: update_bbox_optimzation(x),
+ lambda x: processor.update_bbox_optimization(x),
inputs=[chk_apply_bbox_optimization],
outputs=[],
)
chk_apply_content_aware = gr.Checkbox(
label="Content aware transformation",
- default=True,
+ interactive=True,
+ )
+
+ chk_apply_bbox_refinement = gr.Checkbox(
+ label="Bounding box refinement",
interactive=True,
)
chk_apply_content_aware.change(
- lambda x: update_content_aware(x),
+ lambda x: processor.update_content_aware(x),
inputs=[chk_apply_content_aware],
outputs=[],
)
+ chk_apply_bbox_refinement.change(
+ lambda x: processor.update_bbox_refinement(x),
+ inputs=[chk_apply_bbox_refinement],
+ outputs=[],
+ )
+
with gr.Row():
btn_reset = gr.Button("Clear")
btn_submit = gr.Button("Submit", variant="primary")
+ with gr.Row():
+ with gr.Column():
+ txt_bboxes = gr.components.Textbox(label="Bounding Boxes", value="0")
+ with gr.Column():
+ txt_lines = gr.components.Textbox(label="Lines", value="0")
+
with gr.Row():
with gr.Column():
boxes = gr.components.Image(type="pil", label="boxes")
with gr.Column():
lines = gr.components.Image(type="pil", label="lines")
- btn_submit.click(process_image, inputs=[src], outputs=[boxes, lines])
+ btn_submit.click(
+ processor.process_image,
+ inputs=[src],
+ outputs=[boxes, lines, txt_bboxes, txt_lines],
+ )
iface.launch(debug=True, share=True, server_name="0.0.0.0")
if __name__ == "__main__":
- import torch
-
- torch.set_float32_matmul_precision('high')
- torch.backends.cudnn.benchmark = False
- # torch._dynamo.config.suppress_errors = False
-
+ setup_torch_optimizations()
interface()
diff --git a/workspaces/ocr-diff/app.py b/workspaces/ocr-diff/app.py
new file mode 100644
index 00000000..78dd31a9
--- /dev/null
+++ b/workspaces/ocr-diff/app.py
@@ -0,0 +1,288 @@
+import base64
+import difflib
+import io
+import json
+import logging
+import os
+from io import BytesIO
+
+import cv2
+import pandas as pd
+import streamlit as st
+import streamlit_shortcuts
+from canvas_util import ImageUtils
+from PIL import Image
+from streamlit import session_state as ss
+from streamlit_drawable_canvas import st_canvas
+
+from marie.utils.json import load_json_file, store_json_object
+
+src_dir = os.path.expanduser("~/tmp/ocr-diffs/v5/ocr1")
+output_dir = os.path.expanduser("~/tmp/ocr-diffs/json_data")
+
+os.makedirs(output_dir, exist_ok=True)
+
+json_files = [
+ os.path.join(src_dir, f) for f in os.listdir(src_dir) if f.endswith(".json")
+]
+json_files.sort()
+
+json_files = [
+ os.path.join(src_dir, f)
+ for f in os.listdir(src_dir)
+ if f.endswith(".json")
+ and not os.path.exists(os.path.join(output_dir, os.path.splitext(f)[0] + ".txt"))
+]
+json_files.sort()
+
+
+def load_json(file_index):
+ # Load the JSON file
+ with open(json_files[file_index]) as f:
+ data = json.load(f)
+ return data
+
+
+def display_image(base64_string, scale=1):
+ decoded_image = base64.b64decode(base64_string)
+ image = Image.open(io.BytesIO(decoded_image))
+ image = image.resize((int(image.width * scale), int(image.height * scale)))
+ st.image(image)
+
+
+def save_image_and_text(image_data, text, base_name):
+ # Save the image and text to output_dir
+ image = Image.open(io.BytesIO(base64.b64decode(image_data)))
+ image.save(os.path.join(output_dir, f"{base_name}.png"))
+ with open(os.path.join(output_dir, f"{base_name}.txt"), "w") as f:
+ f.write(text)
+
+
+def accept_word(snippet, word, base_name):
+ def callback():
+ # st.write(f"You accepted {word}")
+ st.session_state.selected_word = word
+ save_image_and_text(snippet, word, base_name)
+ st.session_state.current_file_index += 1
+
+ return callback
+
+
+def get_canvas(self, resized_image, key="canvas", update_streamlit=True, mode="rect"):
+ """Retrieves the canvas to receive the bounding boxes
+ Args:
+ resized_image(Image.Image): the resized uploaded image
+ key(str): the key to initiate the canvas component in streamlit
+ """
+ width, height = resized_image.size
+
+ canvas_result = st_canvas(
+ # fill_color="rgba(255,0, 0, 0.1)",
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
+ stroke_width=2,
+ stroke_color="rgba(255,0,0,1)",
+ background_color="rgba(0,0,0,1)",
+ background_image=resized_image,
+ update_streamlit=update_streamlit,
+ height=960,
+ width=960,
+ drawing_mode=mode,
+ key=key,
+ )
+ return canvas_result
+
+
+def create_initial_drawing(bboxes):
+ template = {
+ "type": "rect",
+ "version": "4.4.0",
+ "originX": "left",
+ "originY": "top",
+ "left": 0,
+ "top": 0,
+ "width": 0,
+ "height": 0,
+ "fill": "rgba(255, 165, 0, 0.3)",
+ "stroke": "rgba(255,0,0,1)",
+ "strokeWidth": 2,
+ "strokeLineCap": "butt",
+ "strokeDashOffset": 0,
+ "strokeLineJoin": "miter",
+ "strokeMiterLimit": 4,
+ "scaleX": 1,
+ "scaleY": 1,
+ "angle": 0,
+ "opacity": 1,
+ "backgroundColor": "",
+ "fillRule": "nonzero",
+ "paintFirst": "fill",
+ "globalCompositeOperation": "source-over",
+ "skewX": 0,
+ "skewY": 0,
+ "rx": 0,
+ "ry": 0,
+ }
+ objects = []
+ for bbox in bboxes:
+ obj = template.copy()
+ obj["left"] = bbox[0]
+ obj["top"] = bbox[1]
+ obj["width"] = bbox[2]
+ obj["height"] = bbox[3]
+ objects.append(obj)
+
+ return {
+ "version": "4.4.0",
+ "objects": objects,
+ }
+
+
+def get_diff_boxes(diffs):
+ bboxes = []
+ for diff in diffs:
+ word1 = diff["word1"] # xywh
+ box = word1["box"]
+ bboxes.append(box)
+
+ print(bboxes)
+ return bboxes
+
+
+def main():
+ st.set_page_config(
+ page_title="OCR-DIFF",
+ page_icon="🧊",
+ layout="wide",
+ initial_sidebar_state="expanded",
+ )
+
+ if "current_file_index" not in st.session_state:
+ st.session_state.current_file_index = 0
+ if "selected_word" not in st.session_state:
+ st.session_state.selected_word = ""
+
+ uploaded_image = st.sidebar.file_uploader(
+ "Upload template source: ",
+ type=["jpg", "jpeg", "png", "webp", "tiff", "tif"],
+ key="source",
+ )
+
+ utils = ImageUtils()
+
+ logging.basicConfig(
+ level=logging.ERROR,
+ format="%(asctime)s%(levelname)s%(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ st.sidebar.write("Matching parameters")
+ realtime_update = st.sidebar.checkbox("Update in realtime", True)
+ submit = st.sidebar.button("Submit")
+
+ # Convert bounding boxes to the format required by st_canvas
+ # initial_drawing = load_json_file(os.path.join(output_dir, "canvas.json"))
+ word_diffs = load_json_file(os.path.join(output_dir, "diff.json"))
+
+ raw_boxes = get_diff_boxes(word_diffs)
+ initial_drawing = create_initial_drawing(raw_boxes)
+
+ if "selected_row_index" not in ss:
+ ss.selected_row_index = None
+
+ # for bbox in bboxes
+ canvas_result = None
+ col1, col2 = st.columns([2, 1])
+ with st.container(border=True):
+ with col1:
+ # st.header("Source image")
+ if uploaded_image is not None:
+ raw_image, raw_size = utils.read_image(uploaded_image, False, 0.75)
+ resized_image, resized_size = utils.resize_image_for_canvas(raw_image)
+ # read bbox input
+ canvas_result = utils.get_canvas(
+ resized_image,
+ key="canvas-source",
+ update_streamlit=True,
+ mode="rect",
+ initial_drawing=initial_drawing,
+ )
+
+ with col2:
+ st.write("Diff-Boxes")
+
+ def image_to_base64(img) -> str:
+ with BytesIO() as buffer:
+ img.save(buffer, "png") # or 'jpeg'
+ return base64.b64encode(buffer.getvalue()).decode()
+
+ def image_formatter(img) -> str:
+ return f"data:image/png;base64,{image_to_base64(img)}"
+
+ # max_width = max([box[2] - box[0] for box in raw_boxes])
+ max_width = max([box[2] for box in raw_boxes])
+ column_configuration = {
+ "Select": st.column_config.CheckboxColumn(),
+ "snippet": st.column_config.ImageColumn(
+ "snippet", help="Document snippet", width=max_width
+ ),
+ }
+ rows = []
+ for word_diff in word_diffs:
+ snippet = word_diff["snippet"]
+ txt_1 = word_diff["word1"]["text"]
+ txt_2 = word_diff["word2"]["text"]
+ conf_1 = word_diff["confidence1"]
+ conf_2 = word_diff["confidence2"]
+
+ rows.append(
+ {
+ "snippet": f"data:image/png;base64,{snippet}",
+ "engine-1": txt_1,
+ "engine-2": txt_2,
+ "conf-1": conf_1,
+ "conf-2": conf_2,
+ }
+ )
+
+ df = pd.DataFrame(rows)
+ # Add a new column for radio buttons
+ # df["Select"] = [False] * len(df)
+
+ # def add_click_events(df):
+ # df_temp = df.copy()
+ # df_temp["Select"] = False # Temporary column
+ # edited_df = st.data_editor(
+ # df_temp,
+ # column_config=column_configuration,
+ # disabled=df.columns,
+ # )
+ # selected_rows = edited_df[edited_df["Select"]].drop("Select", axis=1)
+ # return selected_rows
+ #
+ # selection = add_click_events(df)
+ # st.write("Selected rows:", selection)
+
+ # https://github.com/streamlit/streamlit/pull/8411
+
+ st.dataframe(
+ df,
+ use_container_width=True,
+ column_config=column_configuration,
+ # on_select="rerun",
+ # selection_mode="single-row",
+ )
+
+ selected_row = df.iloc[df.index]
+ st.text_input("Selected row:", str(selected_row))
+
+ if submit:
+ st.write("Submitted")
+ if canvas_result is not None:
+ st.write(canvas_result.json_data)
+ store_json_object(
+ canvas_result.json_data, os.path.join(output_dir, "canvas.json")
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/workspaces/ocr-diff/canvas_util.py b/workspaces/ocr-diff/canvas_util.py
new file mode 100644
index 00000000..bcd42189
--- /dev/null
+++ b/workspaces/ocr-diff/canvas_util.py
@@ -0,0 +1,96 @@
+from PIL import Image
+from streamlit_drawable_canvas import st_canvas
+
+from marie.utils.resize_image import resize_image_progressive
+
+
+class ImageUtils:
+ def __init__(self):
+ pass
+
+ def read_image(self, uploaded_image, progressive_rescale=False, scale=75):
+ """Read the uploaded image"""
+ raw_image = Image.open(uploaded_image).convert("RGB")
+ print("progressive_rescale", progressive_rescale, scale)
+ if progressive_rescale:
+ raw_image = resize_image_progressive(
+ raw_image,
+ reduction_percent=scale / 100,
+ reductions=2,
+ return_intermediate_states=False,
+ )
+ width, height = raw_image.size
+ return raw_image, (width, height)
+
+ def resize_image_for_canvas(self, raw_image, square=960):
+ """Resize the mask so it fits inside a 544x544 square"""
+ width, height = raw_image.size
+ # return raw_image.resize((width, height)), (width, height)
+
+ return raw_image.resize((square, square)), (square, square)
+
+ def get_canvas(
+ self,
+ resized_image,
+ key="canvas",
+ update_streamlit=True,
+ mode="rect",
+ initial_drawing=None,
+ ):
+ """Retrieves the canvas to receive the bounding boxes
+ Args:
+ resized_image(Image.Image): the resized uploaded image
+ key(str): the key to initiate the canvas component in streamlit
+ """
+ width, height = resized_image.size
+
+ canvas_result = st_canvas(
+ # fill_color="rgba(255,0, 0, 0.1)",
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
+ stroke_width=2,
+ stroke_color="rgba(255,0,0,1)",
+ background_color="rgba(0,0,0,1)",
+ # background_color="",
+ background_image=resized_image,
+ update_streamlit=update_streamlit,
+ height=960,
+ width=960,
+ drawing_mode=mode,
+ key=key,
+ initial_drawing=initial_drawing,
+ )
+ return canvas_result
+
+ def get_resized_boxes(self, canvas_result):
+ """Get the resized boxes from the canvas result"""
+ objects = canvas_result.json_data["objects"]
+ resized_boxes = []
+ for obj in objects:
+ print("object", obj)
+ left, top = int(obj["left"]), int(obj["top"]) # upper left corner
+ width, height = int(obj["width"]), int(
+ obj["height"]
+ ) # box width and height
+ right, bottom = left + width, top + height # lower right corner
+ resized_boxes.append([left, top, right, bottom])
+ return resized_boxes
+
+ def get_raw_boxes(self, resized_boxes, raw_size, resized_size):
+ """Convert the resized boxes to raw boxes"""
+ print("raw_size", raw_size)
+ print("resized_size", resized_size)
+ print("resized_boxes", resized_boxes)
+
+ raw_width, raw_height = raw_size
+ resized_width, resized_height = resized_size
+ raw_boxes = []
+
+ for box in resized_boxes:
+ left, top, right, bottom = box
+ raw_left = int(left * raw_width / resized_width)
+ raw_top = int(top * raw_height / resized_height)
+ raw_right = int(right * raw_width / resized_width)
+ raw_bottom = int(bottom * raw_height / resized_height)
+ box2 = [raw_left, raw_top, raw_right, raw_bottom]
+ raw_boxes.append(box2)
+ return raw_boxes
diff --git a/workspaces/ocr-diff/run.sh b/workspaces/ocr-diff/run.sh
new file mode 100755
index 00000000..44a5ddd7
--- /dev/null
+++ b/workspaces/ocr-diff/run.sh
@@ -0,0 +1 @@
+streamlit run app.py --server.fileWatcherType none --server.enableXsrfProtection=false
\ No newline at end of file
diff --git a/workspaces/ocr-review/app.py b/workspaces/ocr-review/app.py
index 201faec1..7bc3b2ce 100644
--- a/workspaces/ocr-review/app.py
+++ b/workspaces/ocr-review/app.py
@@ -11,8 +11,8 @@
from rich.console import Console
from rich.theme import Theme
-src_dir = os.path.expanduser("~/tmp/ocr-diffs/v3/ocr2")
-output_dir = os.path.expanduser("~/tmp/ocr-diffs/reviewed-v3")
+src_dir = os.path.expanduser("~/tmp/ocr-diffs/v5/ocr1")
+output_dir = os.path.expanduser("~/tmp/ocr-diffs/reviewed-v5")
os.makedirs(output_dir, exist_ok=True)
@@ -52,6 +52,16 @@ def save_image_and_text(image_data, text, base_name):
f.write(text)
+def accept_word(snippet, word, base_name):
+ def callback():
+ # st.write(f"You accepted {word}")
+ st.session_state.selected_word = word
+ save_image_and_text(snippet, word, base_name)
+ st.session_state.current_file_index += 1
+
+ return callback
+
+
def main():
if "current_file_index" not in st.session_state:
st.session_state.current_file_index = 0
@@ -83,42 +93,28 @@ def on_prev():
col1, col2 = st.columns(2)
with col1:
- # if st.button("Previous"):
- # on_prev()
streamlit_shortcuts.button(
- "Previous", on_click=on_prev, shortcut="Shift+ArrowUp"
+ "Previous", on_click=on_prev, shortcut="Shift+Ctrl+ArrowUp"
)
with col2:
- # if st.button("Next"):
- # on_next()
- streamlit_shortcuts.button("Next", on_click=on_next, shortcut="Shift+ArrowDown")
- # add a separator
+ streamlit_shortcuts.button(
+ "Next", on_click=on_next, shortcut="Shift+Ctrl+ArrowDown"
+ )
st.markdown("---")
- # Create three columns
- col1, col2, col3 = st.columns(3)
-
- def accept_word(snippet, word, base_name):
- def callback():
- # st.write(f"You accepted {word}")
- st.session_state.selected_word = word
- save_image_and_text(snippet, word, base_name)
- st.session_state.current_file_index += 1
+ st.text(data["word1"]["text"])
+ st.text(data["word2"]["text"])
- return callback
+ col1, col2, col3 = st.columns(3)
with col1:
word_1 = st.text_input("Word 1", data["word1"]["text"], key="Word1")
streamlit_shortcuts.button(
"Accept Word 1",
on_click=accept_word(data["snippet"], word_1, base_name),
- shortcut="Shift+ArrowLeft",
+ shortcut="Shift+Ctrl+ArrowLeft",
)
- # if st.button("X"):
- # st.write("You accepted Word 1:", word_1)
- # save_image_and_text(data["snippet"], word_1, base_name)
- # st.session_state.current_file_index += 1
with col2:
word_2 = st.text_input("Word 2", data["word2"]["text"], key="Word2")
@@ -126,32 +122,23 @@ def callback():
streamlit_shortcuts.button(
"Accept Word 2",
on_click=accept_word(data["snippet"], word_2, base_name),
- shortcut="Shift+ArrowRight",
+ shortcut="Shift+Ctrl+ArrowRight",
)
- # if st.button("Accept Word 2"):
- # st.write("You accepted Word 2:", word_2)
- # save_image_and_text(data["snippet"], word_2, base_name)
- # st.session_state.current_file_index += 1
-
st.write("Last selected word:", st.session_state.selected_word)
# Add a legend for the shortcut keys
st.markdown(
"""
## Shortcut Keys
- - **Shift + ArrowUp**: Previous
- - **Shift + ArrowDown**: Next
- - **Shift + ArrowLeft**: Accept Word 1
- - **Shift + ArrowRight**: Accept Word 2
+ - **Shift + Ctrl + ArrowUp**: Previous
+ - **Shift + Ctrl + ArrowDown**: Next
+ - **Shift + Ctrl + ArrowLeft**: Accept Word 1
+ - **Shift + Ctrl + ArrowRight**: Accept Word 2
"""
)
with col3:
- st.text(word_1)
- st.text(word_2)
-
- # Create a diff line
diff = difflib.ndiff(word_1, word_2)
diff_line = "\n".join(diff)
st.text("Diff:")
diff --git a/workspaces/template-matching/app-g.py b/workspaces/template-matching/app-g.py
deleted file mode 100644
index 096671b5..00000000
--- a/workspaces/template-matching/app-g.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import gradio as gr
-import torch as torch
-
-use_cuda = torch.cuda.is_available()
-
-# https://github.com/gradio-app/gradio/issues/2316
-ROI_coordinates = {
- 'x_temp': 0,
- 'y_temp': 0,
- 'x_new': 0,
- 'y_new': 0,
- 'clicks': 0,
-}
-
-
-def get_select_coordinates(img, evt: gr.SelectData):
- sections = []
- # update new coordinates
- ROI_coordinates['clicks'] += 1
- ROI_coordinates['x_temp'] = ROI_coordinates['x_new']
- ROI_coordinates['y_temp'] = ROI_coordinates['y_new']
- ROI_coordinates['x_new'] = evt.index[0]
- ROI_coordinates['y_new'] = evt.index[1]
- # compare start end coordinates
- x_start = (
- ROI_coordinates['x_new']
- if (ROI_coordinates['x_new'] < ROI_coordinates['x_temp'])
- else ROI_coordinates['x_temp']
- )
- y_start = (
- ROI_coordinates['y_new']
- if (ROI_coordinates['y_new'] < ROI_coordinates['y_temp'])
- else ROI_coordinates['y_temp']
- )
- x_end = (
- ROI_coordinates['x_new']
- if (ROI_coordinates['x_new'] > ROI_coordinates['x_temp'])
- else ROI_coordinates['x_temp']
- )
- y_end = (
- ROI_coordinates['y_new']
- if (ROI_coordinates['y_new'] > ROI_coordinates['y_temp'])
- else ROI_coordinates['y_temp']
- )
- if ROI_coordinates['clicks'] % 2 == 0:
- # both start and end point get
- sections.append(((x_start, y_start, x_end, y_end), "ROI of Face Detection"))
- return (img, sections)
- else:
- point_width = int(img.shape[0] * 0.05)
- sections.append(
- (
- (
- ROI_coordinates['x_new'],
- ROI_coordinates['y_new'],
- ROI_coordinates['x_new'] + point_width,
- ROI_coordinates['y_new'] + point_width,
- ),
- "Click second point for ROI",
- )
- )
- return (img, sections)
-
-
-def process_image(image):
- return image
-
-
-def interface():
- article = """
- # Zero-shot Template Matching
- """
-
- with gr.Blocks() as iface:
- gr.Markdown(article)
-
- i = gr.Image(source="canvas", shape=(512, 512)).style(width=512, height=512)
- o = gr.Image().style(width=512, height=512)
-
- with gr.Row():
- input_img = gr.Image(label="Click")
- img_output = gr.AnnotatedImage(
- label="ROI",
- color_map={
- "ROI of Face Detection": "#9987FF",
- "Click second point for ROI": "#f44336",
- },
- )
- input_img.select(get_select_coordinates, input_img, img_output)
-
- if False:
- with gr.Row():
- src = gr.Image(type="pil", source="upload")
-
- with gr.Row():
- btn_reset = gr.Button("Clear")
- btn_submit = gr.Button("Submit", variant="primary")
-
- with gr.Row():
- with gr.Column():
- boxes = gr.components.Image(type="pil", label="boxes")
- with gr.Column():
- lines = gr.components.Image(type="pil", label="icr")
-
- with gr.Row():
- with gr.Column():
- txt = gr.components.Textbox(label="text", max_lines=100)
-
- with gr.Row():
- with gr.Column():
- results = gr.components.JSON()
-
- btn_submit.click(
- process_image, inputs=[src], outputs=[boxes, lines, results, txt]
- )
-
- iface.launch(debug=True, share=False, server_name="0.0.0.0")
-
-
-if __name__ == "__main__":
- import torch
-
- torch.set_float32_matmul_precision("high")
- torch.backends.cudnn.benchmark = False
- # torch._dynamo.config.suppress_errors = False
-
- interface()
diff --git a/workspaces/template-matching/app.py b/workspaces/template-matching/app.py
index 9e4db8b8..a3ab61cd 100644
--- a/workspaces/template-matching/app.py
+++ b/workspaces/template-matching/app.py
@@ -8,6 +8,7 @@
import pandas as pd
import streamlit as st
import torch
+from canvas_util import ImageUtils
from PIL import Image
from streamlit_drawable_canvas import st_canvas
@@ -32,91 +33,6 @@ def get_template_matchers():
return matcher, matcher_meta, matcher_vqnnft
-class ImageUtils:
- def __init__(self):
- pass
-
- def read_image(self, uploaded_image, progressive_rescale=False, scale=75):
- """Read the uploaded image"""
- raw_image = Image.open(uploaded_image).convert("RGB")
- print("progressive_rescale", progressive_rescale, scale)
- if progressive_rescale:
- raw_image = resize_image_progressive(
- raw_image,
- reduction_percent=scale / 100,
- reductions=2,
- return_intermediate_states=False,
- )
- width, height = raw_image.size
- return raw_image, (width, height)
-
- def resize_image_for_canvas(self, raw_image, square=960):
- """Resize the mask so it fits inside a 544x544 square"""
- width, height = raw_image.size
- # return raw_image.resize((width, height)), (width, height)
-
- return raw_image.resize((square, square)), (square, square)
-
- def get_canvas(
- self, resized_image, key="canvas", update_streamlit=True, mode="rect"
- ):
- """Retrieves the canvas to receive the bounding boxes
- Args:
- resized_image(Image.Image): the resized uploaded image
- key(str): the key to initiate the canvas component in streamlit
- """
- width, height = resized_image.size
-
- canvas_result = st_canvas(
- # fill_color="rgba(255,0, 0, 0.1)",
- fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
- stroke_width=2,
- stroke_color="rgba(255,0,0,1)",
- background_color="rgba(0,0,0,1)",
- background_image=resized_image,
- update_streamlit=update_streamlit,
- height=960,
- width=960,
- drawing_mode=mode,
- key=key,
- )
- return canvas_result
-
- def get_resized_boxes(self, canvas_result):
- """Get the resized boxes from the canvas result"""
- objects = canvas_result.json_data["objects"]
- resized_boxes = []
- for obj in objects:
- print("object", obj)
- left, top = int(obj["left"]), int(obj["top"]) # upper left corner
- width, height = int(obj["width"]), int(
- obj["height"]
- ) # box width and height
- right, bottom = left + width, top + height # lower right corner
- resized_boxes.append([left, top, right, bottom])
- return resized_boxes
-
- def get_raw_boxes(self, resized_boxes, raw_size, resized_size):
- """Convert the resized boxes to raw boxes"""
- print("raw_size", raw_size)
- print("resized_size", resized_size)
- print("resized_boxes", resized_boxes)
-
- raw_width, raw_height = raw_size
- resized_width, resized_height = resized_size
- raw_boxes = []
-
- for box in resized_boxes:
- left, top, right, bottom = box
- raw_left = int(left * raw_width / resized_width)
- raw_top = int(top * raw_height / resized_height)
- raw_right = int(right * raw_width / resized_width)
- raw_bottom = int(bottom * raw_height / resized_height)
- box2 = [raw_left, raw_top, raw_right, raw_bottom]
- raw_boxes.append(box2)
- return raw_boxes
-
-
@st.cache_resource
def get_ocr_engine() -> OcrEngine:
"""Get the OCR engine"""
diff --git a/workspaces/template-matching/canvas_util.py b/workspaces/template-matching/canvas_util.py
new file mode 100644
index 00000000..576dfb09
--- /dev/null
+++ b/workspaces/template-matching/canvas_util.py
@@ -0,0 +1,89 @@
+from PIL import Image
+from streamlit_drawable_canvas import st_canvas
+
+from marie.utils.resize_image import resize_image_progressive
+
+
+class ImageUtils:
+ def __init__(self):
+ pass
+
+ def read_image(self, uploaded_image, progressive_rescale=False, scale=75):
+ """Read the uploaded image"""
+ raw_image = Image.open(uploaded_image).convert("RGB")
+ print("progressive_rescale", progressive_rescale, scale)
+ if progressive_rescale:
+ raw_image = resize_image_progressive(
+ raw_image,
+ reduction_percent=scale / 100,
+ reductions=2,
+ return_intermediate_states=False,
+ )
+ width, height = raw_image.size
+ return raw_image, (width, height)
+
+ def resize_image_for_canvas(self, raw_image, square=960):
+ """Resize the mask so it fits inside a 544x544 square"""
+ width, height = raw_image.size
+ # return raw_image.resize((width, height)), (width, height)
+
+ return raw_image.resize((square, square)), (square, square)
+
+ def get_canvas(
+ self, resized_image, key="canvas", update_streamlit=True, mode="rect"
+ ):
+ """Retrieves the canvas to receive the bounding boxes
+ Args:
+ resized_image(Image.Image): the resized uploaded image
+ key(str): the key to initiate the canvas component in streamlit
+ """
+ width, height = resized_image.size
+
+ canvas_result = st_canvas(
+ # fill_color="rgba(255,0, 0, 0.1)",
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
+ stroke_width=2,
+ stroke_color="rgba(255,0,0,1)",
+ background_color="rgba(0,0,0,1)",
+ background_image=resized_image,
+ update_streamlit=update_streamlit,
+ height=960,
+ width=960,
+ drawing_mode=mode,
+ key=key,
+ )
+ return canvas_result
+
+ def get_resized_boxes(self, canvas_result):
+ """Get the resized boxes from the canvas result"""
+ objects = canvas_result.json_data["objects"]
+ resized_boxes = []
+ for obj in objects:
+ print("object", obj)
+ left, top = int(obj["left"]), int(obj["top"]) # upper left corner
+ width, height = int(obj["width"]), int(
+ obj["height"]
+ ) # box width and height
+ right, bottom = left + width, top + height # lower right corner
+ resized_boxes.append([left, top, right, bottom])
+ return resized_boxes
+
+ def get_raw_boxes(self, resized_boxes, raw_size, resized_size):
+ """Convert the resized boxes to raw boxes"""
+ print("raw_size", raw_size)
+ print("resized_size", resized_size)
+ print("resized_boxes", resized_boxes)
+
+ raw_width, raw_height = raw_size
+ resized_width, resized_height = resized_size
+ raw_boxes = []
+
+ for box in resized_boxes:
+ left, top, right, bottom = box
+ raw_left = int(left * raw_width / resized_width)
+ raw_top = int(top * raw_height / resized_height)
+ raw_right = int(right * raw_width / resized_width)
+ raw_bottom = int(bottom * raw_height / resized_height)
+ box2 = [raw_left, raw_top, raw_right, raw_bottom]
+ raw_boxes.append(box2)
+ return raw_boxes