Skip to content

Commit d5731a4

Browse files
committed
storing depth used for parse in db, no need to specify it when generating
1 parent 585355a commit d5731a4

File tree

9 files changed

+159
-94
lines changed

9 files changed

+159
-94
lines changed

db.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,55 @@
11

22
class Db:
3-
def __init__(self, depth, conn, sql):
4-
self.depth = depth
3+
DEPTH_PARAM_NAME = 'depth'
4+
5+
def __init__(self, conn, sql):
56
self.conn = conn
67
self.cursor = conn.cursor()
78
self.sql = sql
9+
self.depth = None
810

9-
self.cursor.execute(self.sql.create_table_sql())
10-
self.cursor.execute(self.sql.create_index_sql())
11+
def setup(self, depth):
12+
self.depth = depth
13+
self.cursor.execute(self.sql.create_word_table_sql(depth))
14+
self.cursor.execute(self.sql.create_index_sql(depth))
15+
self.cursor.execute(self.sql.create_param_table_sql())
16+
self.cursor.execute(self.sql.set_param_sql(), (self.DEPTH_PARAM_NAME, depth))
1117

1218
def _get_word_list_count(self, word_list):
13-
if len(word_list) != self.depth:
14-
raise ValueError('Expected %s words in list but found %s' % (self.depth, len(word_list)))
19+
if len(word_list) != self.get_depth():
20+
raise ValueError('Expected %s words in list but found %s' % (self.get_depth(), len(word_list)))
1521

16-
self.cursor.execute(self.sql.select_count_for_words_sql(), word_list)
22+
self.cursor.execute(self.sql.select_count_for_words_sql(self.get_depth()), word_list)
1723
r = self.cursor.fetchone()
1824
if r:
1925
return r[0]
2026
else:
2127
return 0
2228

29+
def get_depth(self):
30+
if self.depth == None:
31+
self.cursor.execute(self.sql.get_param_sql(), (self.DEPTH_PARAM_NAME,))
32+
r = self.cursor.fetchone()
33+
if r:
34+
self.depth = int(r[0])
35+
else:
36+
raise ValueError('No depth value found in database, db does not seem to have been created by this utility')
37+
38+
return self.depth
39+
2340
def add_word(self, word_list):
2441
count = self._get_word_list_count(word_list)
2542
if count:
26-
self.cursor.execute(self.sql.update_count_for_words_sql(), [count + 1] + word_list)
43+
self.cursor.execute(self.sql.update_count_for_words_sql(self.get_depth()), [count + 1] + word_list)
2744
else:
28-
self.cursor.execute(self.sql.insert_row_for_words_sql(), word_list + [1])
45+
self.cursor.execute(self.sql.insert_row_for_words_sql(self.get_depth()), word_list + [1])
2946

3047
def commit(self):
3148
self.conn.commit()
3249

3350
def get_word_count(self, word_list):
3451
counts = {}
35-
sql = self.sql.select_words_and_counts_sql()
52+
sql = self.sql.select_words_and_counts_sql(self.get_depth())
3653
for row in self.cursor.execute(sql, word_list):
3754
counts[row[0]] = row[1]
3855

gen.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def _get_next_word(self, word_list):
1717
return w
1818
assert False
1919

20-
def generate(self, depth, word_separator):
20+
def generate(self, word_separator):
21+
depth = self.db.get_depth()
2122
sentence = [Parser.SENTENCE_START_SYMBOL] * (depth - 1)
2223
end_symbol = [Parser.SENTENCE_END_SYMBOL] * (depth - 1)
2324

markov.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,38 @@
77
import sqlite3
88
import codecs
99

10-
SENTENCE_SEPARATOR = '\n'
11-
WORD_SEPARATOR = ''
10+
SENTENCE_SEPARATOR = '.'
11+
WORD_SEPARATOR = ' '
1212

