Skip to content

Commit 3ddb7fb

Browse files
committed
added support for the count field to be omitted if 1
1 parent f8d5235 commit 3ddb7fb

File tree

6 files changed

+75
-4
lines changed

6 files changed

+75
-4
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,11 @@ You can iterate over the document-term counts at the leaves of the HDTM with the
9898

9999
You can **graft** one HTDM under another by using the `graft(prefix, subtree)` method, specifying as `prefix` the document address you want to add the subtree under. This is useful if you have an HTDM for, say, a single work by an author, with chapters as documents and you want to incorporate that into a higher-level HTDM of multiple works by the author, or a collection of works by different authors.
100100

101+
The third (count) field in a loaded file can be omitted if the count is 1 and a document ID + term may be repeated with the counts accumulating.
102+
101103
### Duplicates Policy
102104

103-
You can optionall pass in a `duplicates` setting to the constructorr indicating the policy you want to follow if a term-document count is updated more than once.
105+
You can optionally pass in a `duplicates` setting to the constructorr indicating the policy you want to follow if a term-document count is updated more than once.
104106

105107
```python
106108
>>> c = termdoc.HTDM()

termdoc/htdm.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,16 @@ def increment_count(self, address, term, count):
6464
def load(self, filename, field_sep="\t", address_sep="."):
6565
with open(filename) as f:
6666
for line in f:
67-
address_string, term, count_string = line.strip().split(field_sep)
67+
fields = line.strip().split(field_sep)
68+
if len(fields) == 3:
69+
address_string, term, count_string = fields
70+
count = int(count_string)
71+
elif len(fields) == 2:
72+
address_string, term = fields
73+
count = 1
74+
else:
75+
raise ValueError(f"{fields} should have 2 or 3 fields")
6876
address = tuple(address_string.split(address_sep))
69-
count = int(count_string)
7077
self.increment_count(address, term, count)
7178

7279
def get_counts(self, prefix=()):

test_data/test2.tsv

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
1.1 foo
2+
1.1 foo
3+
1.1 foo
4+
1.1 foo
5+
1.1 foo
6+
1.1 foo
7+
1.1 foo
8+
1.1 bar
9+
1.1 bar
10+
1.1 bar
11+
1.1 bar
12+
1.2 foo
13+
1.2 foo
14+
1.3 bar
15+
2.1 baz
16+
2.1 baz
17+
2.1 baz
18+
2.1 baz
19+
2.1 baz
20+
2.1 foo

test_data/test3e.tsv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1

test_data/test4e.tsv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
1 2 3 4

tests.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_multi_level(self):
3131
self.assertEqual(c.get_counts()["bar"], 6)
3232
self.assertEqual(c.get_counts((2,))["foo"], 4)
3333

34-
def test_load(self):
34+
def test_load1(self):
3535
import termdoc
3636

3737
c = termdoc.HTDM()
@@ -59,6 +59,46 @@ def test_load(self):
5959
self.assertEqual(c.get_counts(("2", "1"))["bar"], 0)
6060
self.assertEqual(c.get_counts(("2", "1"))["baz"], 5)
6161

62+
def test_load2(self):
63+
import termdoc
64+
65+
c = termdoc.HTDM()
66+
c.load("test_data/test2.tsv")
67+
68+
self.assertEqual(c.get_counts()["foo"], 10)
69+
self.assertEqual(c.get_counts()["bar"], 5)
70+
self.assertEqual(c.get_counts()["baz"], 5)
71+
self.assertEqual(c.get_counts(("1",))["foo"], 9)
72+
self.assertEqual(c.get_counts(("1",))["bar"], 5)
73+
self.assertEqual(c.get_counts(("1",))["baz"], 0)
74+
self.assertEqual(c.get_counts(("2",))["foo"], 1)
75+
self.assertEqual(c.get_counts(("2",))["bar"], 0)
76+
self.assertEqual(c.get_counts(("2",))["baz"], 5)
77+
self.assertEqual(c.get_counts(("1", "1"))["foo"], 7)
78+
self.assertEqual(c.get_counts(("1", "1"))["bar"], 4)
79+
self.assertEqual(c.get_counts(("1", "1"))["baz"], 0)
80+
self.assertEqual(c.get_counts(("1", "2"))["foo"], 2)
81+
self.assertEqual(c.get_counts(("1", "2"))["bar"], 0)
82+
self.assertEqual(c.get_counts(("1", "2"))["baz"], 0)
83+
self.assertEqual(c.get_counts(("1", "3"))["foo"], 0)
84+
self.assertEqual(c.get_counts(("1", "3"))["bar"], 1)
85+
self.assertEqual(c.get_counts(("1", "3"))["baz"], 0)
86+
self.assertEqual(c.get_counts(("2", "1"))["foo"], 1)
87+
self.assertEqual(c.get_counts(("2", "1"))["bar"], 0)
88+
self.assertEqual(c.get_counts(("2", "1"))["baz"], 5)
89+
90+
def test_load3(self):
91+
import termdoc
92+
93+
c = termdoc.HTDM()
94+
self.assertRaises(ValueError, c.load, "test_data/test3e.tsv")
95+
96+
def test_load4(self):
97+
import termdoc
98+
99+
c = termdoc.HTDM()
100+
self.assertRaises(ValueError, c.load, "test_data/test4e.tsv")
101+
62102
def test_prune(self):
63103
import termdoc
64104

0 commit comments

Comments
 (0)