Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Pouya Rostam committed Nov 16, 2023
1 parent 6a3b65b commit c54df0b
Show file tree
Hide file tree
Showing 11 changed files with 424 additions and 302 deletions.
16 changes: 5 additions & 11 deletions binding/python/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
from ._util import default_library_path, default_model_path


def create(
access_key: str,
model_path: Optional[str] = None,
library_path: Optional[str] = None) -> Falcon:
def create(access_key: str, model_path: Optional[str] = None, library_path: Optional[str] = None) -> Falcon:
"""
Factory method for Falcon speaker diarization engine.
Expand All @@ -30,17 +27,14 @@ def create(
"""

if model_path is None:
model_path = default_model_path('')
model_path = default_model_path("")

if library_path is None:
library_path = default_library_path('')
library_path = default_library_path("")

return Falcon(
access_key=access_key,
model_path=model_path,
library_path=library_path)
return Falcon(access_key=access_key, model_path=model_path, library_path=library_path)


__all__ = [
'create',
"create",
]
143 changes: 89 additions & 54 deletions binding/python/_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,27 @@


class FalconError(Exception):
pass
def __init__(self, message: str = "", message_stack: Sequence[str] = None):
super().__init__(message)

self._message = message
self._message_stack = list() if message_stack is None else message_stack

def __str__(self):
message = self._message
if len(self._message_stack) > 0:
message += ":"
for i in range(len(self._message_stack)):
message += "\n [%d] %s" % (i, self._message_stack[i])
return message

@property
def message(self) -> str:
return self._message

@property
def message_stack(self) -> Sequence[str]:
return self._message_stack


class FalconMemoryError(FalconError):
Expand Down Expand Up @@ -84,7 +104,7 @@ class PicovoiceStatuses(Enum):
PicovoiceStatuses.ACTIVATION_ERROR: FalconActivationError,
PicovoiceStatuses.ACTIVATION_LIMIT_REACHED: FalconActivationLimitError,
PicovoiceStatuses.ACTIVATION_THROTTLED: FalconActivationThrottledError,
PicovoiceStatuses.ACTIVATION_REFUSED: FalconActivationRefusedError
PicovoiceStatuses.ACTIVATION_REFUSED: FalconActivationRefusedError,
}

_VALID_EXTENSIONS = (
Expand All @@ -104,16 +124,12 @@ class CFalcon(Structure):
pass

class CSegment(Structure):
_fields_ = [
("start_sec", c_float),
("end_sec", c_float),
("speaker_tag", c_int32)]

def __init__(
self,
access_key: str,
model_path: str,
library_path: str) -> None:
"""
Represents a segment with its start, end, and associated speaker tag.
"""
_fields_ = [("start_sec", c_float), ("end_sec", c_float), ("speaker_tag", c_int32)]

def __init__(self, access_key: str, model_path: str, library_path: str) -> None:
"""
Constructor.
Expand All @@ -133,18 +149,31 @@ def __init__(

library = cdll.LoadLibrary(library_path)

set_sdk_func = library.pv_set_sdk
set_sdk_func.argtypes = [c_char_p]
set_sdk_func.restype = None

set_sdk_func("python".encode("utf-8"))

self._get_error_stack_func = library.pv_get_error_stack
self._get_error_stack_func.argtypes = [POINTER(POINTER(c_char_p)), POINTER(c_int)]
self._get_error_stack_func.restype = self.PicovoiceStatuses

self._free_error_stack_func = library.pv_free_error_stack
self._free_error_stack_func.argtypes = [POINTER(c_char_p)]
self._free_error_stack_func.restype = None

init_func = library.pv_falcon_init
init_func.argtypes = [c_char_p, c_char_p, POINTER(POINTER(self.CFalcon))]
init_func.restype = self.PicovoiceStatuses

self._handle = POINTER(self.CFalcon)()

status = init_func(
access_key.encode(),
model_path.encode(),
byref(self._handle))
status = init_func(access_key.encode(), model_path.encode(), byref(self._handle))
if status is not self.PicovoiceStatuses.SUCCESS:
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]()
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](
message="Initialization failed", message_stack=self._get_error_stack()
)

self._delete_func = library.pv_falcon_delete
self._delete_func.argtypes = [POINTER(self.CFalcon)]
Expand All @@ -156,7 +185,7 @@ def __init__(
POINTER(c_short),
c_int32,
POINTER(c_int32),
POINTER(POINTER(self.CSegment))
POINTER(POINTER(self.CSegment)),
]
self._process_func.restype = self.PicovoiceStatuses

Expand All @@ -165,24 +194,22 @@ def __init__(
POINTER(self.CFalcon),
c_char_p,
POINTER(c_int32),
POINTER(POINTER(self.CSegment))
POINTER(POINTER(self.CSegment)),
]
self._process_file_func.restype = self.PicovoiceStatuses

version_func = library.pv_falcon_version
version_func.argtypes = []
version_func.restype = c_char_p
self._version = version_func().decode('utf-8')
self._version = version_func().decode("utf-8")

self._sample_rate = library.pv_sample_rate()

self._segments_delete_func = library.pv_falcon_segments_delete
self._segments_delete_func.argtypes = [
POINTER(self.CSegment)
]
self._segments_delete_func.argtypes = [POINTER(self.CSegment)]
self._segments_delete_func.restype = None

Segment = namedtuple('Segment', ['start_sec', 'end_sec', 'speaker_tag'])
Segment = namedtuple("Segment", ["start_sec", "end_sec", "speaker_tag"])

def process(self, pcm: Sequence[int]) -> Sequence[Segment]:
"""
Expand All @@ -201,20 +228,18 @@ def process(self, pcm: Sequence[int]) -> Sequence[Segment]:
num_segments = c_int32()
c_segments = POINTER(self.CSegment)()
status = self._process_func(
self._handle,
(c_short * len(pcm))(*pcm),
len(pcm),
byref(num_segments),
byref(c_segments))
self._handle, (c_short * len(pcm))(*pcm), len(pcm), byref(num_segments), byref(c_segments)
)
if status is not self.PicovoiceStatuses.SUCCESS:
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status]()
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](
message="Initialization failed", message_stack=self._get_error_stack()
)