1313
if __name__ == '__main__':
1414
args = sys.argv
15-
usage = 'Usage: %s (parse <name> <depth> <path to txt file>|gen <name> <depth> <count>)' % (args[0], )
15+
usage = 'Usage: %s (parse <name> <depth> <path to txt file>|gen <name> <count>)' % (args[0], )
1616

17-
if (len(args) != 5):
17+
if (len(args) < 3):
1818
raise ValueError(usage)
1919

2020
mode = args[1]
2121
name = args[2]
22-
depth = int(args[3])
23-
sql = Sql(depth)
24-
db = Db(depth, sqlite3.connect(name + '.db'), sql)
2522

2623
if mode == 'parse':
24+
if (len(args) != 5):
25+
raise ValueError(usage)
26+
27+
depth = int(args[3])
2728
file_name = args[4]
29+
30+
db = Db(sqlite3.connect(name + '.db'), Sql())
31+
db.setup(depth)
32+
2833
txt = codecs.open(file_name, 'r', 'utf-8').read()
29-
Parser(name, db, SENTENCE_SEPARATOR, WORD_SEPARATOR).parse(depth, txt)
34+
Parser(name, db, SENTENCE_SEPARATOR, WORD_SEPARATOR).parse(txt)
3035

3136
elif mode == 'gen':
32-
count = int(args[4])
37+
count = int(args[3])
38+
db = Db(sqlite3.connect(name + '.db'), Sql())
3339
generator = Generator(name, db, Rnd())
3440
for i in range(0, count):
35-
print generator.generate(depth, WORD_SEPARATOR)
41+
print generator.generate(WORD_SEPARATOR)
3642

3743
else:
3844
raise ValueError(usage)

parse.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
import re
23

34
class Parser:
45
SENTENCE_START_SYMBOL = '^'
@@ -9,12 +10,16 @@ def __init__(self, name, db, sentence_split_char = '\n', word_split_char = ''):
910
self.db = db
1011
self.sentence_split_char = sentence_split_char
1112
self.word_split_char = word_split_char
13+
self.whitespace_regex = re.compile('\s+')
1214

13-
def parse(self, depth, txt):
15+
def parse(self, txt):
16+
depth = self.db.get_depth()
1417
sentences = txt.split(self.sentence_split_char)
1518
i = 0
1619

1720
for sentence in sentences:
21+
sentence = self.whitespace_regex.sub(" ", sentence).strip()
22+
1823
list_of_words = None
1924
if self.word_split_char:
2025
list_of_words = sentence.split(self.word_split_char)

sql.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,54 @@
11
class Sql:
22
WORD_COL_NAME_PREFIX = 'word'
33
COUNT_COL_NAME = 'count'
4-
TABLE_NAME = 'word'
4+
WORD_TABLE_NAME = 'word'
55
INDEX_NAME = 'i_word'
6+
PARAM_TABLE_NAME = 'param'
7+
KEY_COL_NAME = 'name'
8+
VAL_COL_NAME = 'value'
69

7-
def __init__(self, column_count):
8-
self.column_count = int(column_count)
9-
if self.column_count < 2:
10+
def _check_column_count(self, count):
11+
if count < 2:
1012
raise ValueError('Invalid column_count value, must be >= 2')
1113

12-
def _make_column_name_list(self):
13-
return ', '.join([self.WORD_COL_NAME_PREFIX + str(n) for n in range(1, self.column_count + 1)])
14+
def _make_column_name_list(self, column_count):
15+
return ', '.join([self.WORD_COL_NAME_PREFIX + str(n) for n in range(1, column_count + 1)])
1416

15-
def _make_column_names_and_placeholders(self, col_count):
16-
return ' AND '.join(['%s%s=?' % (self.WORD_COL_NAME_PREFIX, n) for n in range(1, col_count + 1)])
17+
def _make_column_names_and_placeholders(self, column_count):
18+
return ' AND '.join(['%s%s=?' % (self.WORD_COL_NAME_PREFIX, n) for n in range(1, column_count + 1)])
1719

