|
| 1 | +import regex as re |
| 2 | + |
| 3 | +#raw string |
| 4 | +mystring="هوشمند بیرانوند فروغی مظاهری زینلی معینی دولت رضاییان قادمه حیدرزاده کوچک سیاح حکیمی غلامزاده" |
| 5 | +#g=[ord(x)for x in mystring] |
| 6 | + |
| 7 | + |
| 8 | +#encoding raw txt with utf-8 encoding |
| 9 | +tokens=list(mystring.encode("utf-8")) |
| 10 | + |
| 11 | +#gets the statistic of which pair appear togather more frequently |
| 12 | +def get_stats(ids): |
| 13 | + counts={} |
| 14 | + for pair in zip(ids,ids[1:]): # (ids,ids[1:]) a way to make a silding window to comp 2 elements |
| 15 | + counts[pair]=counts.get(pair,0)+1 |
| 16 | + return counts |
| 17 | +stats=get_stats(tokens) |
| 18 | +#print(stats) |
| 19 | +#print(sorted(((v,k) for k,v in stats.items()),reverse=True)) |
| 20 | +top_pair=max(stats,key=stats.get) |
| 21 | +#print(top_pair) |
| 22 | + |
| 23 | +#replaces the most common pair with a new id index or idx |
| 24 | +def merge(ids,pair,idx): |
| 25 | + newids=[] |
| 26 | + i=0 |
| 27 | + while i <len(ids): |
| 28 | + if i <len(ids) - 1 and ids[i]==pair[0] and ids[i+1]==pair[1]: |
| 29 | + newids.append(idx) |
| 30 | + i+=2 |
| 31 | + else: |
| 32 | + newids.append(ids[i]) |
| 33 | + i+=1 |
| 34 | + return newids |
| 35 | +#print(merge([2,4,578,5,2,4,12,2,4,1,3,2,4,63,256,2453,24,2,4],(2,4),69)) |
| 36 | +"""tokens2=merge(tokens,top_pair,128) |
| 37 | +print(tokens2) |
| 38 | +print("length: ",len(tokens2)) """ |
| 39 | + |
| 40 | +vocabsize=276 |
| 41 | +num_merges=vocabsize-256 |
| 42 | +ids=list(tokens) #so we still have a copy of the og list |
| 43 | +merges={} # (int,int) -> int or (child1,child2 ) turning into a new token |
| 44 | +for i in range(num_merges): |
| 45 | + stats=get_stats(ids) |
| 46 | + pair=max(stats,key=stats.get) |
| 47 | + idx=256+i |
| 48 | + print(f"merging {pair} into a new token {idx}") |
| 49 | + ids=merge(ids,pair,idx) |
| 50 | + merges[pair]=idx |
| 51 | + |
| 52 | + |
| 53 | +print("token length: ",len(tokens)) |
| 54 | +print("ids length:",len(ids)) |
| 55 | +print(f"compression ratio: {len(tokens)/len(ids):.2f}X") |
| 56 | + |
| 57 | + |
| 58 | +#decoding |
| 59 | + |
| 60 | +# pre processing variable |
| 61 | +vocab={idx:bytes([idx]) for idx in range(256)} |
| 62 | +for (p0,p1),idx in merges.items(): |
| 63 | + vocab[idx]=vocab[p0]+vocab[p1] #addition of two bytes object kinda of a concatination |
| 64 | +def decoding(ids): |
| 65 | + #given ids (list of ints) ,return python string\ |
| 66 | + tokens= b"".join(vocab[idx] for idx in ids) |
| 67 | + text=tokens.decode("utf-8",errors='replace') |
| 68 | + return text |
| 69 | + |
| 70 | + #encoding segment |
| 71 | + |
| 72 | +def encoding(text): |
| 73 | + tokens=list(text.encode("utf-8")) |
| 74 | + while len(tokens)>=2: |
| 75 | + stats=get_stats(tokens) |
| 76 | + pair=min(stats, key=lambda p:merges.get(p,float("inf"))) |
| 77 | + if pair not in merges: |
| 78 | + break #nothing else is mergable |
| 79 | + idx=merges[pair] |
| 80 | + tokens=merge(tokens,pair,idx) |
| 81 | + return tokens |
| 82 | +f=encoding('حسینی زاده') |
| 83 | +print(f) |
| 84 | +print(decoding(f)) |
| 85 | + |
| 86 | + |
| 87 | +##print(re.findall(gpt2pat,"heyo 123 123 I've come to you with big MASSIvE news ")) |
0 commit comments