Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast arrays and hstore columns using Python instead of Postgres db #72

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from setuptools import setup

setup(name='tap-postgres',
version='0.0.65',
version='0.0.66',
description='Singer.io tap for extracting data from PostgreSQL',
author='Stitch',
url='https://singer.io',
Expand Down
136 changes: 45 additions & 91 deletions tap_postgres/sync_strategies/logical_replication.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
#!/usr/bin/env python3
# pylint: disable=missing-docstring,not-an-iterable,too-many-locals,too-many-arguments,invalid-name,too-many-return-statements,too-many-branches,len-as-condition,too-many-nested-blocks,wrong-import-order,duplicate-code, anomalous-backslash-in-string, too-many-statements, singleton-comparison, consider-using-in

import singer
from functools import reduce
from select import select
import copy
import csv
import datetime
import decimal
import json
import re

from dateutil.parser import parse
import psycopg2
import singer
from singer import utils, get_bookmark
import singer.metadata as metadata
import tap_postgres.db as post_db
import tap_postgres.sync_strategies.common as sync_common
from dateutil.parser import parse
import psycopg2
from psycopg2 import sql
import copy
from select import select
from functools import reduce
import json
import re


LOGGER = singer.get_logger()

Expand Down Expand Up @@ -65,81 +67,29 @@ def get_stream_version(tap_stream_id, state):

return stream_version

def tuples_to_map(accum, t):
accum[t[0]] = t[1]
return accum

def create_hstore_elem_query(elem):
return sql.SQL("SELECT hstore_to_array({})").format(sql.Literal(elem))

def create_hstore_elem(conn_info, elem):
with post_db.open_connection(conn_info) as conn:
with conn.cursor() as cur:
query = create_hstore_elem_query(elem)
cur.execute(query)
res = cur.fetchone()[0]
hstore_elem = reduce(tuples_to_map, [res[i:i + 2] for i in range(0, len(res), 2)], {})
return hstore_elem

def create_array_elem(elem, sql_datatype, conn_info):
def create_hstore_elem(elem):
array = [(item.replace('"', '').split('=>')) for item in elem]
hstore = {}
for item in array:
if len(item) == 2:
key, value = item
if key in hstore:
raise KeyError('Duplicate key {} found when creating hstore'.format(key))
if value.lower() == 'null':
value = None
d[key] = value

return hstore

def create_array_elem(elem):
if elem is None:
return None

with post_db.open_connection(conn_info) as conn:
with conn.cursor() as cur:
if sql_datatype == 'bit[]':
cast_datatype = 'boolean[]'
elif sql_datatype == 'boolean[]':
cast_datatype = 'boolean[]'
elif sql_datatype == 'character varying[]':
cast_datatype = 'character varying[]'
elif sql_datatype == 'cidr[]':
cast_datatype = 'cidr[]'
elif sql_datatype == 'citext[]':
cast_datatype = 'text[]'
elif sql_datatype == 'date[]':
cast_datatype = 'text[]'
elif sql_datatype == 'double precision[]':
cast_datatype = 'double precision[]'
elif sql_datatype == 'hstore[]':
cast_datatype = 'text[]'
elif sql_datatype == 'integer[]':
cast_datatype = 'integer[]'
elif sql_datatype == 'bigint[]':
cast_datatype = 'bigint[]'
elif sql_datatype == 'inet[]':
cast_datatype = 'inet[]'
elif sql_datatype == 'json[]':
cast_datatype = 'text[]'
elif sql_datatype == 'jsonb[]':
cast_datatype = 'text[]'
elif sql_datatype == 'macaddr[]':
cast_datatype = 'macaddr[]'
elif sql_datatype == 'money[]':
cast_datatype = 'text[]'
elif sql_datatype == 'numeric[]':
cast_datatype = 'text[]'
elif sql_datatype == 'real[]':
cast_datatype = 'real[]'
elif sql_datatype == 'smallint[]':
cast_datatype = 'smallint[]'
elif sql_datatype == 'text[]':
cast_datatype = 'text[]'
elif sql_datatype in ('time without time zone[]', 'time with time zone[]'):
cast_datatype = 'text[]'
elif sql_datatype in ('timestamp with time zone[]', 'timestamp without time zone[]'):
cast_datatype = 'text[]'
elif sql_datatype == 'uuid[]':
cast_datatype = 'text[]'

else:
#custom datatypes like enums
cast_datatype = 'text[]'

sql_stmt = """SELECT $stitch_quote${}$stitch_quote$::{}""".format(elem, cast_datatype)
cur.execute(sql_stmt)
res = cur.fetchone()[0]
return res
elem = [elem[1:-1]]
reader = csv.reader(elem, delimiter=',', escapechar='\\' , quotechar='"')
array = next(reader)
array = [None if element.lower() == 'null' else element for element in array]
return array

#pylint: disable=too-many-branches,too-many-nested-blocks
def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info):
Expand All @@ -166,17 +116,21 @@ def selected_value_to_singer_value_impl(elem, og_sql_datatype, conn_info):
#for ordinary bits, elem will == '1'
return elem == '1' or elem == True
if sql_datatype == 'boolean':
return elem
return bool(elem)
if sql_datatype == 'hstore':
return create_hstore_elem(conn_info, elem)
return create_hstore_elem(elem)
if 'numeric' in sql_datatype:
return decimal.Decimal(str(elem))
if isinstance(elem, int):
return elem
if isinstance(elem, float):
return elem
if isinstance(elem, str):
return elem
return decimal.Decimal(elem)
if sql_datatype == 'money':
return decimal.Decimal(elem[1:])
if sql_datatype in ('integer', 'smallint', 'bigint'):
return int(elem)
if sql_datatype in ('double precision', 'real', 'float'):
return float(elem)
if sql_datatype in ('text', 'character varying'):
return elem # return as string
if sql_datatype in ('cidr', 'citext', 'json', 'jsonb', 'inet', 'macaddr', 'uuid'):
return elem # return as string