18-
def create_table_sql(self):
19-
return 'CREATE TABLE IF NOT EXISTS %s (%s, %s)' % (self.TABLE_NAME, self._make_column_name_list(), self.COUNT_COL_NAME)
20+
def create_word_table_sql(self, column_count):
21+
return 'CREATE TABLE IF NOT EXISTS %s (%s, %s)' % (self.WORD_TABLE_NAME, self._make_column_name_list(column_count), self.COUNT_COL_NAME)
2022

21-
def create_index_sql(self):
22-
return 'CREATE INDEX IF NOT EXISTS %s ON %s (%s)' % (self.INDEX_NAME, self.TABLE_NAME, self._make_column_name_list())
23+
def create_param_table_sql(self):
24+
return 'CREATE TABLE IF NOT EXISTS %s (%s, %s)' % (self.PARAM_TABLE_NAME, self.KEY_COL_NAME, self.VAL_COL_NAME)
2325

24-
def select_count_for_words_sql(self):
25-
return 'SELECT %s FROM %s WHERE %s' % (self.COUNT_COL_NAME, self.TABLE_NAME, self._make_column_names_and_placeholders(self.column_count))
26+
def set_param_sql(self):
27+
return 'INSERT INTO %s (%s, %s) VALUES (?, ?)' % (self.PARAM_TABLE_NAME, self.KEY_COL_NAME, self.VAL_COL_NAME)
2628

27-
def update_count_for_words_sql(self):
28-
return 'UPDATE %s SET %s=? WHERE %s' % (self.TABLE_NAME, self.COUNT_COL_NAME, self._make_column_names_and_placeholders(self.column_count))
29+
def get_param_sql(self):
30+
return 'SELECT %s FROM %s WHERE %s=?' % (self.VAL_COL_NAME, self.PARAM_TABLE_NAME, self.KEY_COL_NAME)
31+
32+
def create_index_sql(self, column_count):
33+
return 'CREATE INDEX IF NOT EXISTS %s ON %s (%s)' % (self.INDEX_NAME, self.WORD_TABLE_NAME, self._make_column_name_list(column_count))
34+
35+
def select_count_for_words_sql(self, column_count):
36+
return 'SELECT %s FROM %s WHERE %s' % (self.COUNT_COL_NAME, self.WORD_TABLE_NAME, self._make_column_names_and_placeholders(column_count))
37+
38+
def update_count_for_words_sql(self, column_count):
39+
return 'UPDATE %s SET %s=? WHERE %s' % (self.WORD_TABLE_NAME, self.COUNT_COL_NAME, self._make_column_names_and_placeholders(column_count))
2940

30-
def insert_row_for_words_sql(self):
31-
columns = self._make_column_name_list() + ', ' + self.COUNT_COL_NAME
32-
values = ', '.join(['?'] * (self.column_count + 1))
41+
def insert_row_for_words_sql(self, column_count):
42+
columns = self._make_column_name_list(column_count) + ', ' + self.COUNT_COL_NAME
43+
values = ', '.join(['?'] * (column_count + 1))
3344

34-
return 'INSERT INTO %s (%s) VALUES (%s)' % (self.TABLE_NAME, columns, values)
45+
return 'INSERT INTO %s (%s) VALUES (%s)' % (self.WORD_TABLE_NAME, columns, values)
3546

36-
def select_words_and_counts_sql(self):
37-
last_word_col_name = self.WORD_COL_NAME_PREFIX + str(self.column_count)
47+
def select_words_and_counts_sql(self, column_count):
48+
last_word_col_name = self.WORD_COL_NAME_PREFIX + str(column_count)
3849