segments = list()
for i in range(num_segments.value):
word = self.Segment(
start_sec=c_segments[i].start_sec,
end_sec=c_segments[i].end_sec,
speaker_tag=c_segments[i].speaker_tag)
start_sec=c_segments[i].start_sec, end_sec=c_segments[i].end_sec, speaker_tag=c_segments[i].speaker_tag
)
segments.append(word)

self._segments_delete_func(c_segments)
Expand All @@ -236,11 +261,7 @@ def process_file(self, audio_path: str) -> Sequence[Segment]:

num_segments = c_int32()
c_segments = POINTER(self.CSegment)()
status = self._process_file_func(
self._handle,
audio_path.encode(),
byref(num_segments),
byref(c_segments))
status = self._process_file_func(self._handle, audio_path.encode(), byref(num_segments), byref(c_segments))
if status is not self.PicovoiceStatuses.SUCCESS:
if status is self.PicovoiceStatuses.INVALID_ARGUMENT:
if not audio_path.lower().endswith(self._VALID_EXTENSIONS):
Expand All @@ -252,9 +273,8 @@ def process_file(self, audio_path: str) -> Sequence[Segment]:
segments = list()
for i in range(num_segments.value):
word = self.Segment(
start_sec=c_segments[i].start_sec,
end_sec=c_segments[i].end_sec,
speaker_tag=c_segments[i].speaker_tag)
start_sec=c_segments[i].start_sec, end_sec=c_segments[i].end_sec, speaker_tag=c_segments[i].speaker_tag
)
segments.append(word)

self._segments_delete_func(c_segments)
Expand All @@ -278,19 +298,34 @@ def sample_rate(self) -> int:

return self._sample_rate

def _get_error_stack(self) -> Sequence[str]:
message_stack_ref = POINTER(c_char_p)()
message_stack_depth = c_int()
status = self._get_error_stack_func(byref(message_stack_ref), byref(message_stack_depth))
if status is not self.PicovoiceStatuses.SUCCESS:
raise self._PICOVOICE_STATUS_TO_EXCEPTION[status](message="Unable to get Porcupine error state")

message_stack = list()
for i in range(message_stack_depth.value):
message_stack.append(message_stack_ref[i].decode("utf-8"))

self._free_error_stack_func(message_stack_ref)

return message_stack


__all__ = [
'Falcon',
'FalconActivationError',
'FalconActivationLimitError',
'FalconActivationRefusedError',
'FalconActivationThrottledError',
'FalconError',
'FalconIOError',
'FalconInvalidArgumentError',
'FalconInvalidStateError',
'FalconKeyError',
'FalconMemoryError',
'FalconRuntimeError',
'FalconStopIterationError',
"Falcon",
"FalconActivationError",
"FalconActivationLimitError",
"FalconActivationRefusedError",
"FalconActivationThrottledError",
"FalconError",
"FalconIOError",
"FalconInvalidArgumentError",
"FalconInvalidStateError",
"FalconKeyError",
"FalconMemoryError",
"FalconRuntimeError",
"FalconStopIterationError",
]
67 changes: 33 additions & 34 deletions binding/python/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,65 +15,64 @@


def _is_64bit():
return '64bit' in platform.architecture()[0]
return "64bit" in platform.architecture()[0]


