Skip to content

Fix issue with physical units expressions #1194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class SympySimpleExpressionPrinter(CppSimpleExpressionPrinter):
"""

def print_simple_expression(self, node: ASTSimpleExpression) -> str:

if node.is_numeric_literal():
return self._constant_printer.print_constant(node.get_numeric_literal())

Expand Down
4 changes: 2 additions & 2 deletions pynestml/generated/PyNestMLLexer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated from PyNestMLLexer.g4 by ANTLR 4.13.2
# Generated from PyNestMLLexer.g4 by ANTLR 4.13.0
from antlr4 import *
from io import StringIO
import sys
Expand Down Expand Up @@ -430,7 +430,7 @@ class PyNestMLLexer(PyNestMLLexerBase):

def __init__(self, input=None, output:TextIO = sys.stdout):
super().__init__(input, output)
self.checkVersion("4.13.2")
self.checkVersion("4.13.0")
self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache())
self._actions = None
self._predicates = None
Expand Down
297 changes: 150 additions & 147 deletions pynestml/generated/PyNestMLParser.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pynestml/generated/PyNestMLParserVisitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated from PyNestMLParser.g4 by ANTLR 4.13.2
# Generated from PyNestMLParser.g4 by ANTLR 4.13.0
from antlr4 import *
if "." in __name__:
from .PyNestMLParser import PyNestMLParser
Expand Down
15 changes: 12 additions & 3 deletions pynestml/grammars/PyNestMLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ parser grammar PyNestMLParser;
* @attribute isInf: True iff, this expression shall represent the value infinity.
**/
simpleExpression : functionCall
| BOOLEAN_LITERAL // true & false;
| (UNSIGNED_INTEGER | FLOAT) (variable)?
| BOOLEAN_LITERAL
| (UNSIGNED_INTEGER | FLOAT) unitType?
| string=STRING_LITERAL
| isInf=INF_KEYWORD
| variable;
Expand All @@ -120,6 +120,15 @@ parser grammar PyNestMLParser;
logicalOperator : logicalAnd=AND_KEYWORD
| logicalOr=OR_KEYWORD;

/**
**/
// physicalUnitExpression : leftParentheses=LEFT_PAREN term=physicalUnitExpression rightParentheses=RIGHT_PAREN
// | <assoc=right> left=physicalUnitExpression powOp=STAR_STAR right=physicalUnitExpression
// | left=physicalUnitExpression (timesOp=STAR | divOp=FORWARD_SLASH) right=physicalUnitExpression
// | physicalUnit;
//
// physicalUnit : name=NAME;

/**
* ASTVariable Provides a 'marker' AST node to identify variables used in expressions.
* @attribute name: The name of the variable without the differential order, e.g. V_m
Expand Down Expand Up @@ -201,7 +210,7 @@ parser grammar PyNestMLParser;

ifStmt : ifClause elifClause* (elseClause)?;

ifClause : IF_KEYWORD expression COLON
ifClause : IF_KEYWORD expression COLON
NEWLINE INDENT stmtsBody DEDENT;

elifClause : ELIF_KEYWORD expression COLON
Expand Down
11 changes: 6 additions & 5 deletions pynestml/meta_model/ast_node_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,16 @@ def create_ast_return_stmt(cls, expression=None, source_position=None):
return ASTReturnStmt(expression, source_position=source_position)

@classmethod
def create_ast_simple_expression(cls, function_call=None, # type: Union(ASTFunctionCall,None)
boolean_literal=None, # type: Union(bool,None)
numeric_literal=None, # type: Union(float,int)
def create_ast_simple_expression(cls, function_call=None, # type: Union[ASTFunctionCall, None]
boolean_literal=None, # type: Union[bool, None]
numeric_literal=None, # type: Union[float, int]
is_inf=False, # type: bool
unitType=None,
variable=None, # type: ASTVariable
string=None, # type: Union(str,None)
string=None, # type: Union[str, None]
source_position=None # type: ASTSourceLocation
): # type: (...) -> ASTSimpleExpression
return ASTSimpleExpression(function_call, boolean_literal, numeric_literal, is_inf, variable, string,
return ASTSimpleExpression(function_call, boolean_literal, numeric_literal, is_inf, variable, string, unitType,
source_position=source_position)

@classmethod
Expand Down
25 changes: 22 additions & 3 deletions pynestml/meta_model/ast_simple_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from typing import List, Optional, Union

from pynestml.meta_model.ast_unit_type import ASTUnitType

from pynestml.meta_model.ast_expression_node import ASTExpressionNode
from pynestml.meta_model.ast_function_call import ASTFunctionCall
from pynestml.meta_model.ast_node import ASTNode
Expand Down Expand Up @@ -52,7 +54,8 @@ class ASTSimpleExpression(ASTExpressionNode):

def __init__(self, function_call: ASTFunctionCall = None, boolean_literal: bool = None,
numeric_literal: Union[int, float] = None, is_inf: bool = False,
variable: ASTVariable = None, string: str = None, has_delay: bool = False, *args, **kwargs):
variable: ASTVariable = None, string: str = None, unitType: ASTUnitType = None,
has_delay: bool = False, *args, **kwargs):
"""
Standard constructor.