39-
return 'SELECT %s, %s FROM %s WHERE %s' % (last_word_col_name, self.COUNT_COL_NAME, self.TABLE_NAME, self._make_column_names_and_placeholders(self.column_count - 1))
50+
return 'SELECT %s, %s FROM %s WHERE %s' % (last_word_col_name, self.COUNT_COL_NAME, self.WORD_TABLE_NAME, self._make_column_names_and_placeholders(column_count - 1))
4051

4152
def delete_words_sql(self):
42-
return 'DELETE FROM ' + self.TABLE_NAME
53+
return 'DELETE FROM ' + self.WORD_TABLE_NAME
4354

test/db_test.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,57 +6,64 @@ def setUp(self):
66
self.conn = StubConn()
77
self.sql = StubSql()
88

9-
def test_correct_sql_run_when_object_created(self):
10-
Db(3, self.conn, self.sql)
9+
def test_correct_sql_run_when_setup_called(self):
10+
Db(self.conn, self.sql).setup(3)
1111
execute_args = self.conn.stub_cursor.execute_args
12-
self.assertEqual(len(execute_args), 2)
13-
self.assertEqual(execute_args[0], ('create_table_sql',))
14-
self.assertEqual(execute_args[1], ('create_index_sql',))
12+
self.assertEqual(len(execute_args), 4)
13+
self.assertEqual(execute_args[0], ('create_word_table_sql 3',))
14+
self.assertEqual(execute_args[1], ('create_index_sql 3',))
15+
self.assertEqual(execute_args[2], ('create_param_table_sql',))
16+
self.assertEqual(execute_args[3], ('set_param_sql', ('depth', 3)))
1517

1618
def test_error_when_add_word_count_wrong(self):
17-
db = Db(3, self.conn, self.sql)
19+
db = Db(self.conn, self.sql)
20+
db.setup(3)
1821
self.assertRaises(ValueError, db.add_word, ['one','two'])
1922

2023
def test_insert_row_when_add_new_word_list(self):
21-
db = Db(3, self.conn, self.sql)
24+
db = Db(self.conn, self.sql)
25+
db.setup(3)
2226
word_list = ['one', 'two', 'three']
2327
db.add_word(word_list)
2428

2529
execute_args = self.conn.stub_cursor.execute_args
26-
self.assertEqual(len(execute_args), 4)
27-
self.assertEqual(execute_args[2], ('select_count_for_words_sql', word_list))
28-
self.assertEqual(execute_args[3], ('insert_row_for_words_sql', word_list + [1]))
30+
self.assertEqual(len(execute_args), 6)
31+
self.assertEqual(execute_args[4], ('select_count_for_words_sql 3', word_list))
32+
self.assertEqual(execute_args[5], ('insert_row_for_words_sql 3', word_list + [1]))
2933

3034
def test_update_row_when_add_repeated_word_list(self):
31-
db = Db(3, self.conn, self.sql)
35+
db = Db(self.conn, self.sql)
36+
db.setup(3)
3237
row_count = 10
3338
word_list = ['one', 'two', 'three']
3439
self.conn.stub_cursor.fetchone_results.append([row_count])
3540

3641
db.add_word(word_list)
3742

3843
execute_args = self.conn.stub_cursor.execute_args
39-
self.assertEqual(len(execute_args), 4)
40-
self.assertEqual(execute_args[2], ('select_count_for_words_sql', word_list))
41-
self.assertEqual(execute_args[3], ('update_count_for_words_sql', [row_count + 1] + word_list))
44+
self.assertEqual(len(execute_args), 6)
45+
self.assertEqual(execute_args[4], ('select_count_for_words_sql 3', word_list))
46+
self.assertEqual(execute_args[5], ('update_count_for_words_sql 3', [row_count + 1] + word_list))
4247

4348
def test_db_commit_performed_correctly(self):
44-
db = Db(3, self.conn, self.sql)
49+
db = Db(self.conn, self.sql)
50+
db.setup(3)
4551
self.assertEqual(self.conn.commit_count, 0)
4652
db.commit()
4753
self.assertEqual(self.conn.commit_count, 1)
4854

