diff --git a/parallax/__init__.py b/parallax/__init__.py index 84d8daad..00912fd9 100644 --- a/parallax/__init__.py +++ b/parallax/__init__.py @@ -4,7 +4,7 @@ import os -__version__ = "1.12.0" +__version__ = "1.13.0" # allow multiple OpenMP instances os.environ["KMP_DUPLICATE_LIB_OK"] = "True" diff --git a/parallax/control_panel/probe_calibration_handler.py b/parallax/control_panel/probe_calibration_handler.py index 725d7494..22e00906 100644 --- a/parallax/control_panel/probe_calibration_handler.py +++ b/parallax/control_panel/probe_calibration_handler.py @@ -13,6 +13,8 @@ from parallax.probe_calibration.probe_calibration import ProbeCalibration from parallax.handlers.calculator import Calculator from parallax.handlers.reticle_metadata import ReticleMetadata +from parallax.utils.coords_converter import get_transMs_bregma_to_local +from parallax.utils.probe_angles import find_probe_angle, find_probe_angles_dict logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -25,6 +27,9 @@ class StageCalibrationInfo: """ detection_status: str = "default" # options: default, process, accepted transM: Optional[np.ndarray] = None + transM_bregma: Optional[dict] = None + arc_angle_global: Optional[tuple] = None + arc_angle_bregma: Optional[dict] = None L2_err: Optional[float] = None dist_travel: Optional[np.ndarray] = None status_x: Optional[str] = None @@ -500,6 +505,10 @@ def probe_detect_accepted_status(self, switch_probe=False): return self.probe_detection_status = "accepted" + + # Update reticle selector + self.reticle_metadata.load_metadata_from_file() + # Update into model self.update_stage_info_to_model(self.selected_stage_id) self.model.set_calibration_status(self.selected_stage_id, True) @@ -536,9 +545,6 @@ def probe_detect_accepted_status(self, switch_probe=False): self.filter = "no_filter" logger.debug(f"filter: {self.filter}") - # Update reticle selector - self.reticle_metadata.load_metadata_from_file() - def update_probe_calib_status_transM(self, transformation_matrix): """ @@ -607,6 +613,7 @@ def display_probe_calib_status(self, transM, L2_err, dist_travel): def update_probe_calib_status(self, moving_stage_id, transM, L2_err, dist_travel): """ + Handler for the signal emitted when the probe calibration. (transM_info) Updates the probe calibration status based on the moving stage ID and the provided calibration data. If the selected stage matches the moving stage, the calibration data is displayed on the UI. """ @@ -786,6 +793,14 @@ def update_stage_info_to_model(self, stage_id) -> None: stage_info.status_y = self.calib_status_y stage_info.status_z = self.calib_status_z + # Update transM from bregma if available + transMbs = get_transMs_bregma_to_local(self.model, stage_id) + stage_info.transM_bregma = transMbs + + # Get 3D angle + stage_info.arc_angle_global = find_probe_angle(self.transM) + stage_info.arc_angle_bregma = find_probe_angles_dict(transMbs) + def update_stage_info(self, info): if isinstance(info, StageCalibrationInfo): self.transM = info.transM diff --git a/parallax/handlers/calculator.py b/parallax/handlers/calculator.py index 818439c4..c51d7cbc 100644 --- a/parallax/handlers/calculator.py +++ b/parallax/handlers/calculator.py @@ -11,7 +11,7 @@ from PyQt6.uic import loadUi from PyQt6.QtCore import Qt -from parallax.utils.coords_converter import CoordsConverter +from parallax.utils.coords_converter import local_to_global, global_to_local from parallax.stages.stage_controller import StageController from parallax.config.config_path import ui_dir @@ -167,11 +167,11 @@ def _convert(self, sn): logger.debug(f"User Input (Local): {self.reticle}") trans_type, local_pts, global_pts = self._get_transform_type(globalX, globalY, globalZ, localX, localY, localZ) if trans_type == "global_to_local": - local_pts_ret = CoordsConverter.global_to_local(self.model, sn, global_pts, self.reticle) + local_pts_ret = global_to_local(self.model, sn, global_pts, self.reticle) if local_pts_ret is not None: self._show_local_pts_result(sn, local_pts_ret) elif trans_type == "local_to_global": - global_pts_ret = CoordsConverter.local_to_global(self.model, sn, local_pts, self.reticle) + global_pts_ret = local_to_global(self.model, sn, local_pts, self.reticle) if global_pts_ret is not None: self._show_global_pts_result(sn, global_pts_ret) else: @@ -434,8 +434,8 @@ def _is_z_safe_pos(self, stage_sn, x, y, z): continue try: # Apply transformations to get global points for Z=15 and Z=0 - global_pts_z15 = CoordsConverter.local_to_global(self.model, stage_sn, local_pts_z15) - global_pts_z0 = CoordsConverter.local_to_global(self.model, stage_sn, local_pts_z0) + global_pts_z15 = local_to_global(self.model, stage_sn, local_pts_z15) + global_pts_z0 = local_to_global(self.model, stage_sn, local_pts_z0) if global_pts_z15 is None or global_pts_z0 is None: return False # Transformation failed, return False diff --git a/parallax/handlers/reticle_metadata.py b/parallax/handlers/reticle_metadata.py index 7e77161f..0f2f966c 100644 --- a/parallax/handlers/reticle_metadata.py +++ b/parallax/handlers/reticle_metadata.py @@ -72,6 +72,7 @@ def __init__(self, model, reticle_selector): self.ui.add_btn.clicked.connect(self._add_groupbox) self.ui.update_btn.clicked.connect(self._update_reticle_info) + # Update reticle selector self.model.add_reticle_metadata_instance(self) def load_metadata_from_file(self): @@ -87,7 +88,7 @@ def load_metadata_from_file(self): Exception: If there is an error reading the metadata file, logs the error. """ if not os.path.exists(reticle_metadata_file): - logger.info("No existing metadata file found. Starting fresh.") + logger.debug("No existing metadata file found. Starting fresh.") return try: @@ -98,7 +99,8 @@ def load_metadata_from_file(self): for group_box in self.groupboxes.values(): self._update_reticles(group_box) self._update_to_reticle_selector() - + else: + logger.debug("Metadata json file is empty. Starting fresh.") except Exception as e: logger.error(f"Error reading metadata file: {e}") @@ -111,6 +113,7 @@ def _create_groupbox_from_metadata(self, reticle_data): return # Do not add a new groupbox if it already exists self._populate_groupbox(name, reticle_info) + logger.debug(f"Created groupbox for reticle: {name}") def _add_groupbox(self): """This method creates new groupboxes with an alphabet name.""" diff --git a/parallax/model.py b/parallax/model.py index ac4d34cc..3982e9c4 100755 --- a/parallax/model.py +++ b/parallax/model.py @@ -360,7 +360,7 @@ def get_reticle_metadata(self, reticle_name): Returns: dict: Metadata information for the reticle. """ - return self.reticle_metadata.get(reticle_name) + return self.reticle_metadata.get(reticle_name, None) def remove_reticle_metadata(self, reticle_name): """Remove reticle metadata. diff --git a/parallax/probe_calibration/coords_transformation.py b/parallax/probe_calibration/coords_transformation.py deleted file mode 100644 index 4f2263b5..00000000 --- a/parallax/probe_calibration/coords_transformation.py +++ /dev/null @@ -1,202 +0,0 @@ -""" -This module provides functionality for performing 3D transformations, specifically roll, pitch, -and yaw rotations. It also includes methods for fitting transformation parameters to align measured -points to global points using least squares optimization. - -Classes: - - RotationTransformation: Handles 3D rotations and optimization of transformation parameters - (rotation, translation, and scaling) to fit measured points to global points. -""" -import numpy as np -from scipy.optimize import leastsq - - -class RotationTransformation: - """ - This class provides methods for performing 3D rotations (roll, pitch, and yaw), - extracting angles from a rotation matrix, combining angles into a rotation matrix, - and fitting parameters for transforming measured points to global points through - optimization. - """ - - def __init__(self): - """Initialize the RotationTransformation class.""" - pass - - def roll(self, inputMat, g): # rotation around x axis (bank angle) - """ - Performs a rotation around the x-axis (roll or bank angle). - - Args: - inputMat (numpy.ndarray): The input matrix to be rotated. - g (float): The roll angle in radians. - - Returns: - numpy.ndarray: The resulting matrix after applying the roll rotation. - """ - rollMat = np.array([[1, 0, 0], - [0, np.cos(g), -np.sin(g)], - [0, np.sin(g), np.cos(g)]]) - return np.dot(inputMat, rollMat) - - def pitch(self, inputMat, b): # rotation around y axis (elevation angle) - """ - Performs a rotation around the y-axis (pitch or elevation angle). - - Args: - inputMat (numpy.ndarray): The input matrix to be rotated. - b (float): The pitch angle in radians. - - Returns: - numpy.ndarray: The resulting matrix after applying the pitch rotation. - """ - pitchMat = np.array([[np.cos(b), 0, np.sin(b)], - [0, 1, 0], - [-np.sin(b), 0, np.cos(b)]]) - return np.dot(inputMat, pitchMat) - - def yaw(self, inputMat, a): # rotation around z axis (heading angle) - """ - Performs a rotation around the z-axis (yaw or heading angle). - - Args: - inputMat (numpy.ndarray): The input matrix to be rotated. - a (float): The yaw angle in radians. - - Returns: - numpy.ndarray: The resulting matrix after applying the yaw rotation. - """ - yawMat = np.array([[np.cos(a), -np.sin(a), 0], - [np.sin(a), np.cos(a), 0], - [0, 0, 1]]) - return np.dot(inputMat, yawMat) - - def extractAngles(self, mat): - """ - Extracts roll, pitch, and yaw angles from a given rotation matrix. - - Args: - mat (numpy.ndarray): A 3x3 rotation matrix. - - Returns: - tuple: The roll (x), pitch (y), and yaw (z) angles in radians. - """ - x = np.arctan2(mat[2, 1], mat[2, 2]) - y = np.arctan2(-mat[2, 0], np.sqrt(pow(mat[2, 1], 2) + pow(mat[2, 2], 2))) - z = np.arctan2(mat[1, 0], mat[0, 0]) - return x, y, z - - def combineAngles(self, x, y, z, reflect_z=False): - """ - Combines roll, pitch, and yaw angles into a single rotation matrix. - - Args: - x (float): Roll angle in radians. - y (float): Pitch angle in radians. - z (float): Yaw angle in radians. - reflect_z (bool, optional): If True, applies a reflection along the z-axis. Defaults to False. - - Returns: - numpy.ndarray: The combined 3x3 rotation matrix. - """ - eye = np.identity(3) - R = self.roll( - self.pitch( - self.yaw(eye, z), y), x) - - if reflect_z: - reflection_matrix = np.array([[1, 0, 0], - [0, 1, 0], - [0, 0, -1]]) - R = R @ reflection_matrix - return R - - def func(self, x, measured_pts, global_pts, reflect_z=False): - """ - Defines an error function for optimization, calculating the difference between transformed - global points and measured points. - - Args: - x (numpy.ndarray): The parameters to optimize (angles, translation, and scaling factors). - measured_pts (numpy.ndarray): The measured points (local coordinates). - global_pts (numpy.ndarray): The global points (target coordinates). - reflect_z (bool, optional): If True, applies a reflection along the z-axis. Defaults to False. - - Returns: - numpy.ndarray: The error values for each point. - """ - R = self.combineAngles(x[2], x[1], x[0], reflect_z=reflect_z) - origin = np.array([x[3], x[4], x[5]]).T - - error_values = np.zeros(len(global_pts) * 3) - for i in range(len(global_pts)): - global_pt = global_pts[i, :].T - measured_pt = measured_pts[i, :].T - global_pt_exp = R @ measured_pt + origin - error_values[i * 3: (i + 1) * 3] = global_pt - global_pt_exp - - return error_values - - def avg_error(self, x, measured_pts, global_pts, reflect_z=False): - """ - Calculates the total error (L2 norm) for the optimization. - - Args: - x (numpy.ndarray): The parameters to optimize. - measured_pts (numpy.ndarray): The measured points (local coordinates). - global_pts (numpy.ndarray): The global points (target coordinates). - reflect_z (bool, optional): If True, applies a reflection along the z-axis. Defaults to False. - - Returns: - float: The average L2 error across all points. - """ - error_values = self.func(x, measured_pts, global_pts, reflect_z) - - # Calculate the L2 error for each point - l2_errors = np.zeros(len(global_pts)) - for i in range(len(global_pts)): - error_vector = error_values[i * 3: (i + 1) * 3] - l2_errors[i] = np.linalg.norm(error_vector) - - # Calculate the average L2 error - average_l2_error = np.mean(l2_errors) - - return average_l2_error - - def fit_params(self, measured_pts, global_pts): - """ - Fits the transformation parameters (angles, translation, and scaling) to minimize the error - between measured points and global points using least squares optimization. - - Args: - measured_pts (numpy.ndarray): The measured points (local coordinates). - global_pts (numpy.ndarray): The global points (target coordinates). - - Returns: - tuple: A tuple containing the translation vector (origin), rotation matrix (R), and the average error (avg_err). - """ - x0 = np.array([0, 0, 0, 0, 0, 0]) - - if len(measured_pts) <= 3 or len(global_pts) <= 3: - raise ValueError("At least three points are required for optimization.") - - # Optimize without reflection - res1 = leastsq(self.func, x0, args=(measured_pts, global_pts, False), maxfev=5000) - avg_error1 = self.avg_error(res1[0], measured_pts, global_pts, False) - - # Optimize with reflection - res2 = leastsq(self.func, x0, args=(measured_pts, global_pts, True), maxfev=5000) - avg_error2 = self.avg_error(res2[0], measured_pts, global_pts, True) - - # Select the transformation with the smaller total error - if avg_error1 < avg_error2: - rez = res1[0] - R = self.combineAngles(rez[2], rez[1], rez[0], reflect_z=False) - avg_err = avg_error1 - else: - rez = res2[0] - R = self.combineAngles(rez[2], rez[1], rez[0], reflect_z=True) - avg_err = avg_error1 - - origin = rez[3:6] - return origin, R, avg_err # translation vector, rotation matrix, and scaling factors diff --git a/parallax/probe_calibration/probe_calibration.py b/parallax/probe_calibration/probe_calibration.py index b388a5af..32bab414 100644 --- a/parallax/probe_calibration/probe_calibration.py +++ b/parallax/probe_calibration/probe_calibration.py @@ -11,7 +11,8 @@ import pandas as pd from pathlib import Path from PyQt6.QtCore import QObject, pyqtSignal -from .coords_transformation import RotationTransformation +#from .coords_transformation import RotationTransformation +from .transforms import fit_params from .bundle_adjustment import BALProblem, BALOptimizer from parallax.handlers.point_mesh import PointMesh from parallax.config.config_path import stages_dir @@ -27,7 +28,7 @@ class ProbeCalibration(QObject): by transforming local stage coordinates to global reticle coordinates. Signals: - calib_complete (str, object, np.ndarray): Signal emitted when the full calibration is complete. + calib_complete: Signal emitted when the full calibration is complete. transM_info (str, object, float, object): Signal emitted with transformation matrix information. """ calib_complete = pyqtSignal() @@ -53,7 +54,6 @@ def __init__(self, model, stage_listener): stage_listener (QObject): The stage listener object for receiving stage-related events. """ super().__init__() - self.transformer = RotationTransformation() self.model = model self.stage_listener = stage_listener self.stage_listener.probeCalibRequest.connect(self.update) @@ -182,7 +182,7 @@ def _get_local_global_points(self, df): return local_points, global_points - def _get_l2_distance(self, local_points, global_points): + def _get_l2_distance_deprecated(self, local_points, global_points): """ Compute the L2 distance between the expected global points and the actual global points. @@ -193,7 +193,7 @@ def _get_l2_distance(self, local_points, global_points): Returns: numpy.ndarray: The L2 distance between the points. """ - R, t= self.R, self.origin + R, t = self.R, self.origin global_coords_exp = R @ local_points.T + t.reshape(-1, 1) global_coords_exp = global_coords_exp.T @@ -204,6 +204,28 @@ def _get_l2_distance(self, local_points, global_points): return l2_distance + def _get_l2_distance(self, local_points, global_points): + """ + Compute the L2 distance between the expected local points and the actual local points. + + Args: + local_points (numpy.ndarray): The local points. + global_points (numpy.ndarray): The global points. + + Returns: + numpy.ndarray: The L2 distance between the points. + """ + R, t = self.R, self.origin + local_coords_exp = R @ global_points.T + t.reshape(-1, 1) + local_coords_exp = local_coords_exp.T + + l2_distance = np.linalg.norm(local_points - local_coords_exp, axis=1) + mean_l2_distance = np.mean(l2_distance) + std_l2_distance = np.std(l2_distance) + logger.debug(f"mean_l2_distance: {mean_l2_distance}, std_l2_distance: {std_l2_distance}") + + return l2_distance + def _remove_outliers(self, df, threshold=30): """ Remove outliers based on L2 distance threshold. @@ -245,7 +267,7 @@ def _get_transM_LR_orthogonal(self, local_points, global_points, remove_noise=Tr if len(local_points) <= 3 or len(global_points) <= 3: logger.warning("Not enough points for calibration.") return None - self.origin, self.R, self.avg_err = self.transformer.fit_params(local_points, global_points) + self.origin, self.R, self.avg_err = fit_params(local_points, global_points) transformation_matrix = np.hstack([self.R, self.origin.reshape(-1, 1)]) transformation_matrix = np.vstack([transformation_matrix, [0, 0, 0, 1]]) @@ -261,7 +283,8 @@ def _get_transM(self, df): logger.warning("Not enough points for calibration.") return None - self.origin, self.R, self.avg_err = self.transformer.fit_params(local_points, global_points) + # local = R @ global + t, where local shape and global shape are 3xN. + self.origin, self.R, self.avg_err = fit_params(local_points, global_points) transformation_matrix = np.hstack([self.R, self.origin.reshape(-1, 1)]) transformation_matrix = np.vstack([transformation_matrix, [0, 0, 0, 1]]) @@ -428,14 +451,17 @@ def _is_criteria_met_points_min_max(self, sn): def _apply_transformation(self): """ Applies the calculated transformation matrix to convert a local point to global coordinates. - + local = R @ global + t, where local shape and global shape are {3x1}. + To get local from global: + global = R.T @ (local - t), local, t, and global are {3x1} vectors. + global = (local - t) @ R, local, t, and global are {1x3} vectors. Returns: np.array: The transformed global point. """ local_point = np.array([self.stage.stage_x, self.stage.stage_y, self.stage.stage_z]) - local_point = np.append(local_point, 1) - global_point = np.dot(self.transM_LR, local_point) - return global_point[:3] + t = self.origin + global_point = (local_point - t) @ self.R + return global_point def _update_l2_error_current_point(self): """ @@ -444,7 +470,7 @@ def _update_l2_error_current_point(self): if self.transM_LR is None: return None - transformed_point = self._apply_transformation() + transformed_point_global = self._apply_transformation() global_point = np.array( [ self.stage.stage_x_global, @@ -452,7 +478,7 @@ def _update_l2_error_current_point(self): self.stage.stage_z_global, ] ) - self.LR_err_L2_current = np.linalg.norm(transformed_point - global_point) + self.LR_err_L2_current = np.linalg.norm(transformed_point_global - global_point) return def _is_enough_points(self, df): @@ -567,17 +593,6 @@ def _save_transM_to_csv(self, file_name): # Add the transformation matrix columns to the DataFrame self._save_df_to_csv(df, file_name) - def reshape_array(self): - """ - Reshapes arrays of local and global points for processing. - - Returns: - tuple: Reshaped local and global points arrays. - """ - local_points = np.array(self.local_points) - global_points = np.array(self.global_points) - return local_points.reshape(-1, 1, 3), global_points.reshape(-1, 1, 3) - def _print_formatted_transM(self): """ Prints the transformation matrix in a formatted way, including the rotation matrix, diff --git a/parallax/probe_calibration/transforms.py b/parallax/probe_calibration/transforms.py new file mode 100644 index 00000000..c3a1ab0b --- /dev/null +++ b/parallax/probe_calibration/transforms.py @@ -0,0 +1,123 @@ +import numpy as np +from scipy.optimize import leastsq + +# ---------- Rotation utils ---------- +def _Rx(a): # Roll + c, s = np.cos(a), np.sin(a) + return np.array([[1, 0, 0], + [0, c,-s], + [0, s, c]], float) + +def _Ry(a): # Pitch + c, s = np.cos(a), np.sin(a) + return np.array([[ c, 0, s], + [ 0, 1, 0], + [-s, 0, c]], float) + +def _Rz(a): # Yaw + c, s = np.cos(a), np.sin(a) + return np.array([[ c,-s, 0], + [ s, c, 0], + [ 0, 0, 1]], float) + +def _euler_zyx_to_R(roll_x, pitch_y, yaw_z): + """R = Rz(yaw) @ Ry(pitch) @ Rx(roll) (ZYX / yaw-pitch-roll)""" + return _Rz(yaw_z) @ _Ry(pitch_y) @ _Rx(roll_x) + +def _R_to_euler_zyx(R): + """Inverse of euler_zyx_to_R; returns (roll_x, pitch_y, yaw_z).""" + # Handles standard range; watch for gimbal near |pitch|=pi/2 + sy = -R[2, 0] + cy = np.sqrt(R[2, 1]**2 + R[2, 2]**2) + pitch_y = np.arctan2(sy, cy) + roll_x = np.arctan2(R[2, 1], R[2, 2]) + yaw_z = np.arctan2(R[1, 0], R[0, 0]) + return roll_x, pitch_y, yaw_z + +def _reflect_z(R): + """Reflect along z axis: (x,y,z) -> (x,y,-z).""" + return np.diag([1, 1, -1]) @ R + +def _combineAngles(roll_x, pitch_y, yaw_z, reflect_z=False): + if not reflect_z: + return _euler_zyx_to_R(roll_x, pitch_y, yaw_z) + else: + return _reflect_z(_euler_zyx_to_R(roll_x, pitch_y, yaw_z)) + +def _func(x, measured_pts, global_pts, reflect_z=False): + """ + Defines an error function for optimization, calculating the difference between transformed + global points and measured points. + Args: + x (numpy.ndarray): The parameters to optimize (angles, translation). + measured_pts (numpy.ndarray): The measured points (local coordinates). + global_pts (numpy.ndarray): The global points (target coordinates). + reflect_z (bool, optional): If True, applies a reflection along the z-axis. Defaults to False. + Returns: + numpy.ndarray: The error values for each point. + """ + R = _combineAngles(x[2], x[1], x[0], reflect_z=reflect_z) + origin = np.array([x[3], x[4], x[5]]).T + error_values = np.zeros(len(global_pts) * 3) + for i in range(len(global_pts)): + global_pt = global_pts[i, :].T # Shape: (3, 1) + measured_pt = measured_pts[i, :].T # Shape: (3, 1) + measured_pt_exp = R @ global_pt + origin + error_values[i * 3: (i + 1) * 3] = measured_pt - measured_pt_exp + return error_values + +def avg_error(x, measured_pts, global_pts, reflect_z=False): + """ + Calculates the total error (L2 norm) for the optimization. + Args: + x (numpy.ndarray): The parameters to optimize. + measured_pts (numpy.ndarray): The measured points (local coordinates). + global_pts (numpy.ndarray): The global points (target coordinates). + reflect_z (bool, optional): If True, applies a reflection along the z-axis. Defaults to False. + Returns: + float: The average L2 error across all points. + """ + error_values = _func(x, measured_pts, global_pts, reflect_z) + # Calculate the L2 error for each point + l2_errors = np.zeros(len(global_pts)) + for i in range(len(global_pts)): + error_vector = error_values[i * 3: (i + 1) * 3] + l2_errors[i] = np.linalg.norm(error_vector) + # Calculate the average L2 error + average_l2_error = np.mean(l2_errors) + return average_l2_error + +def fit_params(measured_pts, global_pts): + """ + local = R @ global + t, where local shape and global shape are 3xN. + Fits the transformation parameters (angles, translation) to minimize the error + between measured points and global points using least squares optimization. + Args: + measured_pts (numpy.ndarray): The measured points (local coordinates). rows vector (N,3) + global_pts (numpy.ndarray): The global points (target coordinates). rows vector (N,3) + Returns: + tuple: A tuple containing the translation vector (origin), rotation matrix (R), and the average error (avg_err). + """ + x0 = np.array([0, 0, 0, 0, 0, 0]) + if len(measured_pts) <= 3 or len(global_pts) <= 3: + raise ValueError("At least three points are required for optimization.") + + # Optimize without reflection + res1 = leastsq(_func, x0, args=(measured_pts, global_pts, False), maxfev=5000) + avg_error1 = avg_error(res1[0], measured_pts, global_pts, False) + + # Optimize with reflection + res2 = leastsq(_func, x0, args=(measured_pts, global_pts, True), maxfev=5000) + avg_error2 = avg_error(res2[0], measured_pts, global_pts, True) + + # Select the transformation with the smaller total error + if avg_error1 < avg_error2: + rez = res1[0] + R = _combineAngles(rez[2], rez[1], rez[0], reflect_z=False) + avg_err = avg_error1 + else: + rez = res2[0] + R = _combineAngles(rez[2], rez[1], rez[0], reflect_z=True) + avg_err = avg_error2 + origin = rez[3:6] + return origin, R, avg_err # translation vector, rotation matrix, and scaling factors diff --git a/parallax/stages/stage_controller.py b/parallax/stages/stage_controller.py index e0a3b4ab..86d6df65 100644 --- a/parallax/stages/stage_controller.py +++ b/parallax/stages/stage_controller.py @@ -20,7 +20,7 @@ import numpy as np from typing import Optional from PyQt6.QtCore import QObject, QTimer -from parallax.utils.coords_converter import CoordsConverter +from parallax.utils.coords_converter import global_to_local # Set logger name logger = logging.getLogger(__name__) @@ -235,7 +235,7 @@ def _move_request(self, command: dict) -> None: if command.get("world", None) == "global": # coords_converter unit is um, so convert mm to µm global_pts_um = np.array([x*1000, y*1000, z*1000], dtype=float) - local_pts_um = CoordsConverter.global_to_local(self.model, stage_sn, global_pts_um) + local_pts_um = global_to_local(self.model, stage_sn, global_pts_um) if local_pts_um is None: logger.warning(f"Failed to convert global coordinates to local for stage {stage_sn}.") return diff --git a/parallax/stages/stage_listener.py b/parallax/stages/stage_listener.py index 809c3ccd..dd06379c 100644 --- a/parallax/stages/stage_listener.py +++ b/parallax/stages/stage_listener.py @@ -14,7 +14,7 @@ from PyQt6.QtCore import QObject, QThread, QTimer, pyqtSignal from PyQt6.QtWidgets import QFileDialog -from parallax.utils.coords_converter import CoordsConverter +from parallax.utils.coords_converter import local_to_global, apply_reticle_adjustments, local_to_bregma # Set logger name logger = logging.getLogger(__name__) @@ -63,6 +63,7 @@ class Stage: stage_x_global: Optional[float] = None stage_y_global: Optional[float] = None stage_z_global: Optional[float] = None + stage_bregma: Optional[dict] = None stage_x_offset: float = 0.0 stage_y_offset: float = 0.0 stage_z_offset: float = 0.0 @@ -287,6 +288,8 @@ def handleDataChange(self, probe): """ sn = probe["SerialNumber"] stage = (self.model.stages.get(sn, {}) or {}).get("obj") + is_calib = (self.model.stages.get(sn, {}) or {}).get("is_calib") + calib_info = (self.model.stages.get(sn, {}) or {}).get("calib_info") if not stage: return @@ -301,31 +304,59 @@ def handleDataChange(self, probe): stage.stage_y_offset = probe.get("Stage_YOffset", 0) * 1000 # Convert to um stage.stage_z_offset = 15000 - (probe.get("Stage_ZOffset", 0) * 1000) # Convert to um local_pts = np.array([local_x, local_y, local_z]) - global_pts = CoordsConverter.local_to_global(self.model, sn, local_pts) - if global_pts is not None: - stage.stage_x_global = global_pts[0] - stage.stage_y_global = global_pts[1] - stage.stage_z_global = global_pts[2] - + if is_calib: + global_pts = local_to_global(self.model, sn, local_pts) + if global_pts is not None: + stage.stage_x_global = global_pts[0] + stage.stage_y_global = global_pts[1] + stage.stage_z_global = global_pts[2] + + if is_calib: + bregma_pts = {} + for reticle in self.model.reticle_metadata.keys(): + bregma_pt = apply_reticle_adjustments(self.model, global_pts, reticle=reticle) + #bregma_pt_ = local_to_bregma(self.model, sn, local_pts, reticle=reticle) # for the sanity check + #print(f"{reticle}-bregma_pt: {bregma_pt}, bregma_pt_: {bregma_pt_}") + if bregma_pt is not None: + # make JSON-safe now + bregma_pts[reticle] = np.asarray(bregma_pt, dtype=float).reshape(3,).tolist() + + stage.stage_bregma = bregma_pts + + # Update stage UI # Stage is currently selected one, update into UI if sn == self.stage_ui.get_selected_stage_sn(): self.stage_ui.updateStageLocalCoords() # Update local coords into UI - if global_pts is not None: # If stage is calibrated, + if is_calib: self.stage_ui.updateStageGlobalCoords() # update global coords into UI # Update stage info - self._update_stages_info(stage) - - def _update_stages_info(self, stage): - """Update stage info. + self._update_stages_info(stage, is_calib, calib_info) - Args: - stage (Stage): Stage object. - """ - if stage is None: + def _update_stages_info(self, stage, is_calib, calib_info): + """Update stage info without clobbering existing fields and with sane conditions.""" + if stage is None or not getattr(stage, "sn", None): return - self.stages_info[stage.sn] = self._get_stage_info_json(stage) + # Start from existing info; merge in fresh stage fields instead of overwriting. + info = self.stages_info.get(stage.sn, {}).copy() + base = self._get_stage_info_json(stage) or {} + info.update(base) + + prev_is_calib = info.get("is_calibrated") + status_changed = (prev_is_calib is None) or (bool(is_calib) != prev_is_calib) + + if status_changed: + # Always keep this boolean up to date + info["is_calibrated"] = bool(is_calib) + if is_calib and calib_info is not None: + info["calib_info"] = self._get_calib_info_json(calib_info) + logger.debug(f"Stage {stage.sn} calibrated: {info['calib_info']}") + else: + info["calib_info"] = None + logger.debug(f"Stage {stage.sn} uncalibrated") + + self.stages_info[stage.sn] = info def requestUpdateGlobalDataTransformM(self, sn, transM): """ @@ -438,11 +469,29 @@ def _get_stage_info_json(self, stage): sx, sy, sz = stage.stage_x, stage.stage_y, stage.stage_z gx, gy, gz = stage.stage_x_global, stage.stage_y_global, stage.stage_z_global ox, oy, oz = stage.stage_x_offset, stage.stage_y_offset, stage.stage_z_offset + stage_bregma = stage.stage_bregma def _val_mm(v): """Convert value to mm.""" return round(v * 0.001, 4) if v is not None else None + def _vec_mm(v): + """3-vector (np/list) in µm -> [mm, mm, mm] (rounded).""" + if v is None: + return None + arr = np.asarray(v, dtype=float).reshape(-1) + if arr.size < 3: + return None + return [round(arr[0] * 0.001, 4), + round(arr[1] * 0.001, 4), + round(arr[2] * 0.001, 4)] + + def _bregma_mm(b): + """Dict of reticle -> 3-vector in µm -> mm dict.""" + if not b: + return None + return {str(k): _vec_mm(v) for k, v in b.items() if v is not None} + return { "sn": stage.sn, "name": stage.name, @@ -452,13 +501,38 @@ def _val_mm(v): "global_X": _val_mm(gx), "global_Y": _val_mm(gy), "global_Z": _val_mm(gz), - "relative_X": _val_mm(sx - ox), - "relative_Y": _val_mm(sy - oy), - "relative_Z": _val_mm(sz - oz), + "bregma": _bregma_mm(stage_bregma), + "relative_X": _val_mm(sx - ox) if sx is not None and ox is not None else None, + "relative_Y": _val_mm(sy - oy) if sy is not None and oy is not None else None, + "relative_Z": _val_mm(sz - oz) if sz is not None and oz is not None else None, "yaw": stage.yaw, "pitch": stage.pitch, "roll": stage.roll, - "shank_cnt": stage.shank_cnt, + } + + def _get_calib_info_json(self, calib_info): + def _to_list(x): return None if x is None else np.asarray(x).tolist() + + def _to_mm(M): + if M is None: return None + A = np.asarray(M, float).copy() + A[:3, 3] /= 1000.0 # µm -> mm # TODO replace to mm in entire Parallax model + return A.tolist() + + transM_mm = _to_mm(calib_info.transM) + bregma_mm = {k: _to_mm(v) for k, v in (calib_info.transM_bregma or {}).items()} or None + + return { + "detection_status": calib_info.detection_status, + "transM_global_to_local": transM_mm, + "L2_error": calib_info.L2_err, + "distance_travelled": _to_list(calib_info.dist_travel), + "status_x": calib_info.status_x, + "status_y": calib_info.status_y, + "status_z": calib_info.status_z, + "transM_bregma_to_local": bregma_mm, + "arc_angle_global": calib_info.arc_angle_global, + "arc_angle_bregma": calib_info.arc_angle_bregma, } def _snapshot_stage(self): diff --git a/parallax/utils/coords_converter.py b/parallax/utils/coords_converter.py index ef5091f2..afeafcf0 100644 --- a/parallax/utils/coords_converter.py +++ b/parallax/utils/coords_converter.py @@ -1,7 +1,19 @@ """ -This module provides a class for converting between local and global coordinates -using transformation matrices. +This module provides helpers for converting between local, global, and bregma +coordinates using rigid transformation matrices. + +Conventions +----------- +- Canonical (column-vector) definition used to DEFINE R and t: + local_col = R @ global_col + t # R: (3,3), t: (3,) + +- Row-vector form IMPLEMENTED in this module (all inputs/outputs are row 1x3): + local_row = global_row @ R.T + t + +- Inverse (row-vector): + global_row = (local_row - t) @ R """ + import logging import numpy as np from typing import Optional @@ -11,143 +23,362 @@ logger.setLevel(logging.WARNING) -class CoordsConverter: +def local_to_global(model, sn: str, local_pts: np.ndarray, reticle: Optional[str] = None) -> Optional[np.ndarray]: """ - Converts between local and global coordinates using transformation matrices. It also applies reticle adjustments for specific reticles. + Convert local (1x3 row) -> global (1x3 row) using the stage's transform. + + Canonical (column) definition for reference: + local = R @ global + t + + Row-vector form we compute: + global = (local - t) @ R + + Here, the stage supplies T = [[R, t], [0, 1]] that maps GLOBAL→LOCAL. + We invert that mapping for a single row vector via the row-form above. + + Parameters + ---------- + model : object + Provides `is_calibrated(sn)` and `get_transform(sn)` returning a 4x4 T. + sn : str + Stage serial number. + local_pts : np.ndarray + Local coordinates (µm). Expected shape (3,) or (1,3). Interpreted as row-vector. + reticle : str, optional + If provided, apply per-reticle rotation/offset to the computed GLOBAL coords. + + Returns + ------- + np.ndarray or None + Rounded GLOBAL coordinates (1x3). None if the stage/transform is unavailable. """ + if model.is_calibrated(sn): + T = model.get_transform(sn) # T = [[R, t],[0,1]] for GLOBAL→LOCAL + else: + return None + if T is None: + logger.debug(f"TransM not found for {sn}") + return None - @staticmethod - def local_to_global(model, sn: str, local_pts: np.ndarray, reticle: Optional[str] = None) -> Optional[np.ndarray]: - """ - Converts local coordinates to global coordinates using the transformation matrix. - Args: - sn (str): The serial number of the stage. - local_pts (ndarray): The local coordinates (µm) to convert. - reticle (str, optional): The name of the reticle to apply adjustments for. Defaults to None. - Returns: - ndarray: The global coordinates (µm). - """ - if model.is_calibrated(sn): - transM = model.get_transform(sn) - else: - return None + global_pts = apply_inverse_rigid_transform(T, local_pts) # (local - t) @ R - if transM is None: - logger.debug(f"TransM not found for {sn}") - return None + logger.debug(f"global_to_local {global_pts} -> {local_pts}") + # Optional: reticle adjustment maps GLOBAL ↔ BREGMA for a named reticle + if reticle is not None: + global_pts = apply_reticle_adjustments(model, global_pts, reticle) + return np.round(global_pts, 1) - # Apply transM, convert to homogeneous coordinates, and transform - global_pts = np.dot(transM, np.append(local_pts, 1)) - - logger.debug(f"local_to_global: {local_pts} -> {global_pts[:3]}") - logger.debug(f"R: {transM[:3, :3]}\nT: {transM[:3, 3]}") - - if reticle is not None: - # Apply the reticle offset and rotation adjustment - global_pts = CoordsConverter._apply_reticle_adjustments(model, global_pts[:3], reticle) - - return np.round(global_pts[:3], 1) - - @staticmethod - def global_to_local(model, sn: str, global_pts: np.ndarray, reticle: Optional[str] = None) -> Optional[np.ndarray]: - """ - Applies the inverse transformation to convert global coordinates to local coordinates. - - Args: - sn (str): The serial number of the stage. - global_pts (ndarray): The global coordinates (µm) to convert. - reticle (str, optional): The name of the reticle to apply adjustments for. Defaults to None. - Returns: - ndarray: The transformed local coordinates (µm). - """ - if model.is_calibrated(sn): - transM = model.get_transform(sn) - else: - logger.warning(f"Stage {sn} is not calibrated. Cannot convert global to local coordinates.") - return None - if transM is None: - logger.warning(f"Transformation matrix not found for {sn}") +def apply_inverse_rigid_transform(transM: np.ndarray, local_pts: np.ndarray) -> np.ndarray: + """ + Apply the inverse of a rigid transform to get GLOBAL from LOCAL (row-vector form). + + Canonical (column) : global = R.T @ (local - t) + Row-vector form : global = (local - t) @ R + + Where transM = [[R, t], + [0, 1]] maps GLOBAL→LOCAL in the canonical column form. + + Parameters + ---------- + transM : np.ndarray + 4x4 homogeneous transform (GLOBAL→LOCAL). + local_pts : np.ndarray + Local point as row-vector (3,) or (1,3). + + Returns + ------- + np.ndarray + GLOBAL point as (1,3) row-vector (same math, row form). + """ + assert transM.shape == (4, 4), "transM must be 4x4" + R = transM[:3, :3] + t_row = transM[:3, 3].T # shape (3,) + global_row = (local_pts - t_row) @ R # row-vector inverse + return global_row + + +def global_to_local(model, sn: str, global_pts: np.ndarray, reticle: Optional[str] = None) -> Optional[np.ndarray]: + """ + Convert global (1x3 row) -> local (1x3 row) using the stage's transform. + + Canonical (column) mapping carried by the stage: + local = R @ global + t + + Row-vector implementation: + local = global @ R.T + t + + If a reticle is specified (and not "Global coords"), first undo its + rotation/offset on the incoming GLOBAL point, then apply the stage mapping. + + Parameters + ---------- + model : object + Provides `is_calibrated(sn)` and `get_transform(sn)` returning a 4x4 T. + sn : str + Stage serial number. + global_pts : np.ndarray + Global coordinates (µm). Expected shape (3,) or (1,3). Interpreted as row-vector. + reticle : str, optional + If provided (and not "Global coords"), apply the inverse of the reticle's + rotation/offset to the incoming GLOBAL coords before mapping to LOCAL. + + Returns + ------- + np.ndarray or None + Rounded LOCAL coordinates (1x3). None if the stage/transform is unavailable. + """ + if model.is_calibrated(sn): + T = model.get_transform(sn) # T = [[R, t],[0,1]] for GLOBAL→LOCAL + else: + logger.warning(f"Stage {sn} is not calibrated. Cannot convert global to local coordinates.") + return None + if T is None: + logger.warning(f"Transformation matrix not found for {sn}") + return None + if reticle and reticle != "Global coords": + global_pts = apply_reticle_adjustments_inverse(model, global_pts, reticle) + local_row4 = apply_rigid_transform(T, global_pts) # returns homogeneous 4-vector; see its docstring + return np.round(local_row4[:3], 1) + + +def apply_rigid_transform(transM: np.ndarray, global_pts: np.ndarray) -> np.ndarray: + """ + Apply a rigid transform to map GLOBAL → LOCAL. + + Canonical (column) : local = R @ global + t + Row-vector form : local = global @ R.T + t + + This function uses homogeneous multiplication directly: + [local, 1] = transM @ [global, 1] + + Parameters + ---------- + transM : np.ndarray + 4x4 homogeneous transform [[R, t],[0,1]] mapping GLOBAL→LOCAL. + global_pts : np.ndarray + GLOBAL point as row-vector (3,) or (1,3). + + Returns + ------- + np.ndarray + Homogeneous local vector length-4: [local_x, local_y, local_z, 1]. + (Caller typically slices [:3] to get (1x3) LOCAL.) + """ + assert transM.shape == (4, 4), "transM must be 4x4" + # np.dot(A, b) and A @ b are equivalent for NumPy arrays. + local_h = np.dot(transM, np.append(global_pts, 1)) + return local_h + + +def apply_reticle_adjustments_inverse(model, reticle_global_pts: np.ndarray, reticle: str) -> np.ndarray: + """ + Apply the INVERSE of a reticle's rotation/offset to a GLOBAL point. + + Reticle mapping (canonical column definitions): + bregma = Rm @ global + tm + + Row-vector equivalents: + bregma = global @ Rm.T + tm + global = (bregma - tm) @ Rm + + Here we invert the reticle mapping on a GLOBAL point that was tagged + as 'reticle-global', i.e., we compute: + global = (bregma - tm) @ Rm + where 'bregma' is represented by the input reticle_global_pts. + + Parameters + ---------- + model : object + Provides `get_reticle_metadata(reticle)` with 'rotmat' and offsets. + reticle_global_pts : np.ndarray + GLOBAL coordinates (1x3) but already offset/rotated by reticle metadata. + reticle : str + Reticle name. + + Returns + ------- + np.ndarray + GLOBAL coordinates (1x3) with the reticle's rotation/offset removed. + """ + reticle_global_pts = np.array(reticle_global_pts) + md = model.get_reticle_metadata(reticle) + if not md: + logger.warning(f"Warning: No metadata found for reticle '{reticle}'. Returning original points.") + return np.array([reticle_global_pts[0], reticle_global_pts[1], reticle_global_pts[2]]) + Rm = md.get("rotmat", np.eye(3)) + tm = np.array([ + md.get("offset_x", 0), + md.get("offset_y", 0), + md.get("offset_z", 0) + ]) + # Row-form inverse: global = (bregma - tm) @ Rm + global_row = (reticle_global_pts - tm) @ Rm + return np.array(global_row) + + +def apply_reticle_adjustments(model, global_pts: np.ndarray, reticle: str) -> np.ndarray: + """ + Apply a reticle's rotation/offset to a GLOBAL point. + + Reticle mapping (canonical column): + bregma = Rm @ global + tm + + Row-vector equivalent implemented here: + bregma = global @ Rm.T + tm + + Parameters + ---------- + model : object + Provides `get_reticle_metadata(reticle)` with 'rotmat' and offsets. + global_pts : np.ndarray + GLOBAL coordinates (1x3). + reticle : str + Reticle name. + + Returns + ------- + np.ndarray + Adjusted coordinates (1x3), rounded. + """ + md = model.get_reticle_metadata(reticle) + if not md: + logger.warning(f"Warning: No metadata found for reticle '{reticle}'. Returning original points.") + return np.array([global_pts[0], global_pts[1], global_pts[2]]) + reticle_rot = md.get("rot", 0) # scalar degrees flag used by caller's convention + Rm = md.get("rotmat", np.eye(3)) # (3,3) + tm = np.array([ + md.get("offset_x", 0), + md.get("offset_y", 0), + md.get("offset_z", 0) + ]) + # If metadata says a nonzero 'rot' is present, apply row-vector rotation + # using Rm.T (since local = global @ R.T + t). + if reticle_rot != 0: + global_pts = global_pts @ Rm.T + global_pts = global_pts + tm + return np.round(global_pts, 1) + + +def get_transM_bregma_to_local(model, transM: np.ndarray, reticle: str) -> np.ndarray: + """ + Build Tb (bregma→local) from stage T (global→local) and reticle (Rm, tm). + + Known: + Stage mapping (canonical column): local = R @ global + t + Reticle (canonical column): bregma = Rm @ global + tm + + Compose in row form: + global = (bregma - tm) @ Rm + local = ((bregma - tm) @ Rm) @ R.T + t + = bregma @ (Rm @ R.T) + (t - tm @ Rm @ R.T) + + Identify with local = bregma @ Rb.T + tb: + Rb.T = Rm @ R.T ⇒ Rb = R @ Rm.T + tb = t - tm @ Rm @ R.T + + Returns a 4x4 homogeneous Tb = [[Rb, tb],[0,1]] mapping BREGMA→LOCAL. + + Note + ---- + The “To use it …” example in the original snippet used a column-style multiply + to show the idea. In this module we consistently use row vectors and the + explicit row formulas in other helpers. + """ + md = model.get_reticle_metadata(reticle) + if not md: + logger.warning(f"Warning: No metadata found for reticle '{reticle}'. Returning original points.") + return None + Rm = md.get("rotmat", np.eye(3)) + tm = np.array([ + md.get("offset_x", 0.0), + md.get("offset_y", 0.0), + md.get("offset_z", 0.0) + ], dtype=float) + + # Stage T is GLOBAL→LOCAL in the canonical column view: + # local = R @ global + t + R = transM[:3, :3] + t_row = transM[:3, 3].T # (3,) + Rb = R @ Rm.T + tb = t_row - np.dot(Rb, tm) # tb = t - Rb @ tm + + Tb = np.eye(4, dtype=float) + Tb[:3, :3] = Rb + Tb[:3, 3] = tb + return Tb + +def get_transMs_bregma_to_local(model, sn: str) -> np.ndarray: + """ + Generate per-reticle Tb (bregma→local) 4x4 matrices for a calibrated stage. + + Returns + ------- + dict[str, list] or None + Keys are reticle names, values are 4x4 matrices as nested lists + (JSON-serializable). None if the stage/transform is unavailable. + """ + if not model.is_calibrated(sn): + return None + + T = model.get_transform(sn) + if T is None: + return None + + bregma_to_local_transMs: dict[str, list] = {} + for reticle in model.reticle_metadata.keys(): + Tb = get_transM_bregma_to_local(model, T, reticle) + if Tb is not None: + bregma_to_local_transMs[reticle] = np.asarray(Tb, dtype=float).tolist() + return bregma_to_local_transMs + + +def local_to_bregma(model, sn: str, local_pts: np.ndarray, reticle: Optional[str] = None) -> Optional[np.ndarray]: + """ + Convert local (1x3 row) → bregma (1x3 row) using per-reticle Tb (bregma→local). + + For a given reticle, Tb maps BREGMA→LOCAL (canonical column). In row form, + we invert it with: + bregma = (local - tb) @ Rb + + The function retrieves Tb from the model (either a single matrix or a dict + keyed by reticle), checks its shape, and applies the row-vector inverse. + + Returns rounded (1x3) bregma coordinates or None if unavailable. + """ + calib_info = (model.stages.get(sn, {}) or {}).get("calib_info") + if calib_info is None: + logger.warning(f"Stage {sn} is not calibrated.") + return None + transMbs = getattr(calib_info, "transM_bregma", None) + if transMbs is None: + logger.warning(f"No transM_bregma on stage {sn}.") + return None + if isinstance(transMbs, dict): + if reticle is None: + logger.warning("reticle must be provided when transM_bregma is a dict.") return None + Tb = transMbs.get(reticle) + if Tb is None: + logger.warning(f"No transM_bregma for reticle '{reticle}'.") + return None + else: + Tb = transMbs + + Tb = np.asarray(Tb, dtype=float) + if Tb.shape != (4,4): + logger.warning(f"transMb must be 4x4, got {Tb.shape}.") + return None + + # Option 1 (helper): inverse via row-form helper: global = (local - t) @ R + bregma_pts = apply_inverse_rigid_transform(Tb, local_pts) + + # Explicit row-form (documented) — kept here for clarity with the same result: + Rb = Tb[:3, :3] + tb_row = Tb[:3, 3].T # (3,) + bregma_pts = (local_pts - tb_row) @ Rb + + return np.round(bregma_pts, 1) + + - if reticle and reticle != "Global coords": - global_pts = CoordsConverter._apply_reticle_adjustments_inverse(model, global_pts, reticle) - - # Transpose the 3x3 rotation part - R_T = transM[:3, :3].T - local_pts = np.dot(R_T, global_pts - transM[:3, 3]) - logger.debug(f"global_to_local {global_pts} -> {local_pts}") - logger.debug(f"R.T: {R_T}\nT: {transM[:3, 3]}") - - return np.round(local_pts, 1) - - @staticmethod - def _apply_reticle_adjustments_inverse(model, reticle_global_pts: np.ndarray, reticle: str) -> np.ndarray: - """ - Applies the inverse of the selected reticle's adjustments (rotation and offsets) - to the given global coordinates. - - Args: - global_pts (ndarray): The global coordinates to adjust. - reticle (str): The name of the reticle to apply adjustments for. - Returns: - np.ndarray: The adjusted global coordinates. - """ - # Convert global_point to numpy array if it's not already - reticle_global_pts = np.array(reticle_global_pts) - - # Get the reticle metadata - reticle_metadata = model.get_reticle_metadata(reticle) - - if not reticle_metadata: # Prevent applying adjustments with missing metadata - logger.warning(f"Warning: No metadata found for reticle '{reticle}'. Returning original points.") - return np.array([reticle_global_pts[0], reticle_global_pts[1], reticle_global_pts[2]]) - - # Get rotation matrix (default to identity if not found) - reticle_rotmat = reticle_metadata.get("rotmat", np.eye(3)) - - # Get offset values, default to global point coordinates if not found - reticle_offset = np.array([ - reticle_metadata.get("offset_x", 0), # Default to 0 if no offset is provided - reticle_metadata.get("offset_y", 0), - reticle_metadata.get("offset_z", 0) - ]) - - # Subtract the reticle offset - global_point = reticle_global_pts - reticle_offset - - # Undo the rotation - global_point = np.dot(global_point, reticle_rotmat) - - return np.array(global_point) - - @staticmethod - def _apply_reticle_adjustments(model, global_pts: np.ndarray, reticle: str) -> np.ndarray: - """ - Applies the selected reticle's adjustments (rotation and offsets) to the given global coordinates. - Args: - global_pts (ndarray): The global coordinates to adjust. - reticle (str): The name of the reticle to apply adjustments for. - Returns: - tuple: The adjusted global coordinates (x, y, z). - """ - reticle_metadata = model.get_reticle_metadata(reticle) - - if not reticle_metadata: # Prevent applying adjustments with missing metadata - logger.warning(f"Warning: No metadata found for reticle '{reticle}'. Returning original points.") - return np.array([global_pts[0], global_pts[1], global_pts[2]]) - - reticle_rot = reticle_metadata.get("rot", 0) - reticle_rotmat = reticle_metadata.get("rotmat", np.eye(3)) # Default to identity matrix if not found - reticle_offset = np.array([ - reticle_metadata.get("offset_x", 0), - reticle_metadata.get("offset_y", 0), - reticle_metadata.get("offset_z", 0) - ]) - - if reticle_rot != 0: - # Transpose because points are row vectors - global_pts = global_pts @ reticle_rotmat.T - global_pts = global_pts + reticle_offset - - return np.round(global_pts, 1) diff --git a/parallax/utils/probe_angles.py b/parallax/utils/probe_angles.py new file mode 100644 index 00000000..3ba30dc1 --- /dev/null +++ b/parallax/utils/probe_angles.py @@ -0,0 +1,99 @@ +import logging +import numpy as np +from typing import Any, Optional, Dict +import math + +# Set logger name +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +def find_probe_angles_dict(transM_dict: dict[str, np.ndarray]) -> Optional[dict[str, dict[str, float]]]: + """ + Compute arc angles per reticle. + + Returns + ------- + dict[str, dict[str, float]] | None + {"reticleA": {"rx": , "ry": }, ...} or None if empty input. + """ + if not transM_dict: + return None + + angles_dict: dict[str, dict[str, float]] = {} + for reticle, transM in transM_dict.items(): + angles = find_probe_angle(transM) # -> {"rx":..., "ry":...} | None + if angles is not None: + angles_dict[reticle] = angles + return angles_dict or None + + +def find_probe_angle(transM: Optional[np.ndarray]) -> Optional[dict[str, float]]: + """ + transM: 4x4 transformation matrix from global or bregma to coordinates. + Depending on the context, the result is expressed in that coordinate system. + + Returns + ------- + dict[str, float] | None + {"rx": , "ry": } or None if transM is None/invalid. + """ + z_axis = _find_probe_insertion_vector(transM) + return _vector_to_arc_angles(z_axis) + +def _find_probe_insertion_vector(transM: Optional[np.ndarray]) -> Optional[np.ndarray]: + """Return the probe direction as a 3-vector (GLOBAL/BREGMA frame), or None.""" + if transM is None: + return None + + T = np.asarray(transM, dtype=float) + if T.shape != (4, 4): + logger.warning(f"transM must be 4x4, got {T.shape}.") + return None + + # Third ROW (row-vector convention) equals ez^T @ R + R = T[:3, :3] + vec = R[2, :] # shape (3,) + return vec + + +def _vector_to_arc_angles( + vec: Optional[np.ndarray], + degrees: bool = True, + invert_AP: bool = True, +) -> Optional[dict[str, float]]: + """ + Calculate arc angles for a given 3D direction vector in RAS (x=ML, y=AP, z=DV). + + Returns + ------- + dict[str, float] | None + {"rx": , "ry": } where: + - rx: rotation about x (ML), tilt in AP–DV plane [pitch-like] + - ry: rotation about y (AP), tilt in ML–DV plane [yaw-like] + Returns None if vec is None or zero. + """ + if vec is None: + return None + + v = np.asarray(vec, dtype=float) + if np.linalg.norm(v) == 0: + return None + + # Keep to upper hemisphere so |rx| <= 90° + if np.dot(v, [0.0, 0.0, 1.0]) < 0: + v = -v + + nv = v / np.linalg.norm(v) + + # From vertical: + rx = -np.arcsin(nv[1]) # depends on AP component (rotation about x) + ry = np.arctan2(nv[0], nv[2]) # ML vs DV (rotation about y) + + if degrees: + rx = math.degrees(rx) + ry = math.degrees(ry) + if invert_AP: + rx = -rx + + # JSON-friendly dict + return {"rx": float(rx), "ry": float(ry)} diff --git a/tests/test_coords_converter.py b/tests/test_coords_converter.py index f6b6ce6a..d3af315a 100644 --- a/tests/test_coords_converter.py +++ b/tests/test_coords_converter.py @@ -2,7 +2,12 @@ import numpy as np import pytest -from parallax.utils.coords_converter import CoordsConverter +from parallax.utils.coords_converter import ( + local_to_global, + global_to_local, + apply_reticle_adjustments, + apply_reticle_adjustments_inverse +) class StubModel: @@ -46,7 +51,7 @@ def test_local_to_global_identity_rounding(): T = make_T() model = StubModel(calibrated=True, transM=T) local = np.array([1.234, -5.678, 9.876]) - out = CoordsConverter.local_to_global(model, "SN", local, reticle=None) + out = local_to_global(model, "SN", local, reticle=None) # Should match input rounded to 1 decimal np.testing.assert_allclose(out, np.round(local, 1)) @@ -71,18 +76,18 @@ def test_local_to_global_with_reticle_rotation_and_offset(): base_global = local + np.array([10.0, 20.0, 30.0]) # Apply reticle: row-vector convention uses @ Rz.T; 90° CCW maps (x,y)->(x',y')=(x*0 + y*1, -x*1 + y*0) = (0, -x) # For base_global (110,20,30) -> (20, -110, 30); then add offsets (1,2,3) => (21, -108, 33) - out = CoordsConverter.local_to_global(model, "SN", local, reticle="R1") + out = local_to_global(model, "SN", local, reticle="R1") np.testing.assert_allclose(out, np.array([-19.0, 112.0, 33.0])) def test_local_to_global_not_calibrated_returns_none(): model = StubModel(calibrated=False, transM=make_T()) - assert CoordsConverter.local_to_global(model, "SN", np.array([0, 0, 0])) is None + assert local_to_global(model, "SN", np.array([0, 0, 0])) is None def test_local_to_global_missing_transform_returns_none(): model = StubModel(calibrated=True, transM=None) - assert CoordsConverter.local_to_global(model, "SN", np.array([0, 0, 0])) is None + assert local_to_global(model, "SN", np.array([0, 0, 0])) is None def test_global_to_local_inverse_no_reticle(): @@ -94,7 +99,7 @@ def test_global_to_local_inverse_no_reticle(): # Pick a local point, map forward manually, then invert via API local = np.array([10.0, 4.0, -3.0]) global_fwd = Rz @ local + t - back = CoordsConverter.global_to_local(model, "SN", global_fwd, reticle="Global coords") + back = global_to_local(model, "SN", global_fwd, reticle="Global coords") np.testing.assert_allclose(back, np.round(local, 1)) @@ -121,24 +126,24 @@ def test_global_to_local_with_inverse_reticle(): rotated = pre @ Rz.T adjusted = rotated + np.array([10.0, 0.0, -5.0]) - out = CoordsConverter.global_to_local(model, "SN", adjusted, reticle="R1") + out = global_to_local(model, "SN", adjusted, reticle="R1") np.testing.assert_allclose(out, np.round(pre, 1)) def test_global_to_local_not_calibrated_returns_none(): model = StubModel(calibrated=False, transM=make_T()) - assert CoordsConverter.global_to_local(model, "SN", np.array([0, 0, 0])) is None + assert global_to_local(model, "SN", np.array([0, 0, 0])) is None def test_apply_reticle_adjustments_missing_metadata_returns_original(): model = StubModel(calibrated=True, transM=make_T(), reticle_meta={}) original = np.array([7.0, -8.0, 9.0]) - out = CoordsConverter._apply_reticle_adjustments(model, original, "Unknown") + out = apply_reticle_adjustments(model, original, "Unknown") np.testing.assert_allclose(out, np.round(original, 1)) def test_apply_reticle_adjustments_inverse_missing_metadata_returns_original(): model = StubModel(calibrated=True, transM=make_T(), reticle_meta={}) original = np.array([1.0, 2.0, 3.0]) - out = CoordsConverter._apply_reticle_adjustments_inverse(model, original, "Unknown") + out = apply_reticle_adjustments_inverse(model, original, "Unknown") np.testing.assert_allclose(out, original) diff --git a/tests/test_coords_transformation.py b/tests/test_coords_transformation.py index 3d260196..650b0ea1 100644 --- a/tests/test_coords_transformation.py +++ b/tests/test_coords_transformation.py @@ -1,56 +1,53 @@ import numpy as np import pytest -from parallax.probe_calibration.coords_transformation import RotationTransformation +from parallax.probe_calibration.transforms import fit_params, _roll, _pitch, _yaw, _R_to_euler_zyx, _combineAngles -@pytest.fixture -def transformer(): - return RotationTransformation() -def test_roll(transformer): +def test_roll(): # Test roll rotation around the x-axis input_matrix = np.identity(3) roll_angle = np.pi / 4 # 45 degrees expected_output = np.array([[1, 0, 0], [0, np.sqrt(2) / 2, -np.sqrt(2) / 2], [0, np.sqrt(2) / 2, np.sqrt(2) / 2]]) - output = transformer.roll(input_matrix, roll_angle) + output = _roll(input_matrix, roll_angle) assert np.allclose(output, expected_output), "Roll transformation failed." -def test_pitch(transformer): +def test_pitch(): # Test pitch rotation around the y-axis input_matrix = np.identity(3) pitch_angle = np.pi / 6 # 30 degrees expected_output = np.array([[np.sqrt(3) / 2, 0, 0.5], [0, 1, 0], [-0.5, 0, np.sqrt(3) / 2]]) - output = transformer.pitch(input_matrix, pitch_angle) + output = _pitch(input_matrix, pitch_angle) assert np.allclose(output, expected_output), "Pitch transformation failed." -def test_yaw(transformer): +def test_yaw(): # Test yaw rotation around the z-axis input_matrix = np.identity(3) yaw_angle = np.pi / 3 # 60 degrees expected_output = np.array([[0.5, -np.sqrt(3) / 2, 0], [np.sqrt(3) / 2, 0.5, 0], [0, 0, 1]]) - output = transformer.yaw(input_matrix, yaw_angle) + output = _yaw(input_matrix, yaw_angle) assert np.allclose(output, expected_output), "Yaw transformation failed." -def test_extract_angles(transformer): +def test_extract_angles(): # Test extraction of roll, pitch, yaw from rotation matrix - rotation_matrix = transformer.combineAngles(np.pi / 4, np.pi / 6, np.pi / 3) - roll, pitch, yaw = transformer.extractAngles(rotation_matrix) + rotation_matrix = _combineAngles(np.pi / 4, np.pi / 6, np.pi / 3) + roll, pitch, yaw = _R_to_euler_zyx(rotation_matrix) assert np.isclose(roll, np.pi / 4), f"Expected roll to be {np.pi / 4}, got {roll}" assert np.isclose(pitch, np.pi / 6), f"Expected pitch to be {np.pi / 6}, got {pitch}" assert np.isclose(yaw, np.pi / 3), f"Expected yaw to be {np.pi / 3}, got {yaw}" -def test_fit_params(transformer): +def test_fit_params(): # Test fitting parameters for transformation measured_pts = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) global_pts = np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10], [11, 12, 13]]) - origin, rotation_matrix, avg_err = transformer.fit_params(measured_pts, global_pts) + origin, rotation_matrix, avg_err = fit_params(measured_pts, global_pts) # Expected values based on the simplified test data expected_origin = np.array([1, 1, 1])