diff --git a/trnsysGUI/MassFlowVisualizer.py b/trnsysGUI/MassFlowVisualizer.py index 2ed59f2a..9f292d9e 100644 --- a/trnsysGUI/MassFlowVisualizer.py +++ b/trnsysGUI/MassFlowVisualizer.py @@ -2,12 +2,13 @@ # type: ignore import datetime as _dt -import itertools as _it import PyQt5.QtCore as _qtc import PyQt5.QtWidgets as _qtw +import itertools as _it import numpy as _np import pandas as _pd +import typing as _tp import trnsysGUI.TVentil as _tv import trnsysGUI.connection.connectionBase as _cb @@ -17,6 +18,8 @@ import trnsysGUI.massFlowSolver.names as _mnames import trnsysGUI.massFlowSolver.networkModel as _mfn +_MAX_HEADER_LENGTH = 25 + class MassFlowVisualizer(_qtw.QDialog): def __init__(self, parent, mfrFile, tempFile): @@ -167,11 +170,13 @@ def showMassBtn(self): def loadFile(self): if not self.loadedFile: self.massFlowData = _pd.read_csv(self.dataFilePath, sep="\t").rename(columns=lambda x: x.strip()) + _truncateColumnNames(self.massFlowData) self.loadedFile = True def loadTempFile(self): if not self.tempLoadedFile: self.tempMassFlowData = _pd.read_csv(self.tempDatafilePath, sep="\t").rename(columns=lambda x: x.strip()) + _truncateColumnNames(self.tempMassFlowData) self.tempLoadedFile = True def start(self): @@ -256,11 +261,13 @@ def advance(self): return def _getMassFlow(self, mfrVariableName: str, timeStep: int) -> float: - mass = self.massFlowData[mfrVariableName[:25]].iloc[timeStep] + truncatedMfrVariableName = _truncateName(mfrVariableName) + mass = self.massFlowData[truncatedMfrVariableName].iloc[timeStep] return mass def _getTemperature(self, temperatureVariableName: str, timeStep: int) -> float: - return self.tempMassFlowData[temperatureVariableName[:25]].iloc[timeStep] + truncatedTemperatureVariableName = temperatureVariableName + return self.tempMassFlowData[truncatedTemperatureVariableName].iloc[timeStep] def pauseVis(self): self.paused = True @@ -460,3 +467,29 @@ def keyPressEvent(self, e): elif e.key() == _qtc.Qt.Key_Down: self.logger.debug("Down is pressed") self.decreaseValue() + + +def _truncateColumnNames(df: _pd.DataFrame) -> None: + _ensureNamesDontCollideAfterTruncating(df.columns) + df.columns = [_truncateName(n) for n in df.columns] + + +def _ensureNamesDontCollideAfterTruncating(columnNames: _tp.Sequence[str]) -> None: + sortedColumnNames = sorted(columnNames) + groupedNames = [list(g) for _, g in _it.groupby(sortedColumnNames, key=_truncateName)] + collidingNames = _flatten(g for g in groupedNames if len(g) > 1) + if collidingNames: + formattedCollidingNames = "\n\t".join(collidingNames) + errorMessage = ( + f"The following column names collide after truncating them to " + f"{_MAX_HEADER_LENGTH} characters:\n\t{formattedCollidingNames}" + ) + raise ValueError(errorMessage) + + +def _truncateName(name: str): + return name[:_MAX_HEADER_LENGTH] + + +def _flatten(iterable: _tp.Iterable[_tp.Iterable[str]]) -> _tp.Sequence[str]: + return list(_it.chain.from_iterable(iterable))