Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add video segmentation feature #784

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,26 @@ See the examples notebooks on [using SAM with prompts](/notebooks/predictor_exam
<img src="assets/notebook2.png?raw=true" width="48.9%" />
</p>

## Video Segmentation

To use the new video segmentation feature, follow these steps:

1. Import the necessary modules and initialize the SAM model and predictor:

```
from segment_anything import SamPredictor, sam_model_registry, segment_video
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
predictor = SamPredictor(sam)
```

2. Call the `segment_video` function with the path to your video file and the predictor:

```
segment_video("<path/to/video>", predictor)
```

This will read the video frames, segment objects using SAM, and display the segmented frames.

## ONNX Export

SAM's lightweight mask decoder can be exported to ONNX format so that it can be run in any environment that supports ONNX runtime, such as in-browser as showcased in the [demo](https://segment-anything.com/demo). Export the model with
Expand Down
18 changes: 12 additions & 6 deletions demo/src/App.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.

// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

import { InferenceSession, Tensor } from "onnxruntime-web";
import React, { useContext, useEffect, useState } from "react";
import "./assets/scss/App.scss";
Expand All @@ -13,6 +7,7 @@ import { onnxMaskToImage } from "./components/helpers/maskUtils";
import { modelData } from "./components/helpers/onnxModelAPI";
import Stage from "./components/Stage";
import AppContext from "./components/hooks/createContext";
import { segment_video } from "segment_anything/predictor";
const ort = require("onnxruntime-web");
/* @ts-ignore */
import npyjs from "npyjs";
Expand All @@ -30,6 +25,7 @@ const App = () => {
} = useContext(AppContext)!;
const [model, setModel] = useState<InferenceSession | null>(null); // ONNX model
const [tensor, setTensor] = useState<Tensor | null>(null); // Image embedding tensor
const [isVideo, setIsVideo] = useState<boolean>(false); // State variable to handle video input

// The ONNX model expects the input to be rescaled to 1024.
// The modelScale state variable keeps track of the scale values.
Expand Down Expand Up @@ -124,6 +120,16 @@ const App = () => {
}
};

const handleVideoUpload = async (videoFile: File) => {
setIsVideo(true);
const videoURL = URL.createObjectURL(videoFile);
const videoElement = document.createElement("video");
videoElement.src = videoURL;
videoElement.onloadeddata = () => {
segment_video(videoElement, model);
};
};

return <Stage />;
};

Expand Down
58 changes: 32 additions & 26 deletions demo/src/components/Tool.tsx
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.

// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

import React, { useContext, useEffect, useState } from "react";
import AppContext from "./hooks/createContext";
import { ToolProps } from "./helpers/Interfaces";
import * as _ from "underscore";

const Tool = ({ handleMouseMove }: ToolProps) => {
interface ToolProps {
handleMouseMove: (e: any) => void;
isVideo: boolean;
}

const Tool = ({ handleMouseMove, isVideo }: ToolProps) => {
const {
image: [image],
maskImg: [maskImg, setMaskImg],
} = useContext(AppContext)!;

// Determine if we should shrink or grow the images to match the
// width or the height of the page and setup a ResizeObserver to
// monitor changes in the size of the page
const [shouldFitToWidth, setShouldFitToWidth] = useState(true);
const bodyEl = document.body;
const fitToPage = () => {
Expand All @@ -44,27 +40,37 @@ const Tool = ({ handleMouseMove }: ToolProps) => {
const imageClasses = "";
const maskImageClasses = `absolute opacity-40 pointer-events-none`;

// Render the image and the predicted mask image on top
return (
<>
{image && (
<img
{isVideo ? (
<video
onMouseMove={handleMouseMove}
onMouseOut={() => _.defer(() => setMaskImg(null))}
onTouchStart={handleMouseMove}
src={image.src}
className={`${
shouldFitToWidth ? "w-full" : "h-full"
} ${imageClasses}`}
></img>
)}
{maskImg && (
<img
src={maskImg.src}
className={`${
shouldFitToWidth ? "w-full" : "h-full"
} ${maskImageClasses}`}
></img>
src={image?.src}
className={`${shouldFitToWidth ? "w-full" : "h-full"} ${imageClasses}`}
autoPlay
loop
muted
></video>
) : (
<>
{image && (
<img
onMouseMove={handleMouseMove}
onMouseOut={() => _.defer(() => setMaskImg(null))}
onTouchStart={handleMouseMove}
src={image.src}
className={`${shouldFitToWidth ? "w-full" : "h-full"} ${imageClasses}`}
></img>
)}
{maskImg && (
<img
src={maskImg.src}
className={`${shouldFitToWidth ? "w-full" : "h-full"} ${maskImageClasses}`}
></img>
)}
</>
)}
</>
);
Expand Down
32 changes: 32 additions & 0 deletions segment_anything/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import torch
import cv2

from segment_anything.modeling import Sam

Expand Down Expand Up @@ -267,3 +268,34 @@ def reset_image(self) -> None:
self.orig_w = None
self.input_h = None
self.input_w = None


def segment_video(video_path: str, sam_predictor: SamPredictor) -> None:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video {video_path}")
return

while cap.isOpened():
ret, frame = cap.read()
if not ret:
break

frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
sam_predictor.set_image(frame_rgb)

# Example of using the predictor with some dummy points
point_coords = np.array([[100, 100], [200, 200]])
point_labels = np.array([1, 0])
masks, _, _ = sam_predictor.predict(point_coords=point_coords, point_labels=point_labels)

# Process masks as needed, e.g., overlay on the frame
for mask in masks:
frame[mask > 0] = [0, 255, 0] # Example: color mask area green

cv2.imshow('Segmented Video', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break

cap.release()
cv2.destroyAllWindows()