-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathBiasGroup.py
156 lines (127 loc) · 4.72 KB
/
BiasGroup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from typing import Any, Dict, Iterable, List, Union
from novelai_api.Preset import Model
from novelai_api.utils import tokenize_if_not
class BiasGroup:
_sequences: List[Union[List[int], str]]
bias: float
ensure_sequence_finish: bool
generate_once: bool
enabled: bool
def __init__(
self,
bias: float,
ensure_sequence_finish: bool = False,
generate_once: bool = False,
enabled: bool = True,
):
"""
Create a bias group
:param bias: Bias value of the bias group. Negative is a downbias, positive is an upbias
:param ensure_sequence_finish: Ensures the bias completes
:param generate_once: Only biases for the first occurrence
:param enabled: Is the bias group enabled
"""
self._sequences = []
self.bias = bias
self.ensure_sequence_finish = ensure_sequence_finish
self.generate_once = generate_once
self.enabled = enabled
@classmethod
def from_data(cls, data: Dict[str, Any]) -> "BiasGroup":
"""
Create a bias group from bias group data
"""
# FIXME: wtf is "whenInactive" in bias ?
ensure_sequence_finish = (
data["ensureSequenceFinish"]
if "ensureSequenceFinish" in data
else data["ensure_sequence_finish"]
if "ensure_sequence_finish" in data
else False
)
generate_once = (
data["generateOnce"]
if "generateOnce" in data
else data["generate_once"]
if "generate_once" in data
else False
)
b = cls(data["bias"], ensure_sequence_finish, generate_once, data["enabled"])
if "phrases" in data:
b.add(*data["phrases"])
return b
def add(
self,
*sequences: Union[Dict[str, List[List[int]]], Dict[str, List[int]], List[int], str],
) -> "BiasGroup":
"""
Add elements to the bias group. Elements can be string or tokenized strings
Using tokenized strings is not recommended, for flexibility between tokenizers
"""
for i, sequence in enumerate(sequences):
if isinstance(sequence, dict):
if "sequence" in sequence:
sequence = sequence["sequence"]
elif "sequences" in sequence:
sequence = sequence["sequences"][0]
if not isinstance(sequence, str):
if not isinstance(sequence, list):
raise ValueError(
f"Expected type 'List[int]' for sequence #{i} of 'sequences', " f"but got '{type(sequence)}'"
)
for j, s in enumerate(sequence):
if not isinstance(s, int):
raise ValueError(
f"Expected type 'int' for item #{j} of sequence #{i} of 'sequences', "
f"but got '{type(s)}': {sequence}"
)
self._sequences.append(sequence)
return self
def __iadd__(
self, sequences: Union[Dict[str, List[List[int]]], Dict[str, List[int]], List[int], str]
) -> "BiasGroup":
"""
Add elements to the bias group. Elements can be string or tokenized strings
Using tokenized strings is not recommended, for flexibility between tokenizers
"""
self.add(sequences)
return self
def __iter__(self):
"""
Return an iterator on the stored sequences
"""
return (
{
"bias": self.bias,
"ensure_sequence_finish": self.ensure_sequence_finish,
"generate_once": self.generate_once,
"enabled": self.enabled,
"sequence": s,
}
for s in self._sequences
)
def get_tokenized_entries(self, model: Model) -> Iterable[Dict[str, any]]:
"""
Return the tokenized sequences for the bias group, if it is enabled
:param model: Model to use for tokenization
"""
return (
{
"bias": self.bias,
"ensure_sequence_finish": self.ensure_sequence_finish,
"generate_once": self.generate_once,
"sequence": tokenize_if_not(model, s),
}
for s in self._sequences
if self.enabled
)
def __str__(self) -> str:
return (
"{ "
f"bias: {self.bias}, "
f"ensure_sequence_finish: {self.ensure_sequence_finish}, "
f"generate_once: {self.generate_once}, "
f"enabled: {self.enabled}, "
f"sequences: {self._sequences}"
"}"
)