def _linux_machine() -> str:
machine = platform.machine()
if machine == 'x86_64':
if machine == "x86_64":
return machine
elif machine in ['aarch64', 'armv7l']:
arch_info = ('-' + machine) if _is_64bit() else ''
elif machine in ["aarch64", "armv7l"]:
arch_info = ("-" + machine) if _is_64bit() else ""
else:
raise NotImplementedError("Unsupported CPU architecture: `%s`" % machine)

cpu_info = ''
cpu_info = ""
try:
cpu_info = subprocess.check_output(['cat', '/proc/cpuinfo']).decode('utf-8')
cpu_part_list = [x for x in cpu_info.split('\n') if 'CPU part' in x]
cpu_part = cpu_part_list[0].split(' ')[-1].lower()
cpu_info = subprocess.check_output(["cat", "/proc/cpuinfo"]).decode("utf-8")
cpu_part_list = [x for x in cpu_info.split("\n") if "CPU part" in x]
cpu_part = cpu_part_list[0].split(" ")[-1].lower()
except Exception as e:
raise RuntimeError("Failed to identify the CPU with `%s`\nCPU info: `%s`" % (e, cpu_info))

if '0xd03' == cpu_part:
return 'cortex-a53' + arch_info
elif '0xd07' == cpu_part:
return 'cortex-a57' + arch_info
elif '0xd08' == cpu_part:
return 'cortex-a72' + arch_info
if "0xd03" == cpu_part:
return "cortex-a53" + arch_info
elif "0xd07" == cpu_part:
return "cortex-a57" + arch_info
elif "0xd08" == cpu_part:
return "cortex-a72" + arch_info
else:
raise NotImplementedError("Unsupported CPU: `%s`." % cpu_part)


_RASPBERRY_PI_MACHINES = {'cortex-a53', 'cortex-a72', 'cortex-a53-aarch64', 'cortex-a72-aarch64'}
_JETSON_MACHINES = {'cortex-a57-aarch64'}
_RASPBERRY_PI_MACHINES = {"cortex-a53", "cortex-a72", "cortex-a53-aarch64", "cortex-a72-aarch64"}
_JETSON_MACHINES = {"cortex-a57-aarch64"}


def default_library_path(relative: str = '') -> str:
if platform.system() == 'Darwin':
if platform.machine() == 'x86_64':
return os.path.join(os.path.dirname(__file__), relative, 'lib/mac/x86_64/libpv_falcon.dylib')
def default_library_path(relative: str = "") -> str:
if platform.system() == "Darwin":
if platform.machine() == "x86_64":
return os.path.join(os.path.dirname(__file__), relative, "lib/mac/x86_64/libpv_falcon.dylib")
elif platform.machine() == "arm64":
return os.path.join(os.path.dirname(__file__), relative, 'lib/mac/arm64/libpv_falcon.dylib')
elif platform.system() == 'Linux':
return os.path.join(os.path.dirname(__file__), relative, "lib/mac/arm64/libpv_falcon.dylib")
elif platform.system() == "Linux":
linux_machine = _linux_machine()
if linux_machine == 'x86_64':
return os.path.join(os.path.dirname(__file__), relative, 'lib/linux/x86_64/libpv_falcon.so')
if linux_machine == "x86_64":
return os.path.join(os.path.dirname(__file__), relative, "lib/linux/x86_64/libpv_falcon.so")
elif linux_machine in _JETSON_MACHINES:
return os.path.join(os.path.dirname(__file__), relative, 'lib/jetson/%s/libpv_falcon.so' % linux_machine)
return os.path.join(os.path.dirname(__file__), relative, "lib/jetson/%s/libpv_falcon.so" % linux_machine)
elif linux_machine in _RASPBERRY_PI_MACHINES:
return os.path.join(
os.path.dirname(__file__),
relative,
'lib/raspberry-pi/%s/libpv_falcon.so' % linux_machine)
elif platform.system() == 'Windows':
return os.path.join(os.path.dirname(__file__), relative, 'lib', 'windows', 'amd64', 'libpv_falcon.dll')
os.path.dirname(__file__), relative, "lib/raspberry-pi/%s/libpv_falcon.so" % linux_machine
)
elif platform.system() == "Windows":
return os.path.join(os.path.dirname(__file__), relative, "lib", "windows", "amd64", "libpv_falcon.dll")

raise NotImplementedError('Unsupported platform.')
raise NotImplementedError("Unsupported platform.")


def default_model_path(relative: str = '') -> str:
return os.path.join(os.path.dirname(__file__), relative, 'lib', 'common', 'falcon_params.pv')
def default_model_path(relative: str = "") -> str:
return os.path.join(os.path.dirname(__file__), relative, "lib", "common", "falcon_params.pv")


__all__ = ['default_library_path', 'default_model_path']
__all__ = ["default_library_path", "default_model_path"]
Loading

0 comments on commit c54df0b

Please sign in to comment.