22# MIT Licence, see details in top-level file: LICENCE
33
44"""
5- Classes for parameterizing a trajectory in SE3 with B-splines.
6-
7- Copies parts of the API from scipy's B-spline class.
5+ Classes for parameterizing a trajectory in SE3 with splines.
86"""
97
10- from typing import Any , Dict , List , Optional
11- from scipy . interpolate import BSpline
12- from spatialmath import SE3
13- import numpy as np
8+ from abc import ABC , abstractmethod
9+ from functools import cached_property
10+ from typing import List , Optional , Tuple , Set
11+
1412import matplotlib .pyplot as plt
15- from spatialmath .base .transforms3d import tranimate , trplot
13+ import numpy as np
14+ from scipy .interpolate import BSpline , CubicSpline
15+ from scipy .spatial .transform import Rotation , RotationSpline
16+
17+ from spatialmath import SE3 , SO3 , Twist3
18+ from spatialmath .base .transforms3d import tranimate
19+
20+
21+ class SplineSE3 (ABC ):
22+ def __init__ (self ) -> None :
23+ self .control_poses : SE3
24+
25+ @abstractmethod
26+ def __call__ (self , t : float ) -> SE3 :
27+ pass
28+
29+ def visualize (
30+ self ,
31+ sample_times : List [float ],
32+ input_trajectory : Optional [List [SE3 ]] = None ,
33+ pose_marker_length : float = 0.2 ,
34+ animate : bool = False ,
35+ repeat : bool = True ,
36+ ax : Optional [plt .Axes ] = None ,
37+ ) -> None :
38+ """Displays an animation of the trajectory with the control poses against an optional input trajectory.
39+
40+ Args:
41+ sample_times: which times to sample the spline at and plot
42+ """
43+ if ax is None :
44+ fig = plt .figure (figsize = (10 , 10 ))
45+ ax = fig .add_subplot (projection = "3d" )
46+
47+ samples = [self (t ) for t in sample_times ]
48+ if not animate :
49+ pos = np .array ([pose .t for pose in samples ])
50+ ax .plot (
51+ pos [:, 0 ], pos [:, 1 ], pos [:, 2 ], "c" , linewidth = 1.0
52+ ) # plot spline fit
53+
54+ pos = np .array ([pose .t for pose in self .control_poses ])
55+ ax .plot (pos [:, 0 ], pos [:, 1 ], pos [:, 2 ], "r*" ) # plot control_poses
56+
57+ if input_trajectory is not None :
58+ pos = np .array ([pose .t for pose in input_trajectory ])
59+ ax .plot (
60+ pos [:, 0 ], pos [:, 1 ], pos [:, 2 ], "go" , fillstyle = "none"
61+ ) # plot compare to input poses
62+
63+ if animate :
64+ tranimate (
65+ samples , length = pose_marker_length , wait = True , repeat = repeat
66+ ) # animate pose along trajectory
67+ else :
68+ plt .show ()
69+
70+
71+ class InterpSplineSE3 (SplineSE3 ):
72+ """Class for an interpolated trajectory in SE3, as a function of time, through control_poses with a cubic spline.
73+
74+ A combination of scipy.interpolate.CubicSpline and scipy.spatial.transform.RotationSpline (itself also cubic)
75+ under the hood.
76+ """
77+
78+ _e = 1e-12
79+
80+ def __init__ (
81+ self ,
82+ timepoints : List [float ],
83+ control_poses : List [SE3 ],
84+ * ,
85+ normalize_time : bool = False ,
86+ bc_type : str = "not-a-knot" , # not-a-knot is scipy default; None is invalid
87+ ) -> None :
88+ """Construct a InterpSplineSE3 object
89+
90+ Extends the scipy CubicSpline object
91+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.CubicSpline.html#cubicspline
92+
93+ Args :
94+ timepoints : list of times corresponding to provided poses
95+ control_poses : list of SE3 objects that govern the shape of the spline.
96+ normalize_time : flag to map times into the range [0, 1]
97+ bc_type : boundary condition provided to scipy CubicSpline backend.
98+ string options: ["not-a-knot" (default), "clamped", "natural", "periodic"].
99+ For tuple options and details see the scipy docs link above.
100+ """
101+ super ().__init__ ()
102+ self .control_poses = control_poses
103+ self .timepoints = np .array (timepoints )
104+
105+ if self .timepoints [- 1 ] < self ._e :
106+ raise ValueError (
107+ "Difference between start and end timepoints is less than {self._e}"
108+ )
109+
110+ if len (self .control_poses ) != len (self .timepoints ):
111+ raise ValueError ("Length of control_poses and timepoints must be equal." )
112+
113+ if len (self .timepoints ) < 2 :
114+ raise ValueError ("Need at least 2 data points to make a trajectory." )
115+
116+ if normalize_time :
117+ self .timepoints = self .timepoints - self .timepoints [0 ]
118+ self .timepoints = self .timepoints / self .timepoints [- 1 ]
119+
120+ self .spline_xyz = CubicSpline (
121+ self .timepoints ,
122+ np .array ([pose .t for pose in self .control_poses ]),
123+ bc_type = bc_type ,
124+ )
125+ self .spline_so3 = RotationSpline (
126+ self .timepoints ,
127+ Rotation .from_matrix (np .array ([(pose .R ) for pose in self .control_poses ])),
128+ )
129+
130+ def __call__ (self , t : float ) -> SE3 :
131+ """Compute function value at t.
132+ Return:
133+ pose: SE3
134+ """
135+ return SE3 .Rt (t = self .spline_xyz (t ), R = self .spline_so3 (t ).as_matrix ())
136+
137+ def derivative (self , t : float ) -> Twist3 :
138+ linear_vel = self .spline_xyz .derivative ()(t )
139+ angular_vel = self .spline_so3 (
140+ t , 1
141+ ) # 1 is angular rate, 2 is angular acceleration
142+ return Twist3 (linear_vel , angular_vel )
143+
144+
145+ class SplineFit :
146+ """A general class to fit various SE3 splines to data."""
147+
148+ def __init__ (
149+ self ,
150+ time_data : List [float ],
151+ pose_data : List [SE3 ],
152+ ) -> None :
153+ self .time_data = time_data
154+ self .pose_data = pose_data
155+ self .spline : Optional [SplineSE3 ] = None
156+
157+ def stochastic_downsample_interpolation (
158+ self ,
159+ epsilon_xyz : float = 1e-3 ,
160+ epsilon_angle : float = 1e-1 ,
161+ normalize_time : bool = True ,
162+ bc_type : str = "not-a-knot" ,
163+ check_type : str = "local"
164+ ) -> Tuple [InterpSplineSE3 , List [int ]]:
165+ """
166+ Uses a random dropout to downsample a trajectory with an interpolated spline. Keeps the start and
167+ end points of the trajectory. Takes a random order of the remaining indices, and then checks the error bound
168+ of just that point if check_type=="local", checks the error of the whole trajectory is check_type=="global".
169+ Local is **much** faster.
170+
171+ Return:
172+ downsampled interpolating spline,
173+ list of removed indices from input data
174+ """
175+
176+ interpolation_indices = list (range (len (self .pose_data )))
177+
178+ # randomly attempt to remove poses from the trajectory
179+ # always keep the start and end
180+ removal_choices = interpolation_indices .copy ()
181+ removal_choices .remove (0 )
182+ removal_choices .remove (len (self .pose_data ) - 1 )
183+ np .random .shuffle (removal_choices )
184+ for candidate_removal_index in removal_choices :
185+ interpolation_indices .remove (candidate_removal_index )
186+
187+ self .spline = InterpSplineSE3 (
188+ [self .time_data [i ] for i in interpolation_indices ],
189+ [self .pose_data [i ] for i in interpolation_indices ],
190+ normalize_time = normalize_time ,
191+ bc_type = bc_type ,
192+ )
193+
194+ sample_time = self .time_data [candidate_removal_index ]
195+ if check_type is "local" :
196+ angular_error = SO3 (self .pose_data [candidate_removal_index ]).angdist (
197+ SO3 (self .spline .spline_so3 (sample_time ).as_matrix ())
198+ )
199+ euclidean_error = np .linalg .norm (
200+ self .pose_data [candidate_removal_index ].t - self .spline .spline_xyz (sample_time )
201+ )
202+ elif check_type is "global" :
203+ angular_error = self .max_angular_error ()
204+ euclidean_error = self .max_euclidean_error ()
205+ else :
206+ raise ValueError (f"check_type must be 'local' of 'global', is { check_type } ." )
207+
208+ if (angular_error > epsilon_angle ) or (euclidean_error > epsilon_xyz ):
209+ interpolation_indices .append (candidate_removal_index )
210+ interpolation_indices .sort ()
16211
212+ self .spline = InterpSplineSE3 (
213+ [self .time_data [i ] for i in interpolation_indices ],
214+ [self .pose_data [i ] for i in interpolation_indices ],
215+ normalize_time = normalize_time ,
216+ bc_type = bc_type ,
217+ )
218+
219+ return self .spline , interpolation_indices
220+
221+ def max_angular_error (self ) -> float :
222+ return np .max (self .angular_errors ())
223+
224+ def angular_errors (self ) -> List [float ]:
225+ return [
226+ pose .angdist (self .spline (t ))
227+ for pose , t in zip (self .pose_data , self .time_data )
228+ ]
229+
230+ def max_euclidean_error (self ) -> float :
231+ return np .max (self .euclidean_errors ())
17232
18- class BSplineSE3 :
233+ def euclidean_errors (self ) -> List [float ]:
234+ return [
235+ np .linalg .norm (pose .t - self .spline (t ).t )
236+ for pose , t in zip (self .pose_data , self .time_data )
237+ ]
238+
239+
240+ class BSplineSE3 (SplineSE3 ):
19241 """A class to parameterize a trajectory in SE3 with a 6-dimensional B-spline.
20242
21243 The SE3 control poses are converted to se3 twists (the lie algebra) and a B-spline
@@ -39,9 +261,9 @@ def __init__(
39261 - degree: int that controls degree of the polynomial that governs any given point on the spline.
40262 - knots: list of floats that govern which control points are active during evaluating the spline
41263 at a given t input. If none, they are automatically, uniformly generated based on number of control poses and
42- degree of spline.
264+ degree of spline on the range [0,1] .
43265 """
44-
266+ super (). __init__ ()
45267 self .control_poses = control_poses
46268
47269 # a matrix where each row is a control pose as a twist
@@ -74,32 +296,3 @@ def __call__(self, t: float) -> SE3:
74296 """
75297 twist = np .hstack ([spline (t ) for spline in self .splines ])
76298 return SE3 .Exp (twist )
77-
78- def visualize (
79- self ,
80- num_samples : int ,
81- length : float = 1.0 ,
82- repeat : bool = False ,
83- ax : Optional [plt .Axes ] = None ,
84- kwargs_trplot : Dict [str , Any ] = {"color" : "green" },
85- kwargs_tranimate : Dict [str , Any ] = {"wait" : True },
86- kwargs_plot : Dict [str , Any ] = {},
87- ) -> None :
88- """Displays an animation of the trajectory with the control poses."""
89- out_poses = [self (t ) for t in np .linspace (0 , 1 , num_samples )]
90- x = [pose .x for pose in out_poses ]
91- y = [pose .y for pose in out_poses ]
92- z = [pose .z for pose in out_poses ]
93-
94- if ax is None :
95- fig = plt .figure (figsize = (10 , 10 ))
96- ax = fig .add_subplot (projection = "3d" )
97-
98- trplot (
99- [np .array (self .control_poses )], ax = ax , length = length , ** kwargs_trplot
100- ) # plot control points
101- ax .plot (x , y , z , ** kwargs_plot ) # plot x,y,z trajectory
102-
103- tranimate (
104- out_poses , repeat = repeat , length = length , ** kwargs_tranimate
105- ) # animate pose along trajectory
0 commit comments