|
9 | 9 |
|
10 | 10 | from typing import TYPE_CHECKING
|
11 | 11 |
|
12 |
| -from omc3.optics_measurements.constants import DISPERSION |
| 12 | +import numpy as np |
| 13 | +from pandas import Series |
| 14 | + |
| 15 | +from omc3.optics_measurements.constants import DISPERSION, MOMENTUM_DISPERSION |
13 | 16 | from omc3.segment_by_segment import math as sbs_math
|
14 | 17 | from omc3.segment_by_segment.propagables.abstract import Propagable
|
15 | 18 | from omc3.segment_by_segment.propagables.phase import Phase
|
@@ -119,3 +122,97 @@ def _compute_elements(self, plane: str, seg_model: pd.DataFrame, forward: bool):
|
119 | 122 | model_phase = Phase.get_segment_phase(seg_model, plane, forward)
|
120 | 123 | propagated_err = sbs_math.propagate_error_dispersion(model_disp, model_phase, init_condition)
|
121 | 124 | return model_disp, propagated_err
|
| 125 | + |
| 126 | + |
| 127 | +class MomentumDispersion(Propagable): |
| 128 | + |
| 129 | + _init_pattern = "dp{}_{}" # format(plane, ini/end) |
| 130 | + columns: PropagableColumns = PropagableColumns(MOMENTUM_DISPERSION) |
| 131 | + |
| 132 | + @classmethod |
| 133 | + def get_at(cls, names: IndexType, meas: OpticsMeasurement, plane: str) -> ValueErrorType: |
| 134 | + c = cls.columns.planed(plane) |
| 135 | + momentum_dispersion = meas.dispersion[plane].loc[names, c.column] |
| 136 | + model_momentum_dispersion_err = np.nan # No error in the measured dpx,dpy |
| 137 | + return momentum_dispersion, model_momentum_dispersion_err |
| 138 | + |
| 139 | + @classmethod |
| 140 | + def in_measurement(cls, meas: OpticsMeasurement) -> bool: |
| 141 | + """ Check if the dispersion is in the measurement data. """ |
| 142 | + try: |
| 143 | + meas.dispersion_x |
| 144 | + meas.dispersion_y |
| 145 | + except FileNotFoundError: |
| 146 | + return False |
| 147 | + return True |
| 148 | + |
| 149 | + def init_conditions_dict(self): |
| 150 | + # needs to be inverted for backward propagation, i.e. the end-init |
| 151 | + init_cond = super().init_conditions_dict() |
| 152 | + for key, value in init_cond.items(): |
| 153 | + if "end" in key: |
| 154 | + init_cond[key] = -value |
| 155 | + return init_cond |
| 156 | + |
| 157 | + def add_differences(self, segment_diffs: SegmentDiffs): |
| 158 | + dfs = self.get_difference_dataframes() |
| 159 | + for plane, df in dfs.items(): |
| 160 | + # save to diffs/write to file (if allow_write is set) |
| 161 | + segment_diffs.momentum_dispersion[plane] = df |
| 162 | + |
| 163 | + def _compute_measured(self, |
| 164 | + plane: str, |
| 165 | + seg_model: TfsDataFrame, |
| 166 | + forward: bool |
| 167 | + ) -> tuple[pd.Series, pd.Series]: |
| 168 | + """ Compute the momentum dispersion difference between the given segment model and the measured values.""" |
| 169 | + |
| 170 | + # get the measured values |
| 171 | + names = self.get_segment_observation_points(plane) |
| 172 | + momentum_dispersion, err_pdisp = self.get_at(names, self._meas, plane) |
| 173 | + |
| 174 | + # get the propagated values |
| 175 | + model_momentum_dispersion = seg_model.loc[names, f"{MOMENTUM_DISPERSION}{plane}"] |
| 176 | + |
| 177 | + if not forward: |
| 178 | + model_momentum_dispersion = -model_momentum_dispersion |
| 179 | + |
| 180 | + # calculate difference |
| 181 | + momentum_dispersion_diff = momentum_dispersion - model_momentum_dispersion |
| 182 | + |
| 183 | + propagated_err = Series(np.nan, index=model_momentum_dispersion.index) # No error propagation for now |
| 184 | + return momentum_dispersion_diff, propagated_err |
| 185 | + |
| 186 | + def _compute_correction( |
| 187 | + self, |
| 188 | + plane: str, |
| 189 | + seg_model: pd.DataFrame, |
| 190 | + seg_model_corr: pd.DataFrame, |
| 191 | + forward: bool, |
| 192 | + ) -> tuple[pd.Series, pd.Series]: |
| 193 | + """Compute the momentum dispersion differennce between the nominal and the corrected model.""" |
| 194 | + |
| 195 | + model_momentum_dispersion = seg_model.loc[:, f"{MOMENTUM_DISPERSION}{plane}"] |
| 196 | + corrected_momentum_dispersion = seg_model_corr.loc[:, f"{MOMENTUM_DISPERSION}{plane}"] |
| 197 | + |
| 198 | + momentum_dispersion_diff = corrected_momentum_dispersion - model_momentum_dispersion |
| 199 | + if not forward: |
| 200 | + momentum_dispersion_diff = -momentum_dispersion_diff |
| 201 | + |
| 202 | + propagated_err = Series(np.nan, index=model_momentum_dispersion.index) # No error propagation for now |
| 203 | + return momentum_dispersion_diff, propagated_err |
| 204 | + |
| 205 | + def _compute_elements(self, plane: str, seg_model: pd.DataFrame, forward: bool): |
| 206 | + """ Compute get the propagated momentum dispersion values from the segment model and calculate the propagated error. """ |
| 207 | + |
| 208 | + model_momentum_dispersion = seg_model.loc[:, f"{MOMENTUM_DISPERSION}{plane}"] |
| 209 | + |
| 210 | + pderror = Series(np.nan, index=model_momentum_dispersion.index) # No error propagation for now |
| 211 | + return model_momentum_dispersion, pderror |
| 212 | + |
| 213 | + def get_segment_observation_points(self, plane: str): |
| 214 | + """ Return the measurement points for the given plane, that are in the segment. """ |
| 215 | + return common_indices( |
| 216 | + self.segment_models.forward.index, |
| 217 | + self._meas.dispersion[plane].index |
| 218 | + ) |
0 commit comments