-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmarkov.py
executable file
·144 lines (130 loc) · 4.85 KB
/
markov.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/python3
import random, re, sys
from collections import Counter
try:
import json, lmdb
except ImportError:
json = lmdb = None
toks = [] # global token deduplicator
def dedup(tok):
"""Takes a token and deduplicates it, either by returning the previous version or by adding it to the list."""
try:
n = toks.index(tok)
except ValueError:
toks.append(tok)
return tok
else:
return toks[n]
class MarkovGenerator:
def __init__(self, clen=3, db=None):
"""Initialize the generator for a given chain length. The longer the chain
length, the more similar generated output will be to training input.
Reasonable values vary; for word-based text generation on relatively
short inputs, 3-5 is good, though 5 may end up generating mostly direct
quotes. Letter-based generation will need larger values.
This object can use an LMDB database, if passed it. If desired, 'db'
should be a tuple of (env, db), where these are the result of
lmdb.open() and env.open_db() respectively. (MarkovGenerator does not
want or need to care about filenames, named databases, &c.; it uses the
handles passed. Actually doing LMDB handling is left to the caller.) The
passed-in DB will be arbitrarily scribbled-on by the writing of keys; it
really does need a dedicated one per generator.
"""
self.clen = clen
if db is None:
self.lmdb = False
self.db = {}
else:
self.lmdb = True
self.env = db[0]
self.db = db[1]
def ntuples(self, ilist):
if len(ilist) < self.clen:
raise Exception("Not enough data")
for i in range(len(ilist)-self.clen+1):
r = tuple([dedup(n) for n in ilist[i:i+self.clen]])
yield r
def train_token(self, k, v, txn=None):
if not self.lmdb:
if k in self.db:
self.db[k].append(v)
else:
self.db[k] = [v]
else:
ks = json.dumps(k, separators=(',',':')).encode()
vd = txn.get(ks, default=b"{}")
vs = json.loads(vd.decode())
if v in vs:
vs[v] += 1
else:
vs[v] = 1
vsn = json.dumps(vs, separators=(',',':')).encode()
txn.put(ks, vsn)
def train(self, ilist, txn=None):
"""Feed the generator input to train on. ilist is an arbitrary list of data,
probably strings or integers. Common use cases are lists of the words or
letters of a text file, in order to generate text with a resemblance to
that file.
"""
if self.lmdb:
made = False
if txn is None:
made = True
txn = self.env.begin(write=True, db=self.db)
for i in self.ntuples(ilist):
k, v = i[:-1], i[-1]
self.train_token(k, v, txn)
if made:
txn.commit()
else:
for i in self.ntuples(ilist):
k, v = i[:-1], i[-1]
self.train_token(k, v)
def get_statelist(self):
if not self.lmdb:
return list(self.db.keys())
else:
with self.env.begin(db=self.db) as txn:
cursor = txn.cursor()
i = cursor.iternext(values=False)
return list(i)
def get_state_toks(self, state, txn=None):
if not self.lmdb:
return self.db[state]
else:
ks = json.dumps(state, separators=(',',':'))
v = txn.get(ks, default=b"{}")
vs = json.loads(v.decode())
if not vs:
raise KeyError("Not in database")
rv = []
for v in vs:
rv.extend([v] * vs[v])
return rv
def generate(self, n, istate=None):
"""Generate output. The result is a generator that will yield at most n
tokens. (Fewer may be generated, if the generator achieves a terminal
state. Each individual training creates a new terminal state, making
this more likely.) If istate is provided, then that will be the initial
state; otherwise, it will be chosen randomly from the database keys.
"""
if n < self.clen:
raise Exception("Not long enough")
if istate == None:
state = random.choice(self.get_statelist())
else:
state = istate
for i in state:
yield i
n -= len(state)
if self.lmdb:
txn = self.env.begin(db=self.db)
else:
txn = None
for i in range(n):
try:
w = random.choice(self.get_state_toks(state, txn))
except KeyError:
break
state = state[1:] + (w,)
yield w