Skip to content

Commit

Permalink
Merge pull request cython#437 from robertwb/cpdef-enums
Browse files Browse the repository at this point in the history
More complete cpdef enums.
  • Loading branch information
robertwb committed Sep 16, 2015
2 parents 0f33914 + 9a80574 commit b9f3159
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 51 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Features added
* The embedded C code comments that show the original source code
can be discarded with the new directive ``emit_code_comments=False``.

* Cpdef enums are now first-class iterable, callable types in Python.

* Posix declarations for DLL loading and stdio extensions were added.
Patch by Lars Buitinck.

Expand Down
1 change: 1 addition & 0 deletions Cython/Compiler/Code.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ class LazyUtilityCode(UtilityCodeBase):
Utility code that calls a callback with the root code writer when
available. Useful when you only have 'env' but not 'code'.
"""
__name__ = '<lazy>'

def __init__(self, callback):
self.callback = callback
Expand Down
79 changes: 45 additions & 34 deletions Cython/Compiler/ExprNodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,7 +1927,12 @@ def check_identifier_kind(self):
entry = self.entry
if entry.is_type and entry.type.is_extension_type:
self.type_entry = entry
if not (entry.is_const or entry.is_variable
if entry.is_type and entry.type.is_enum:
py_entry = Symtab.Entry(self.name, None, py_object_type)
py_entry.is_pyglobal = True
py_entry.scope = self.entry.scope
self.entry = py_entry
elif not (entry.is_const or entry.is_variable
or entry.is_builtin or entry.is_cfunction
or entry.is_cpp_class):
if self.entry.as_variable:
Expand Down Expand Up @@ -6060,7 +6065,7 @@ def infer_type(self, env):
node = self.analyse_as_cimported_attribute_node(env, target=False)
if node is not None:
return node.entry.type
node = self.analyse_as_unbound_cmethod_node(env)
node = self.analyse_as_type_attribute(env)
if node is not None:
return node.entry.type
obj_type = self.obj.infer_type(env)
Expand All @@ -6087,7 +6092,7 @@ def analyse_types(self, env, target = 0):
self.initialized_check = env.directives['initializedcheck']
node = self.analyse_as_cimported_attribute_node(env, target)
if node is None and not target:
node = self.analyse_as_unbound_cmethod_node(env)
node = self.analyse_as_type_attribute(env)
if node is None:
node = self.analyse_as_ordinary_attribute_node(env, target)
assert node is not None
Expand All @@ -6111,45 +6116,51 @@ def analyse_as_cimported_attribute_node(self, env, target):
return self.as_name_node(env, entry, target)
return None

def analyse_as_unbound_cmethod_node(self, env):
def analyse_as_type_attribute(self, env):
# Try to interpret this as a reference to an unbound
# C method of an extension type or builtin type. If successful,
# creates a corresponding NameNode and returns it, otherwise
# returns None.
if self.obj.is_string_literal:
return
type = self.obj.analyse_as_type(env)
if type and (type.is_extension_type or type.is_builtin_type or type.is_cpp_class):
entry = type.scope.lookup_here(self.attribute)
if entry and (entry.is_cmethod or type.is_cpp_class and entry.type.is_cfunction):
if type.is_builtin_type:
if not self.is_called:
# must handle this as Python object
return None
ubcm_entry = entry
else:
# Create a temporary entry describing the C method
# as an ordinary function.
if entry.func_cname and not hasattr(entry.type, 'op_arg_struct'):
cname = entry.func_cname
if entry.type.is_static_method:
ctype = entry.type
elif type.is_cpp_class:
error(self.pos, "%s not a static member of %s" % (entry.name, type))
ctype = PyrexTypes.error_type
else:
# Fix self type.
ctype = copy.copy(entry.type)
ctype.args = ctype.args[:]
ctype.args[0] = PyrexTypes.CFuncTypeArg('self', type, 'self', None)
if type:
if type.is_extension_type or type.is_builtin_type or type.is_cpp_class:
entry = type.scope.lookup_here(self.attribute)
if entry and (entry.is_cmethod or type.is_cpp_class and entry.type.is_cfunction):
if type.is_builtin_type:
if not self.is_called:
# must handle this as Python object
return None
ubcm_entry = entry
else:
cname = "%s->%s" % (type.vtabptr_cname, entry.cname)
ctype = entry.type
ubcm_entry = Symtab.Entry(entry.name, cname, ctype)
ubcm_entry.is_cfunction = 1
ubcm_entry.func_cname = entry.func_cname
ubcm_entry.is_unbound_cmethod = 1
return self.as_name_node(env, ubcm_entry, target=False)
# Create a temporary entry describing the C method
# as an ordinary function.
if entry.func_cname and not hasattr(entry.type, 'op_arg_struct'):
cname = entry.func_cname
if entry.type.is_static_method:
ctype = entry.type
elif type.is_cpp_class:
error(self.pos, "%s not a static member of %s" % (entry.name, type))
ctype = PyrexTypes.error_type
else:
# Fix self type.
ctype = copy.copy(entry.type)
ctype.args = ctype.args[:]
ctype.args[0] = PyrexTypes.CFuncTypeArg('self', type, 'self', None)
else:
cname = "%s->%s" % (type.vtabptr_cname, entry.cname)
ctype = entry.type
ubcm_entry = Symtab.Entry(entry.name, cname, ctype)
ubcm_entry.is_cfunction = 1
ubcm_entry.func_cname = entry.func_cname
ubcm_entry.is_unbound_cmethod = 1
return self.as_name_node(env, ubcm_entry, target=False)
elif type.is_enum:
if self.attribute in type.values:
return self.as_name_node(env, env.lookup(self.attribute), target=False)
else:
error(self.pos, "%s not a known value of %s" % (self.attribute, type))
return None

def analyse_as_type(self, env):
Expand Down
14 changes: 12 additions & 2 deletions Cython/Compiler/Nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,14 +1488,24 @@ def declare(self, env):
self.entry = env.declare_enum(self.name, self.pos,
cname = self.cname, typedef_flag = self.typedef_flag,
visibility = self.visibility, api = self.api,
create_wrapper = self.create_wrapper)
create_wrapper = self.create_wrapper and self.name is None)

def analyse_declarations(self, env):
if self.items is not None:
if self.in_pxd and not env.in_cinclude:
self.entry.defined_in_pxd = 1
for item in self.items:
item.analyse_declarations(env, self.entry)
if self.name is not None:
self.entry.type.values = set(
(item.name) for item in self.items)
if self.create_wrapper and self.name is not None:
from .UtilityCode import CythonUtilityCode
env.use_utility_code(CythonUtilityCode.load(
"EnumType", "CpdefEnums.pyx",
context={"name": self.name,
"items": tuple(item.name for item in self.items)},
outer_module_scope=env.global_scope()))

def analyse_expressions(self, env):
return self
Expand Down Expand Up @@ -1535,7 +1545,7 @@ def analyse_declarations(self, env, enum_entry):
entry = env.declare_const(self.name, enum_entry.type,
self.value, self.pos, cname = self.cname,
visibility = enum_entry.visibility, api = enum_entry.api,
create_wrapper = enum_entry.create_wrapper)
create_wrapper = enum_entry.create_wrapper and enum_entry.name is None)
enum_entry.enum_values.append(entry)
if enum_entry.name:
enum_entry.type.values.append(entry.cname)
Expand Down
2 changes: 1 addition & 1 deletion Cython/Compiler/Optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2097,7 +2097,7 @@ def _handle_method(self, node, type_name, attr_name, function,
entry=type_entry,
type=type_entry.type),
attribute=attr_name,
is_called=True).analyse_as_unbound_cmethod_node(self.current_env())
is_called=True).analyse_as_type_attribute(self.current_env())
if method is None:
return node
args = node.args
Expand Down
26 changes: 26 additions & 0 deletions Cython/Compiler/Pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,37 @@ def use_utility_code_definitions(scope, target, seen=None):
elif entry.as_module:
use_utility_code_definitions(entry.as_module, target, seen)

def sort_utility_codes(utilcodes):
ranks = {}
def get_rank(utilcode):
if utilcode not in ranks:
ranks[utilcode] = 0 # prevent infinite recursion on circular dependencies
original_order = len(ranks)
ranks[utilcode] = 1 + min([get_rank(dep) for dep in utilcode.requires or ()] or [-1]) + original_order * 1e-8
return ranks[utilcode]
for utilcode in utilcodes:
get_rank(utilcode)
return [utilcode for utilcode, _ in sorted(ranks.items(), key=lambda kv: kv[1])]

def normalize_deps(utilcodes):
deps = {}
for utilcode in utilcodes:
deps[utilcode] = utilcode
def unify_dep(dep):
if dep in deps:
return deps[dep]
else:
deps[dep] = dep
return dep
for utilcode in utilcodes:
utilcode.requires = [unify_dep(dep) for dep in utilcode.requires or ()]

def inject_utility_code_stage_factory(context):
def inject_utility_code_stage(module_node):
module_node.prepare_utility_code()
use_utility_code_definitions(context.cython_scope, module_node.scope)
module_node.scope.utility_code_list = sort_utility_codes(module_node.scope.utility_code_list)
normalize_deps(module_node.scope.utility_code_list)
added = []
# Note: the list might be extended inside the loop (if some utility code
# pulls in other utility code, explicitly or implicitly)
Expand Down
13 changes: 13 additions & 0 deletions Cython/Compiler/UtilityCode.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ def scope_transform(module_node):
pipeline = Pipeline.insert_into_pipeline(pipeline, scope_transform,
before=transform)

for dep in self.requires:
if (isinstance(dep, CythonUtilityCode)
and hasattr(dep, 'tree')
and not cython_scope):
def scope_transform(module_node):
module_node.scope.merge_in(dep.tree.scope)
return module_node

transform = ParseTreeTransforms.AnalyseDeclarationsTransform
pipeline = Pipeline.insert_into_pipeline(pipeline, scope_transform,
before=transform)

if self.outer_module_scope:
# inject outer module between utility code module and builtin module
def scope_transform(module_node):
Expand All @@ -158,6 +170,7 @@ def scope_transform(module_node):

(err, tree) = Pipeline.run_pipeline(pipeline, tree, printtree=False)
assert not err, err
self.tree = tree
return tree

def put_code(self, output):
Expand Down
63 changes: 63 additions & 0 deletions Cython/Utility/CpdefEnums.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#################### EnumBase ####################

cimport cython

cdef extern from *:
int PY_VERSION_HEX

cdef object __Pyx_OrderedDict
if PY_VERSION_HEX >= 0x02070000:
from collections import OrderedDict as __Pyx_OrderedDict
else:
__Pyx_OrderedDict = dict

@cython.internal
cdef class __Pyx_EnumMeta(type):
def __init__(cls, name, parents, dct):
type.__init__(cls, name, parents, dct)
cls.__members__ = __Pyx_OrderedDict()
def __iter__(cls):
return iter(cls.__members__.values())
def __getitem__(cls, name):
return cls.__members__[name]

# @cython.internal
cdef type __Pyx_EnumBase
class __Pyx_EnumBase(int):
__metaclass__ = __Pyx_EnumMeta
def __new__(cls, value, name=None):
for v in cls:
if v == value:
return v
if name is None:
raise ValueError("Unknown enum value: '%s'" % value)
res = int.__new__(cls, value)
res.name = name
setattr(cls, name, res)
cls.__members__[name] = res
return res
def __repr__(self):
return "<%s.%s: %d>" % (self.__class__.__name__, self.name, self)
def __str__(self):
return "%s.%s" % (self.__class__.__name__, self.name)

#################### EnumType ####################
#@requires: EnumBase

cdef dict __Pyx_globals = globals()
if PY_VERSION_HEX >= 0x03040000:
from enum import IntEnum
{{name}} = IntEnum('{{name}}', __Pyx_OrderedDict([
{{for item in items}}
('{{item}}', {{item}}),
{{endfor}}
]))
{{for item in items}}
__Pyx_globals['{{item}}'] = {{name}}.{{item}}
{{endfor}}
else:
class {{name}}(__Pyx_EnumBase):
pass
{{for item in items}}
__Pyx_globals['{{item}}'] = {{name}}({{item}}, '{{item}}')
{{endfor}}
17 changes: 3 additions & 14 deletions docs/src/reference/language_basics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ the same effect as the C directive ``#pragma pack(1)``.
cheddar, edam,
camembert

cdef enum CheeseState:
Declaring an enum as ```cpdef`` will create a PEP 435-style Python wrapper::

cpdef enum CheeseState:
hard = 1
soft = 2
runny = 3
Expand Down Expand Up @@ -803,16 +805,3 @@ Conditional Statements


.. [#] The conversion is to/from str for Python 2.x, and bytes for Python 3.x.
31 changes: 31 additions & 0 deletions tests/run/cpdef_enums.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ True
>>> RANK_3 # doctest: +ELLIPSIS
Traceback (most recent call last):
NameError: ...name 'RANK_3' is not defined
>>> set(PyxEnum) == set([TWO, THREE, FIVE])
True
>>> str(PyxEnum.TWO)
'PyxEnum.TWO'
>>> PyxEnum.TWO + PyxEnum.THREE == PyxEnum.FIVE
True
>>> PyxEnum(2) is PyxEnum["TWO"] is PyxEnum.TWO
True
"""


Expand All @@ -51,3 +60,25 @@ cpdef enum PyxEnum:

cdef enum SecretPyxEnum:
SEVEN = 7

def test_as_variable_from_cython():
"""
>>> test_as_variable_from_cython()
"""
import sys
if sys.version_info >= (2, 7):
assert list(PyxEnum) == [TWO, THREE, FIVE], list(PyxEnum)
assert list(PxdEnum) == [RANK_0, RANK_1, RANK_2], list(PxdEnum)
else:
# No OrderedDict.
assert set(PyxEnum) == {TWO, THREE, FIVE}, list(PyxEnum)
assert set(PxdEnum) == {RANK_0, RANK_1, RANK_2}, list(PxdEnum)

cdef int verify_pure_c() nogil:
cdef int x = TWO
cdef int y = PyxEnum.THREE
cdef int z = SecretPyxEnum.SEVEN
return x + y + z

# Use it to suppress warning.
verify_pure_c()

0 comments on commit b9f3159

Please sign in to comment.