Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to speed up #309

Merged
merged 5 commits into from
Nov 3, 2023
Merged

Refactor to speed up #309

merged 5 commits into from
Nov 3, 2023

Conversation

iory
Copy link
Owner

@iory iory commented Nov 3, 2023

from datetime import datetime

import numpy as np
from skrobot.coordinates.math import matrix2quaternion


class RobotDataset(object):

    def __init__(self, robot_model):
        self.robot_model = robot_model

    def sample(self, n):
        robot_model = self.robot_model
        ndof = len(robot_model.joint_max_angles)

        j_max = robot_model.joint_max_angles.copy()
        j_min = robot_model.joint_min_angles.copy()
        j_max[j_max == np.inf] = 2 * np.pi
        j_min[j_min == -np.inf] = - 2 * np.pi
        return np.random.uniform(low=j_min,
                                 high=j_max,
                                 size=(n, ndof))

    def end_coords_from_samples(self, samples):
        robot_model = self.robot_model
        end_coords_list = []
        from tqdm import tqdm
        for s in tqdm(samples):
            robot_model.angle_vector(s)
            end_coords = robot_model.end_coords
            t_xyz = end_coords.worldpos()
            q_wxyz = matrix2quaternion(end_coords.worldrot())
            end_coords_list.append(
                np.concatenate(
                    [t_xyz, q_wxyz]))
        return np.array(end_coords_list)


if __name__ == '__main__':
    from skrobot.models import PR2
    robot_model = PR2()
    robot_dataset = RobotDataset(robot_model.rarm)

    n = 100000
    start = datetime.now()
    samples = robot_dataset.sample(n)
    end_coords_array = robot_dataset.end_coords_from_samples(samples)
    end = datetime.now()
    print(end - start)

0:00:10.482203 -> 0:00:06.810289

@iory iory merged commit 26a18f5 into master Nov 3, 2023
6 checks passed
@iory iory deleted the refactor branch November 3, 2023 05:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant