Skip to content

Feature/89 rload fn #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmdstanpy/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.1'
__version__ = '0.4.1'
2 changes: 1 addition & 1 deletion cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def exe_file(self) -> str:

def compile(
self,
opt_lvl: int = 2,
opt_lvl: int = 3,
overwrite: bool = False,
include_paths: List[str] = None,
) -> None:
Expand Down
89 changes: 75 additions & 14 deletions cmdstanpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,72 @@ def rdump(path: str, data: Dict) -> None:
fd.write('\n')


def rload(fname: str) -> dict:
"""Parse data and parameter variable values from an R dump format file.
This parser only supports the subset of R dump data as described
in the "Dump Data Format" section of the CmdStan manual, i.e.,
scalar, vector, matrix, and array data types.
"""
data_dict = {}
with open(fname, 'r') as fp:
lines = fp.readlines()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, larger than RAM files are not really that common.

# Variable data may span multiple lines, parse accordingly
idx = 0
while idx < len(lines) and '<-' not in lines[idx]:
idx += 1
if idx == len(lines):
return None
start_idx = idx
idx += 1
while True:
while idx < len(lines) and '<-' not in lines[idx]:
idx += 1
next_var = idx
var_data = ''.join(lines[start_idx:next_var]).replace('\n', '')
lhs, rhs = [item.strip() for item in var_data.split('<-')]
lhs = lhs.replace('"', '') # strip optional Jags double quotes
rhs = rhs.replace('L', '') # strip R long int qualifier
data_dict[lhs] = parse_rdump_value(rhs)
if idx == len(lines):
break
start_idx = next_var
idx += 1
return data_dict


def parse_rdump_value(rhs: str) -> Union[int, float, np.array]:
"""Process right hand side of Rdump variable assignment statement.
Value is either scalar, vector, or multi-dim structure.
Use regex to capture structure values, dimensions.
"""
pat = re.compile(
r'structure\(\s*c\((?P<vals>[^)]*)\)'
r'(,\s*\.Dim\s*=\s*c\s*\((?P<dims>[^)]*)\s*\))?\)'
)
val = None
try:
if rhs.startswith('structure'):
parse = pat.match(rhs)
if parse is None or parse.group('vals') is None:
raise ValueError(rhs)
vals = [float(v) for v in parse.group('vals').split(',')]
val = np.array(vals, order='F')
if parse.group('dims') is not None:
dims = [int(v) for v in parse.group('dims').split(',')]
val = np.array(vals).reshape(dims, order='F')
elif rhs.startswith('c(') and rhs.endswith(')'):
val = np.array([float(item) for item in rhs[2:-1].split(',')])
elif '.' in rhs or 'e' in rhs:
val = float(rhs)
else:
val = int(rhs)
except TypeError:
raise ValueError(
'bad value in Rdump file: {}'.format(rhs)
)
return val


def check_csv(path: str, is_optimizing: bool = False) -> Dict:
"""Capture essential config, shape from stan_csv file."""
meta = scan_stan_csv(path, is_optimizing=is_optimizing)
Expand Down Expand Up @@ -409,7 +475,7 @@ def read_metric(path: str) -> List[int]:
' entry "inv_metric"'.format(path)
)
else:
dims = read_rdump_metric(path)
dims = list(read_rdump_metric(path))
if dims is None:
raise ValueError(
'metric file {}, bad or missing'
Expand All @@ -420,20 +486,15 @@ def read_metric(path: str) -> List[int]:

def read_rdump_metric(path: str) -> List[int]:
"""
Find dimensions of variable named 'inv_metric' using regex search.
Find dimensions of variable named 'inv_metric' in Rdump data file.
"""
with open(path, 'r') as fp:
data = fp.read().replace('\n', '')
m1 = re.search(r'inv_metric\s*<-\s*structure\(\s*c\(', data)
if not m1:
return_value = None
else:
m2 = re.search(r'\.Dim\s*=\s*c\(([^)]+)\)', data, m1.end())
if not m2:
return_value = None
dims = m2.group(1).split(',')
return_value = [int(d) for d in dims]
return return_value
metric_dict = rload(path)
if not ('inv_metric' in metric_dict and
isinstance(metric_dict['inv_metric'], np.ndarray)):
raise ValueError(
'metric file {}, bad or missing entry "inv_metric"'.format(path)
)
return list(metric_dict['inv_metric'].shape)


def do_command(cmd: str, cwd: str = None, logger: logging.Logger = None) -> str:
Expand Down
2 changes: 1 addition & 1 deletion test/data/metric_bad_2.data.R
Original file line number Diff line number Diff line change
@@ -1 +1 @@
inv_metric <- (0.787405, 0.884987, 1.19869)
inv_metric <- 0.787405
29 changes: 29 additions & 0 deletions test/data/rdump_array.data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
d <- 4
d_v1 <- structure(c(1, 2, 3, 4), .Dim = c(4))
d_v2 <- structure(c(15, 16, 17, 18), .Dim = c(4))
d_rv1 <- structure(c(10, 20, 30, 40), .Dim = c(4))
d_rv2 <- structure(c(101, 201, 301, 401), .Dim = c(4))
d_v_ar <-
structure(c(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200,
1300, 1400, 1500, 1600),
.Dim = c(4, 4))
d_rv_ar <-
structure(c(1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000,
11000, 12000, 13000, 14000, 15000, 16000),
.Dim = c(4, 4))
d_m <-
structure(c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16),
.Dim = c(4, 4))
d_m_ar <-
structure(c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,
58, 59, 60, 61, 62, 63, 64),
.Dim = c(4, 4, 4))


d_m2 <- structure(c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16),.Dim=c(2,8))

d_m3<-structure(c(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16),.Dim=c(1,16))

d_m4 <- structure(c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16),.Dim = c(8, 2))
1 change: 1 addition & 0 deletions test/data/rdump_bad_1.data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
d <- c('C1', 'C2', 'C3')
7 changes: 7 additions & 0 deletions test/data/rdump_bad_2.data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
d <- 4
d_m_ar <-
structure(c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,


1 change: 1 addition & 0 deletions test/data/rdump_bad_3.data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
d_m_ar <-
12 changes: 12 additions & 0 deletions test/data/rdump_jags.data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"N" <-
128
"M" <-
2
"y" <-
c(0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0,
1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0,
1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,
1, 1, 0)
56 changes: 56 additions & 0 deletions test/data/rdump_test.data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
N <-
128
M <-
2
y <-
c(0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0,
1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1,
1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0,
1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,
1, 1, 0)
x <-
structure(c(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1.3297992629225, 1.2724293214294, 0.414641434456408,
-1.53995004190371, -0.928567034713538, -0.29472044679056, -0.00576717274753696,
2.40465338885795, 0.76359346114046, -0.799009248989368, -1.14765700923635,
-0.289461573688223, -0.299215117897316, -0.411510832795067, 0.252223448156132,
-0.891921127284569, 0.435683299355719, -1.23753842192996, -0.224267885278309,
0.377395645981701, 0.133336360814841, 0.804189509744908, -0.0571067743838088,
0.503607972233726, 1.08576936214569, -0.69095383969683, -1.28459935387219,
0.046726172188352, -0.235706556439501, -0.542888255010254, -0.433310317456782,
-0.649471646796233, 0.726750747385451, 1.1519117540872, 0.992160365445798,
-0.429513109491881, 1.23830410085338, -0.279346281854269, 1.75790308981071,
0.560746090888056, -0.452783972553158, -0.832043296117832, -1.16657054708471,
-1.0655905803883, -1.563782051071, 1.15653699715018, 0.83204712857239,
-0.227328691424755, 0.266137361672105, -0.376702718583628, 2.44136462889459,
-0.795339117255372, -0.0548774737115786, 0.250141322854153, 0.618243293566247,
-0.172623502645857, -2.22390027400994, -1.26361438497058, 0.358728895971352,
-0.0110454784656636, -0.940649162618608, -0.115825322156954,
-0.814968708869917, 0.242263480859686, -1.4250983947325, 0.36594112304922,
0.248412648872596, 0.0652881816716207, 0.0191563916602738, 0.257338377155533,
-0.649010077708898, -0.119168762418038, 0.66413569989411, 1.10096910219409,
0.14377148075807, -0.117753598165951, -0.912068366948338, -1.43758624082998,
-0.797089525071965, 1.25408310644997, 0.77214218580453, -0.21951562675344,
-0.424810283377287, -0.418980099421959, 0.996986860909106, -0.275778029088027,
1.2560188173061, 0.646674390495345, 1.29931230256343, -0.873262111744435,
0.00837095999603331, -0.880871723252545, 0.59625901661066, 0.119717641289537,
-0.282173877322451, 1.45598840106634, 0.229019590694692, 0.996543928544126,
0.781859184600258, -0.776776621764597, -0.615989907707918, 0.0465803028049967,
-1.13038577760069, 0.576718781896486, -1.28074943178832, 1.62544730346494,
-0.500696596002705, 1.67829720781629, -0.412519887482398, -0.97228683550556,
0.0253828675878054, 0.0274753367451927, -1.68018272239593, 1.05375086302862,
-1.11959910457218, 0.335617209968815, 0.494795767113158, 0.138052708711737,
-0.118792025778828, 0.197684262345795, -1.06869271125479, -0.803213217364741,
-1.11376513631953, 1.58009168370384, 1.49781876103841, 0.262645458662762,
-1.23290119957126, -0.00372353379218051), .Dim = c(128L, 2L))

a <- 1.008
b <- -1.008
c <- -1.008e32
84 changes: 80 additions & 4 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
read_metric,
TemporaryCopiedFile,
windows_short_path,
rdump, rload, parse_rdump_value
)

here = os.path.dirname(os.path.abspath(__file__))
datafiles_path = os.path.join(here, 'data')

rdump = '''N <- 10
y <- c(0, 1, 0, 0, 0, 0, 0, 0, 0, 1)
'''


class CmdStanPathTest(unittest.TestCase):
def test_default_path(self):
Expand Down Expand Up @@ -248,5 +245,84 @@ def test_windows_short_path_file_with_space(self):
assert '.csv' == os.path.splitext(short_path)[1]


class RloadTest(unittest.TestCase):
def test_rload_metric(self):
dfile = os.path.join(datafiles_path, 'metric_diag.data.R')
data_dict = rload(dfile)
self.assertEqual(data_dict['inv_metric'].shape,(3,))

dfile = os.path.join(datafiles_path, 'metric_dense.data.R')
data_dict = rload(dfile)
self.assertEqual(data_dict['inv_metric'].shape,(3,3))

def test_rload_data(self):
dfile = os.path.join(datafiles_path, 'rdump_test.data.R')
data_dict = rload(dfile)
self.assertEqual(data_dict['N'],128)
self.assertEqual(data_dict['M'],2)
self.assertEqual(data_dict['x'].shape,(128,2))

def test_rload_jags_data(self):
dfile = os.path.join(datafiles_path, 'rdump_jags.data.R')
data_dict = rload(dfile)
self.assertEqual(data_dict['N'],128)
self.assertEqual(data_dict['M'],2)
self.assertEqual(data_dict['y'].shape,(128,))

def test_rload_wrong_data(self):
dfile = os.path.join(datafiles_path, 'metric_diag.data.json')
data_dict = rload(dfile)
self.assertEqual(data_dict,None)

def test_rload_bad_data_1(self):
dfile = os.path.join(datafiles_path, 'rdump_bad_1.data.R')
with self.assertRaises(ValueError):
data_dict = rload(dfile)

def test_rload_bad_data_2(self):
dfile = os.path.join(datafiles_path, 'rdump_bad_2.data.R')
with self.assertRaises(ValueError):
data_dict = rload(dfile)

def test_rload_bad_data_3(self):
dfile = os.path.join(datafiles_path, 'rdump_bad_3.data.R')
with self.assertRaises(ValueError):
data_dict = rload(dfile)

def test_roundtrip_metric(self):
dfile = os.path.join(datafiles_path, 'metric_diag.data.R')
data_dict_1 = rload(dfile)
self.assertEqual(data_dict_1['inv_metric'].shape,(3,))

dfile_tmp = os.path.join(datafiles_path, 'tmp.data.R')
rdump(dfile_tmp, data_dict_1)
data_dict_2 = rload(dfile_tmp)

self.assertTrue('inv_metric' in data_dict_2)
for i,x in enumerate(data_dict_2['inv_metric']):
self.assertEqual(x, data_dict_2['inv_metric'][i])

os.remove(dfile_tmp)

def test_parse_rdump_value(self):
s1 = 'structure(c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16),.Dim=c(2,8))'
v_s1 = parse_rdump_value(s1)
self.assertEqual(v_s1.shape,(2,8))
self.assertEqual(v_s1[1,0], 2)
self.assertEqual(v_s1[0,7], 15)

s2 = 'structure(c(1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16),.Dim=c(1,16))'
v_s2 = parse_rdump_value(s2)
self.assertEqual(v_s2.shape,(1,16))

s3 = 'structure(c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16),.Dim = c(8, 2))'
v_s3 = parse_rdump_value(s3)
self.assertEqual(v_s3.shape,(8,2))
self.assertEqual(v_s3[1,0], 2)
self.assertEqual(v_s3[7,0], 8)
self.assertEqual(v_s3[0,1], 9)
self.assertEqual(v_s3[6,1], 15)


if __name__ == '__main__':
unittest.main()