Expand Down Expand Up @@ -91,6 +94,7 @@ def __init__(self, function_call: ASTFunctionCall = None, boolean_literal: bool
self.is_inf_literal = is_inf
self.variable = variable
self.string = string
self.unitType = unitType
self.has_delay = has_delay

def clone(self):
Expand All @@ -112,15 +116,15 @@ def clone(self):
boolean_literal = True
if self.is_boolean_false:
boolean_literal = False
assert function_call_dup or (boolean_literal is not None) or (
numeric_literal_dup is not None) or self.is_inf_literal or variable_dup or self.string
assert function_call_dup or (boolean_literal is not None) or (numeric_literal_dup is not None) or self.is_inf_literal or variable_dup or self.string
dup = ASTSimpleExpression(function_call=function_call_dup,
boolean_literal=boolean_literal,
numeric_literal=numeric_literal_dup,
is_inf=self.is_inf_literal,
variable=variable_dup,
string=self.string,
has_delay=self.has_delay,
unitType=self.unitType,
# ASTNode common attributes:
source_position=self.source_position,
scope=self.scope,
Expand Down Expand Up @@ -273,6 +277,18 @@ def get_string(self):
"""
return self.string

def get_unitType(self):
"""
Returns the unitType of the simple expression.
"""
return self.unitType

def set_unitType(self, unitType):
"""
Sets the unitType of the simple expression.
"""
self.unitType = unitType

def get_children(self) -> List[ASTNode]:
r"""
Returns the children of this node, if any.
Expand Down Expand Up @@ -341,4 +357,7 @@ def equals(self, other: ASTNode) -> bool:
if self.get_string() != other.get_string():
return False

if self.get_unitType() != other.get_unitType():
return False

return True
7 changes: 7 additions & 0 deletions pynestml/visitors/ast_builder_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from pynestml.cocos.co_cos_manager import CoCosManager
from pynestml.frontend.frontend_configuration import FrontendConfiguration
from pynestml.generated.PyNestMLParser import PyNestMLParser
from pynestml.generated.PyNestMLParserVisitor import PyNestMLParserVisitor
from pynestml.meta_model.ast_expression import ASTExpression
from pynestml.meta_model.ast_node_factory import ASTNodeFactory
Expand All @@ -34,6 +35,7 @@
from pynestml.utils.logger import Logger
from pynestml.utils.port_signal_type import PortSignalType
from pynestml.visitors.ast_data_type_visitor import ASTDataTypeVisitor
from pynestml.visitors.ast_unit_type_visitor import ASTUnitTypeVisitor
from pynestml.visitors.comment_collector_visitor import CommentCollectorVisitor


Expand Down Expand Up @@ -199,11 +201,16 @@ def visitSimpleExpression(self, ctx):
else:
numeric_literal = None
is_inf = (True if ctx.isInf is not None else False)
unitType = self.visit(ctx.unitType()) if ctx.unitType() is not None else None
if unitType is not None:
unitType.accept(ASTUnitTypeVisitor())

variable = (self.visit(ctx.variable()) if ctx.variable() is not None else None)
string = (str(ctx.string.text) if ctx.string is not None else None)
node = ASTNodeFactory.create_ast_simple_expression(function_call=function_call,
boolean_literal=boolean_literal,
numeric_literal=numeric_literal,
unitType=unitType,
is_inf=is_inf, variable=variable,
string=string,
source_position=create_source_pos(ctx))
Expand Down
93 changes: 6 additions & 87 deletions pynestml/visitors/ast_data_type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,11 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from astropy import units

from pynestml.meta_model.ast_unit_type import ASTUnitType
from pynestml.symbols.error_type_symbol import ErrorTypeSymbol
from pynestml.symbols.predefined_types import PredefinedTypes
from pynestml.symbols.predefined_units import PredefinedUnits
from pynestml.symbols.unit_type_symbol import UnitTypeSymbol
from pynestml.utils.logger import Logger
from pynestml.utils.logger import LoggingLevel
from pynestml.utils.messages import Messages
from pynestml.utils.unit_type import UnitType
from pynestml.visitors.ast_unit_type_visitor import ASTUnitTypeVisitor
from pynestml.visitors.ast_visitor import ASTVisitor


Expand Down Expand Up @@ -67,6 +61,11 @@ def visit_data_type(self, node):
elif node.is_void:
self.symbol = PredefinedTypes.get_void_type()
node.set_type_symbol(self.symbol)
elif node.is_unit_type:
unit_type_visitor = ASTUnitTypeVisitor()
node.get_unit_type().accept(unit_type_visitor)
self.symbol = unit_type_visitor.symbol
node.set_type_symbol(self.symbol)

def endvisit_data_type(self, node):
if node.is_unit_type() and node.get_unit_type().get_type_symbol() is not None:
Expand All @@ -77,83 +76,3 @@ def endvisit_data_type(self, node):
code, message = Messages.astdatatype_type_symbol_could_not_be_derived()
Logger.log_message(None, code, message, node.get_source_position(), LoggingLevel.ERROR)
return

def visit_unit_type(self, node):
"""
Visits a single unit type element, checks for correct usage of units and builds the corresponding combined
unit.
:param node: a single unit type meta_model.
:type node: ASTUnitType
:return: a new type symbol representing this unit type.
:rtype: type_symbol
"""
if node.is_simple_unit():
type_s = PredefinedTypes.get_type(node.unit)
if type_s is None:
code, message = Messages.unknown_type(str(node.unit))
Logger.log_message(None, code, message, node.get_source_position(), LoggingLevel.ERROR)
return

node.set_type_symbol(type_s)
self.symbol = type_s

def endvisit_unit_type(self, node):
if node.is_encapsulated:
node.set_type_symbol(node.compound_unit.get_type_symbol())
elif node.is_pow:
base_symbol = node.base.get_type_symbol()
exponent = node.exponent
astropy_unit = base_symbol.astropy_unit ** exponent
res = handle_unit(astropy_unit)
node.set_type_symbol(res)
self.symbol = res
elif node.is_div:
if isinstance(node.get_lhs(), ASTUnitType): # regard that lhs can be a numeric or a unit-type
lhs = node.get_lhs().get_type_symbol().astropy_unit
else:
lhs = node.get_lhs()
rhs = node.get_rhs().get_type_symbol().astropy_unit
res = lhs / rhs
res = handle_unit(res)
node.set_type_symbol(res)
self.symbol = res
elif node.is_times:
if isinstance(node.get_lhs(), ASTUnitType): # regard that lhs can be a numeric or a unit-type
if node.get_lhs().get_type_symbol() is None or isinstance(node.get_lhs().get_type_symbol(), ErrorTypeSymbol):
node.set_type_symbol(ErrorTypeSymbol())
return
lhs = node.get_lhs().get_type_symbol().astropy_unit
else:
lhs = node.get_lhs()
rhs = node.get_rhs().get_type_symbol().astropy_unit
res = lhs * rhs
res = handle_unit(res)
node.set_type_symbol(res)
self.symbol = res
return


def handle_unit(unit_type):
"""
Handles a handed over unit by creating the corresponding unit-type, storing it in the list of predefined
units, creating a type symbol and returning it.
:param unit_type: astropy unit object
:type unit_type: astropy.units.core.Unit
:return: a new type symbol
:rtype: TypeSymbol
"""
# first ensure that it does not already exists, if not create it and register it in the set of predefined units
# first clean up the unit of not required components, here it is the 1.0 in front of the unit
# e.g., 1.0 * 1 / ms. This step is not mandatory for correctness, but makes reporting easier
if isinstance(unit_type, units.Quantity) and unit_type.value == 1.0:
to_process = unit_type.unit
else:
to_process = unit_type
if str(to_process) not in PredefinedUnits.get_units().keys():
unit_type_t = UnitType(name=str(to_process), unit=to_process)
PredefinedUnits.register_unit(unit_type_t)
# now create the corresponding type symbol if it does not exists
if PredefinedTypes.get_type(str(to_process)) is None:
type_symbol = UnitTypeSymbol(unit=PredefinedUnits.get_unit(str(to_process)))
PredefinedTypes.register_type(type_symbol)
return PredefinedTypes.get_type(name=str(to_process))
13 changes: 8 additions & 5 deletions pynestml/visitors/ast_numeric_literal_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ def visit_simple_expression(self, node):
variable_symbol_resolve = scope.resolve_to_symbol(var_name, SymbolKind.VARIABLE)
if variable_symbol_resolve is not None:
node.type = variable_symbol_resolve.get_type_symbol()
node.type.referenced_object = node
return

if node.get_unitType() is not None:
type_symbol = node.get_unitType().get_type_symbol()
if type_symbol is not None:
node.type = type_symbol
else:
type_symbol_resolve = scope.resolve_to_symbol(var_name, SymbolKind.TYPE)
if type_symbol_resolve is not None:
node.type = type_symbol_resolve
else:
node.type = ErrorTypeSymbol()
node.type = ErrorTypeSymbol()
node.type.referenced_object = node
return

Expand Down
Loading
Loading