Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add stability-ai image generation in workflow #933

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions inference/core/workflows/core_steps/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@
from inference.core.workflows.core_steps.models.foundation.segment_anything2.v1 import (
SegmentAnything2BlockV1,
)
from inference.core.workflows.core_steps.models.foundation.stability_ai.image_gen.v1 import (
StabilityAIImageGenBlockV1,
)
from inference.core.workflows.core_steps.models.foundation.stability_ai.inpainting.v1 import (
StabilityAIInpaintingBlockV1,
)
Expand Down Expand Up @@ -572,6 +575,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]:
SIFTComparisonBlockV2,
SegmentAnything2BlockV1,
StabilityAIInpaintingBlockV1,
StabilityAIImageGenBlockV1,
StabilizeTrackedDetectionsBlockV1,
StitchImagesBlockV1,
StitchOCRDetectionsBlockV1,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import base64
import uuid
from typing import List, Literal, Optional, Type, Union

import cv2
import numpy as np
import requests
from pydantic import ConfigDict, Field

from inference.core.workflows.execution_engine.entities.base import (
ImageParentMetadata,
OutputDefinition,
WorkflowImageData,
)
from inference.core.workflows.execution_engine.entities.types import (
FLOAT_ZERO_TO_ONE_KIND,
IMAGE_KIND,
SECRET_KIND,
STRING_KIND,
Selector,
)
from inference.core.workflows.prototypes.block import (
BlockResult,
WorkflowBlock,
WorkflowBlockManifest,
)

LONG_DESCRIPTION = """
The block wraps [Stability AI image generation API](https://platform.stability.ai/docs/api-reference#tag/Generate) and let users generate new images from text, or create variations of existing images.
"""

SHORT_DESCRIPTION = (
"generate new images from text, or create variations of existing images."
)

API_HOST = "https://api.stability.ai"
ENDPOINT = {
"ultra": "/v2beta/stable-image/generate/ultra",
"core": "/v2beta/stable-image/generate/core",
"sd3": "/v2beta/stable-image/generate/sd3",
}


class BlockManifest(WorkflowBlockManifest):
model_config = ConfigDict(
json_schema_extra={
"name": "Stability AI Image Generation",
"version": "v1",
"short_description": SHORT_DESCRIPTION,
"long_description": LONG_DESCRIPTION,
"license": "Apache-2.0",
"block_type": "model",
"search_keywords": [
"Stability AI",
"stability.ai",
"image variation",
"image generation",
],
"ui_manifest": {
"section": "model",
"icon": "far fa-palette",
},
}
)
type: Literal["roboflow_core/stability_ai_image_gen@v1"]
image: Selector(kind=[IMAGE_KIND]) = Field(
description="The image to use as the starting point for the generation.",
examples=["$inputs.image"],
default=None,
)
strength: Union[float, Selector(kind=[FLOAT_ZERO_TO_ONE_KIND])] = Field(
description="controls how much influence the image parameter has on the generated image. A value of 0 would yield an image that is identical to the input. A value of 1 would be as if you passed in no image at all.",
default=0.3,
examples=[0.3, "$inputs.strength"],
)
prompt: Union[
Selector(kind=[STRING_KIND]),
Selector(kind=[STRING_KIND]),
str,
] = Field(
description="Prompt to generate new images from text (what you wish to see)",
examples=["my prompt", "$inputs.prompt"],
)
negative_prompt: Optional[
Union[
Selector(kind=[STRING_KIND]),
Selector(kind=[STRING_KIND]),
str,
]
] = Field(
default=None,
description="Negative prompt to image generation model (what you do not wish to see)",
examples=["my prompt", "$inputs.prompt"],
)
model: Optional[
Union[
Selector(kind=[STRING_KIND]),
Selector(kind=[STRING_KIND]),
str,
]
] = Field(
default="core",
description="choose one of {'core', 'ultra', 'sd3'}. Default 'core' ",
examples=["my prompt", "$inputs.prompt"],
)
api_key: Union[Selector(kind=[STRING_KIND, SECRET_KIND]), str] = Field(
description="Your Stability AI API key",
examples=["xxx-xxx", "$inputs.stability_ai_api_key"],
private=True,
)

@classmethod
def describe_outputs(cls) -> List[OutputDefinition]:
return [
OutputDefinition(name="image", kind=[IMAGE_KIND]),
]

@classmethod
def get_execution_engine_compatibility(cls) -> Optional[str]:
return ">=1.4.0,<2.0.0"


class StabilityAIImageGenBlockV1(WorkflowBlock):
@classmethod
def get_manifest(cls) -> Type[WorkflowBlockManifest]:
return BlockManifest

def run(
self,
prompt: str,
negative_prompt: str,
model: str,
api_key: str,
image: WorkflowImageData,
strength: float = 0.3,
) -> BlockResult:
request_data = {
"prompt": prompt,
"output_format": "jpeg",
}
files_to_send = {"none": ""}
if image is not None:
encoded_image = numpy_array_to_jpeg_bytes(image=image.numpy_image)
files_to_send = {
"image": encoded_image,
}
request_data["strength"] = strength

if negative_prompt is not None:
request_data["negative_prompt"] = negative_prompt
if model not in ENDPOINT.keys():
model = "core"
response = requests.post(
f"{API_HOST}{ENDPOINT[model]}",
headers={"authorization": f"Bearer {api_key}", "accept": "image/*"},
files=files_to_send,
data=request_data,
)
if response.status_code != 200:
raise RuntimeError(
f"Request to StabilityAI API failed: {str(response.json())}"
)
new_image_base64 = base64.b64encode(response.content).decode("utf-8")
parent_metadata = ImageParentMetadata(parent_id=str(uuid.uuid1()))
return {
deependujha marked this conversation as resolved.
Show resolved Hide resolved
"image": WorkflowImageData(parent_metadata, base64_image=new_image_base64),
}


def numpy_array_to_jpeg_bytes(
image: np.ndarray,
) -> bytes:
_, img_encoded = cv2.imencode(".jpg", image)
return np.array(img_encoded).tobytes()