4955
def test_get_word_counts_works_correctly(self):
50-
db = Db(3, self.conn, self.sql)
56+
db = Db(self.conn, self.sql)
57+
db.setup(3)
5158
word_list = ['i', 'like']
5259
self.conn.stub_cursor.execute_results = [[['dogs', 1], ['cats', 2], ['frogs', 3]]]
5360

5461
word_counts = db.get_word_count(word_list)
5562

5663
self.assertEqual(word_counts, {'dogs' : 1, 'cats' : 2, 'frogs' : 3})
5764
execute_args = self.conn.stub_cursor.execute_args
58-
self.assertEqual(len(execute_args), 3)
59-
self.assertEqual(execute_args[2], ('select_words_and_counts_sql', word_list))
65+
self.assertEqual(len(execute_args), 5)
66+
self.assertEqual(execute_args[4], ('select_words_and_counts_sql 3', word_list))
6067

6168
class StubCursor:
6269
def __init__(self):
@@ -96,23 +103,32 @@ def cursor(self):
96103
return self.stub_cursor
97104

98105
class StubSql:
99-
def create_table_sql(self):
100-
return 'create_table_sql'
106+
def create_word_table_sql(self, column_count):
107+
return 'create_word_table_sql' + ' ' + str(column_count)
108+
109+
def create_index_sql(self, column_count):
110+
return 'create_index_sql' + ' ' + str(column_count)
111+
112+
def create_param_table_sql(self):
113+
return 'create_param_table_sql'
114+
115+
def set_param_sql(self):
116+
return 'set_param_sql'
101117

102-
def create_index_sql(self):
103-
return 'create_index_sql'
118+
def get_param_sql(self):
119+
return 'get_param_sql'
104120

105-
def select_count_for_words_sql(self):
106-
return 'select_count_for_words_sql'
121+
def select_count_for_words_sql(self, column_count):
122+
return 'select_count_for_words_sql' + ' ' + str(column_count)
107123

108-
def update_count_for_words_sql(self):
109-
return 'update_count_for_words_sql'
124+
def update_count_for_words_sql(self, column_count):
125+
return 'update_count_for_words_sql' + ' ' + str(column_count)
110126

111-
def insert_row_for_words_sql(self):
112-
return 'insert_row_for_words_sql'
127+
def insert_row_for_words_sql(self, column_count):
128+
return 'insert_row_for_words_sql' + ' ' + str(column_count)
113129

114-
def select_words_and_counts_sql(self):
115-
return 'select_words_and_counts_sql'
130+
def select_words_and_counts_sql(self, column_count):
131+
return 'select_words_and_counts_sql' + ' ' + str(column_count)
116132

117133
def delete_words_sql(self):
118134
return 'delete_words_sql'

test/gen_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,19 @@ def test_generated_sequence_is_correct(self):
1919

2020
self.rnd.vals = [1, 2, 2, 1, 4, 1, 1]
2121

22-
self.assertEqual(Generator('name', self.db, self.rnd).generate(2, ' '), 'the cat sat on the mat')
22+
self.assertEqual(Generator('name', self.db, self.rnd).generate(' '), 'the cat sat on the mat')
2323
self.assertEqual(self.db.get_word_count_args, [['^'], ['the'], ['cat'], ['sat'], ['on'], ['the'], ['mat']])
2424
self.assertEqual(self.rnd.maxints, [3, 2, 2, 5, 4, 2, 1])
2525

2626
class StubDb:
2727
def __init__(self):
2828
self.count_values = []
2929
self.get_word_count_args = []
30+
self.depth = 2
3031

32+
def get_depth(self):
33+
return self.depth
34+
3135
def get_word_count(self, word_list):
3236
self.get_word_count_args.append(word_list)
3337
return self.count_values.pop(0)

0 commit comments

Comments
 (0)