From 47e349156a7bfe12daf371b9408b895ebac43844 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20St=C3=B6lzle?= Date: Sat, 18 May 2024 08:27:12 +0200 Subject: [PATCH] Fix bug in implementation of `potential_energy_fn` for planar_pcs system --- pyproject.toml | 2 +- src/jsrm/systems/planar_pcs.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 626a8c5..acb5a91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ name = "jsrm" # Required # # For a discussion on single-sourcing the version, see # https://packaging.python.org/guides/single-sourcing-package-version/ -version = "0.0.8" # Required +version = "0.0.9" # Required # This is a one-line description or tagline of what your project does. This # corresponds to the "Summary" metadata field: diff --git a/src/jsrm/systems/planar_pcs.py b/src/jsrm/systems/planar_pcs.py index 2b6bbed..14bee58 100644 --- a/src/jsrm/systems/planar_pcs.py +++ b/src/jsrm/systems/planar_pcs.py @@ -111,6 +111,9 @@ def select_params_for_lambdify(params: Dict[str, Array]) -> List[Array]: G_lambda = sp.lambdify( params_syms_cat + sym_exps["state_syms"]["xi"], sym_exps["exps"]["G"], "jax" ) + U_lambda = sp.lambdify( + params_syms_cat + sym_exps["state_syms"]["xi"], sym_exps["exps"]["U"], "jax" + ) compute_stiffness_matrix_for_all_segments_fn = vmap( compute_planar_stiffness_matrix, in_axes=(0, 0, 0, 0), out_axes=0 @@ -300,9 +303,8 @@ def potential_energy_fn(params: Dict[str, Array], q: Array, eps: float = 1e4 * g U_K = (xi - xi_eq).T @ K @ (xi - xi_eq) # evaluate K(xi) = K @ xi # gravitational potential energy - U_G = sp.Matrix([[0]]) params_for_lambdify = select_params_for_lambdify(params) - U_G = G_lambda(*params_for_lambdify, *xi_epsed).squeeze() @ xi_epsed + U_G = U_lambda(*params_for_lambdify, *xi_epsed) # total potential energy U = (U_G + U_K).squeeze()