Skip to content

Commit

Permalink
chore: Refactor Trajectory class for improved code readability and ef…
Browse files Browse the repository at this point in the history
…ficient multi-processing
  • Loading branch information
KeplerC committed Aug 31, 2024
1 parent eccf7b2 commit e4913b1
Showing 1 changed file with 62 additions and 42 deletions.
104 changes: 62 additions & 42 deletions fog_x/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import h5py
import asyncio
from concurrent.futures import ThreadPoolExecutor
import sys

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -361,7 +362,7 @@ def _load_from_cache(self):

def _load_from_container(self):
"""
Load the container file with the entire VLA trajectory.
Load the container file with the entire VLA trajectory using multi-processing for image streams.
args:
save_to_cache: save the decoded data to the cache file
Expand All @@ -372,8 +373,11 @@ def _load_from_container(self):
Workflow:
- Get schema of the container file.
- Preallocate decoded streams.
- Decode frame by frame and store in the preallocated memory.
- Use multi-processing to decode image streams separately.
- Decode non-image streams in the main process.
- Combine results from all processes.
"""
import multiprocessing as mp

def _get_length_of_stream(container, stream):
"""
Expand All @@ -385,6 +389,25 @@ def _get_length_of_stream(container, stream):
length += 1
return length

def process_image_stream(stream, feature_name, feature_type, length, path, result_queue):
container = av.open(path, mode="r", format="matroska")
np_cache = np.empty((length,) + feature_type.shape, dtype=feature_type.dtype)
feature_length = 0

for packet in container.demux([stream]):
frames = packet.decode()
for frame in frames:
if feature_type.dtype == "float32":
data = frame.to_ndarray(format="gray").reshape(feature_type.shape)
else:
data = frame.to_ndarray(format="rgb24").reshape(feature_type.shape)
np_cache[feature_length] = data
feature_length += 1

container.close()
result_queue.put((feature_name, np_cache[:feature_length]))
os._exit(0)

try:
container_to_get_length = av.open(self.path, mode="r", format="matroska")
except Exception as e:
Expand All @@ -398,11 +421,12 @@ def _get_length_of_stream(container, stream):
container = av.open(self.path, mode="r", format="matroska")
streams = container.streams


# Dictionary to store preallocated numpy arrays
np_cache = {}

# Preallocate memory for the streams in numpy arrays
# Prepare for multi-processing
image_streams = []
other_streams = []
for stream in streams:
feature_name = stream.metadata.get("FEATURE_NAME")
if feature_name is None:
Expand All @@ -412,54 +436,50 @@ def _get_length_of_stream(container, stream):
self.feature_name_to_stream[feature_name] = stream
self.feature_name_to_feature_type[feature_name] = feature_type

logger.debug(
f"Creating a cache for {feature_name} with shape {feature_type.shape}"
)

# Allocate numpy array with shape [None, X, Y, Z] where X, Y, Z are feature dimensions
if feature_type.dtype == "string":
np_cache[feature_name] = np.empty((length,) + feature_type.shape, dtype=object)
if stream.codec_context.codec.name == "h264":
image_streams.append((stream, feature_name, feature_type))
else:
np_cache[feature_name] = np.empty((length,) + feature_type.shape, dtype=feature_type.dtype)
other_streams.append((stream, feature_name, feature_type))
if feature_type.dtype == "string":
np_cache[feature_name] = np.empty((length,) + feature_type.shape, dtype=object)
else:
np_cache[feature_name] = np.empty((length,) + feature_type.shape, dtype=feature_type.dtype)

# Process image streams with multi-processing
result_queue = mp.Queue()
processes = []
for stream, feature_name, feature_type in image_streams:
p = mp.Process(target=process_image_stream, args=(stream, feature_name, feature_type, length, self.path, result_queue))
processes.append(p)
p.start()


# Decode the frames and store them in the preallocated numpy memory
d_feature_length = {feature: 0 for feature in self.feature_name_to_stream}
for packet in container.demux(list(streams)):
# Process other streams in the main process
d_feature_length = {feature: 0 for feature, _, _ in other_streams}
for packet in container.demux([stream for stream, _, _ in other_streams]):
feature_name = packet.stream.metadata.get("FEATURE_NAME")
if feature_name is None:
logger.debug(f"Skipping stream without FEATURE_NAME: {packet.stream}")
continue
feature_type = FeatureType.from_str(packet.stream.metadata.get("FEATURE_TYPE"))

logger.debug(
f"Decoding {feature_name} with shape {feature_type.shape} and dtype {feature_type.dtype} with time {packet.dts}"
)

feature_codec = packet.stream.codec_context.codec.name
if feature_codec == "h264":
frames = packet.decode()
for frame in frames:
if feature_type.dtype == "float32":
data = frame.to_ndarray(format="gray").reshape(feature_type.shape)
else:
data = frame.to_ndarray(format="rgb24").reshape(feature_type.shape)

# Append data to the numpy array
np_cache[feature_name][d_feature_length[feature_name]] = data
d_feature_length[feature_name] += 1
packet_in_bytes = bytes(packet)
if packet_in_bytes:
data = pickle.loads(packet_in_bytes)
np_cache[feature_name][d_feature_length[packet.stream]] = data
d_feature_length[packet.stream] += 1
else:
packet_in_bytes = bytes(packet)
if packet_in_bytes:
# Decode the packet
data = pickle.loads(packet_in_bytes)

# Append data to the numpy array
np_cache[feature_name][d_feature_length[feature_name]] = data
d_feature_length[feature_name] += 1
else:
logger.debug(f"Skipping empty packet: {packet} for {feature_name}")
logger.debug(f"Length of the stream {feature_name} is {d_feature_length[feature_name]}")
logger.debug(f"Skipping empty packet: {packet} for {feature_name}")
container.close()
# Wait for all image processing to complete
# busy join here
for p in processes:
p.join()

# Collect results from image processing
while not result_queue.empty():
feature_name, data = result_queue.get()
np_cache[feature_name] = data

return np_cache

Expand Down

0 comments on commit e4913b1

Please sign in to comment.