raise Exception("do not know how to marshall value of type {}".format(elem.__class__))

Expand All @@ -189,7 +143,7 @@ def selected_array_to_singer_value(elem, sql_datatype, conn_info):
def selected_value_to_singer_value(elem, sql_datatype, conn_info):
#are we dealing with an array?
if sql_datatype.find('[]') > 0:
cleaned_elem = create_array_elem(elem, sql_datatype, conn_info)
cleaned_elem = create_array_elem(elem)
return list(map(lambda elem: selected_array_to_singer_value(elem, sql_datatype, conn_info), (cleaned_elem or [])))

return selected_value_to_singer_value_impl(elem, sql_datatype, conn_info)
Expand Down
110 changes: 110 additions & 0 deletions tests/test_logical_replication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from decimal import Decimal
import unittest
from unittest.mock import patch

from utils import get_test_connection_config
from tap_postgres.sync_strategies import logical_replication


class TestHandlingArrays(unittest.TestCase):
def setUp(self):
self.env = patch.dict(
'os.environ', {
'TAP_POSTGRES_HOST':'test',
'TAP_POSTGRES_USER':'test',
'TAP_POSTGRES_PASSWORD':'test',
'TAP_POSTGRES_PORT':'5432'
},
)

self.arrays = [
'{10,01,NULL}',
'{t,f,NULL}',
'{127.0.0.1/32,10.0.0.0/32,NULL}',
'{CASE_INSENSITIVE,case_insensitive,NULL,"CASE,,INSENSITIVE"}',
'{2000-12-31,2001-01-01,NULL}',
'{3.14159265359,3.1415926,NULL}',
'{"\\"foo\\"=>\\"bar\\"","\\"baz\\"=>NULL",NULL}',
'{1,2,NULL}',
'{9223372036854775807,NULL}',
'{198.24.10.0/24,NULL}',
'{"{\\"foo\\":\\"bar\\"}",NULL}',
'{"{\\"foo\\": \\"bar\\"}",NULL}',
'{08:00:2b:01:02:03,NULL}',
'{$19.99,NULL}',
'{19.9999999,NULL}',
'{3.14159,NULL}',
'{0,1,NULL}',
'{foo,bar,NULL,"foo,bar","diederik\'s motel "}',
'{16:38:47,NULL}',
'{"2019-11-19 11:38:47-05",NULL}',
'{123e4567-e89b-12d3-a456-426655440000,NULL}'
]

self.sql_datatypes = {
'bit[]': bool,
'boolean[]': bool,
'cidr[]': str,
'citext[]': str,
'date[]': str,
'double precision[]': float,
'hstore[]': dict,
'integer[]': int,
'bigint[]': int,
'inet[]': str,
'json[]': str,
'jsonb[]': str,
'macaddr[]': str,
'money[]': Decimal,
'numeric[]': Decimal,
'real[]': float,
'smallint[]': int,
'text[]': str,
'time with time zone[]': str,
'timestamp with time zone[]': str,
'uuid[]': str,
}

def test_create_array_elem(self):
expected_arrays = [
['10', '01' ,None],
['t', 'f', None],
['127.0.0.1/32', '10.0.0.0/32', None],
['CASE_INSENSITIVE', 'case_insensitive', None,"CASE,,INSENSITIVE"],
['2000-12-31', '2001-01-01', None],
['3.14159265359','3.1415926', None],
['"foo"=>"bar"', '"baz"=>NULL', None],
['1','2',None],
['9223372036854775807', None],
['198.24.10.0/24', None],
["{\"foo\":\"bar\"}", None],
["{\"foo\": \"bar\"}", None],
['08:00:2b:01:02:03', None],
['$19.99', None],
['19.9999999', None],
['3.14159', None],
['0','1', None],
['foo','bar',None,"foo,bar","diederik\'s motel "],
['16:38:47',None],
["2019-11-19 11:38:47-05",None],
['123e4567-e89b-12d3-a456-426655440000', None],
]
for elem, expected_array in zip(self.arrays, expected_arrays):
array = logical_replication.create_array_elem(elem)
self.assertEqual(array, expected_array)

def test_selected_value_to_singer_value_impl(self):
with self.env:
conn_info = get_test_connection_config()
for elem, sql_datatype in zip(self.arrays, self.sql_datatypes.keys()):
array = logical_replication.selected_value_to_singer_value(elem, sql_datatype, conn_info)

for element in array:
python_datatype = self.sql_datatypes[sql_datatype]
if element:
self.assertIsInstance(element, python_datatype)

if __name__== "__main__":
test1 = TestHandlingArrays()
test1.setUp()
test1.test_selected_value_to_singer_value_impl()