Skip to content
This repository has been archived by the owner on Sep 15, 2022. It is now read-only.

Commit

Permalink
[app] adding engine routines
Browse files Browse the repository at this point in the history
  • Loading branch information
JosephGeoBenjamin committed Oct 29, 2020
1 parent 892889c commit 3419db2
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 79 deletions.
77 changes: 16 additions & 61 deletions apps/api_expose.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import csv

class XlitError(enum.Enum):
lang_err = "Unsupported langauge ID requested"
string_err = "String passed is incompatable"
lang_err = "Unsupported langauge ID requested ;( Please check available languages."
string_err = "String passed is incompatable ;("
internal_err = "Internal crash ;("
unknown_err = "Unknown Failure"
loading_err = "Loading failed ;( Check if metadata/paths are correctly configured."


app = Flask(__name__)
app.config['JSON_AS_ASCII'] = False
Expand Down Expand Up @@ -88,7 +90,7 @@ def xlit_api(lang_code, eng_word):
return jsonify(response)

try:
xlit_result = engine.transliterate(lang_code, eng_word)
xlit_result = engine.translit_word(eng_word, lang_code)
except Exception as e:
xlit_result = XlitError.internal_err

Expand Down Expand Up @@ -126,64 +128,6 @@ def learn_from_user():
write_userdata(data)
return jsonify({'status': 'Success'})

@app.route('/learn_context', methods=['POST'])
def learn_from_context():
data = request.get_json(force=True)
data['user_ip'] = request.remote_addr
data['timestamp'] = str(datetime.utcnow()) + ' +0000 UTC'
write_userdata(data)
return jsonify({'status': 'Success'})


## ----------------------------- Xlit Engine -------------------------------- ##

BASEPATH = os.path.dirname(os.path.realpath(__file__))
sys.path.append(BASEPATH)

class XlitEngine():
def __init__(self):
self.langs = {"hi": "Hindi", "gom": "Konkani (Goan)", "mai": "Maithili"}

try:
from models.hindi.hi_program110 import inference_engine as hindi_engine
self.hindi_engine = hindi_engine
except Exception as error:
print("Failure in loading Hindi \n", error)
del self.langs['hi']

try:
from models.konkani.gom_program116 import inference_engine as konkani_engine
self.konkani_engine = konkani_engine
except Exception as error:
print("Failure in loading Konkani \n", error)
del self.langs['gom']

try:
from models.maithili.mai_program120 import inference_engine as maithili_engine
self.maithili_engine = maithili_engine
except Exception as error:
print("Failure in loading Maithili \n", error)
del self.langs['mai']

def transliterate(self, lang_code, eng_word):
if eng_word == "":
return []

if lang_code not in self.langs:
print("Unknown Langauge requested", lang_code)
return XlitError.lang_err

try:
if lang_code == "hi":
return self.hindi_engine(eng_word)
elif lang_code == "gom":
return self.konkani_engine(eng_word)
elif lang_code == "mai":
return self.maithili_engine(eng_word)

except error as Exception:
print("Error:", error)
return XlitError.unknown_err


## -------------------------- Server Setup ---------------------------------- ##
Expand All @@ -195,6 +139,17 @@ def host_https():
https_server.serve_forever()
return



## ----------------------------- Xlit Engine -------------------------------- ##

BASEPATH = os.path.dirname(os.path.realpath(__file__))
sys.path.append(BASEPATH)
from xlit_src import XlitEngine


## -------------------------------------------------------------------------- ##

if __name__ == '__main__':
engine = XlitEngine()
if not DEBUG: # Production Server
Expand Down
160 changes: 145 additions & 15 deletions apps/xlit_src.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,17 @@
import sys
import os
import json
from annoy import AnnoyIndex
import enum

F_DIR = os.path.dirname(os.path.realpath(__file__))

class XlitError(enum.Enum):
lang_err = "Unsupported langauge ID requested ;( Please check available languages."
string_err = "String passed is incompatable ;("
internal_err = "Internal crash ;("
unknown_err = "Unknown Failure"
loading_err = "Loading failed ;( Check if metadata/paths are correctly configured."

##=================== Network ==================================================

class Encoder(nn.Module):
Expand Down Expand Up @@ -437,19 +444,22 @@ def reposition(self, word_list):

##=============== INSTANTIATION ================================================

class XlitCrankshaft():
class XlitPiston():
"""
For handling prediction & post-processing of transliteration for a single language
Class dependency: Seq2Seq, GlyphStrawboss, VocabSanitizer
Global Variables: F_DIR
"""
def __init__(self):
def __init__(self, weight_path, vocab_file, tglyph_cfg_file,
iglyph_cfg_file = "en", device = "cpu" ):

self.device = "cpu"
self.in_glyph_obj = GlyphStrawboss("en")
self.tgt_glyph_obj = GlyphStrawboss(glyphs = F_DIR+"/models/tamil/ta_scripts.json")
self.voc_sanity = VocabSanitizer(F_DIR+"/models/tamil/ta_words_a4b.json")
self.device = device
self.in_glyph_obj = GlyphStrawboss(iglyph_cfg_file)
self.tgt_glyph_obj = GlyphStrawboss(glyphs = tglyph_cfg_file)
self.voc_sanity = VocabSanitizer(vocab_file)

self._numsym_set = set("1234567890.")
self._numsym_set = set(json.load(open(tglyph_cfg_file))["numsym_map"].keys() )
self._inchar_set = set("abcdefghijklmnopqrstuvwxyz")
self._natscr_set = set().union(self.tgt_glyph_obj.glyphs,
sum(self.tgt_glyph_obj.numsym_map.values(),[]) )
Expand Down Expand Up @@ -485,15 +495,15 @@ def __init__(self):
device = self.device,)
self.model = Seq2Seq(enc, dec, pass_enc2dec_hid=enc2dec_hid, device=self.device)
self.model = self.model.to(self.device)
weights = torch.load( F_DIR+"/models/tamil/ta_101_model.pth", map_location=torch.device(self.device))
weights = torch.load( weight_path, map_location=torch.device(self.device))

self.model.load_state_dict(weights)
self.model.eval()

def character_model(self, word, topk = 10):
def character_model(self, word, beam_width = 1):
in_vec = torch.from_numpy(self.in_glyph_obj.word2xlitvec(word)).to(self.device)
## change to active or passive beam
p_out_list = self.model.active_beam_inference(in_vec, beam_width = topk)
p_out_list = self.model.active_beam_inference(in_vec, beam_width = beam_width)
p_result = [ self.tgt_glyph_obj.xlitvec2word(out.cpu().numpy()) for out in p_out_list]

result = self.voc_sanity.reposition(p_result)
Expand Down Expand Up @@ -544,7 +554,7 @@ def _word_segementer(self, sequence):

return segment

def inferencer(self, sequence):
def inferencer(self, sequence, beam_width = 10):

seg = self._word_segementer(sequence[:120])
lit_seg = []
Expand All @@ -559,7 +569,7 @@ def inferencer(self, sequence):

if model_flag:
if seg[p][0] in self._inchar_set:
lit_seg.append(self.character_model(seg[p]))
lit_seg.append(self.character_model(seg[p], beam_width=beam_width))
p+=1; model_flag = False
else: # num & punc
lit_seg.append(self.numsym_model(seg[p]))
Expand All @@ -577,7 +587,127 @@ def inferencer(self, sequence):

return final_result

from collections.abc import Iterable
class XlitEngine():
"""
For Managing the top level tasks and applications of transliteration
Global Variables: F_DIR
"""
def __init__(self, lang2use = "all", config_path = "models/lineup.json"):

lineup = json.load( open(os.path.join(F_DIR, config_path)) )
if isinstance(lang2use, str):
if lang2use == "all":
self.lang_config = lineup
elif lang2use in lineup:
self.lang_config[lang2use] = lineup[lang2use]
else:
raise "The entered Langauge code not found. Available are {}".format(lineup.keys())

elif isinstance(lang2use, Iterable):
for l in lang2use:
try:
self.lang_config[l] = lineup[l]
except:
print("Language code {} not found, Skipping...".format(l))
else:
raise "lang2use must be a list of language codes (or) string of single language code"

self.langs = {}
self.lang_model = {}
for la in self.lang_config:
try:
print("Loading {}...".format(la) )
self.lang_model[la] = XlitPiston(
weight_path = os.path.join(F_DIR, "models",
self.lang_config[la]["weight"]) ,
vocab_file = os.path.join(F_DIR, "models",
self.lang_config[la]["vocab"]),
tglyph_cfg_file = os.path.join(F_DIR, "models",
self.lang_config[la]["script"]),
iglyph_cfg_file = "en",
)
self.langs[la] = self.lang_config[la]["name"]
except Exception as error:
print("Failure in loading {} \n".format(la), error)
print(XlitError.loading_err.value)


def translit_word(self, eng_word, lang_code = "default", topk = 7, beam_width = 10):
if eng_word == "":
return []

if (lang_code in self.langs):
try:
res_list = self.lang_model[lang_code].inferencer(eng_word, beam_width = 10)
return res_list[topk]

except Exception as error:
print("Error:", error)
print(XlitError.internal_err.value)
return XlitError.internal_err

elif lang_code == "default":
try:
res_dict = {}
for la in self.lang_model:
res = self.lang_model[la].inferencer(eng_word, beam_width = 10)
res_dict[la] = res[topk]
return res_dict

except Exception as error:
print("Error:", error)
print(XlitError.internal_err.value)
return XlitError.internal_err

else:
print("Unknown Langauge requested", lang_code)
print(XlitError.lang_err.value)
return XlitError.lang_err


def translit_sentence(self, eng_sentence, lang_code = "default", beam_width = 10):
if eng_sentence == "":
return []

if (lang_code in self.langs):
try:
out_str = ""
for word in eng_sentence.split():
res_ = self.lang_model[lang_code].inferencer(word, beam_width = 10)
out_str = out_str + res_[0] + " "
return out_str[:-1]

except Exception as error:
print("Error:", error)
print(XlitError.internal_err.value)
return XlitError.internal_err

elif lang_code == "default":
try:
res_dict = {}
for la in self.lang_model:
out_str = ""
for word in eng_sentence.split():
res_ = self.lang_model[la].inferencer(word, beam_width = 10)
out_str = out_str + res_[0] + " "
res_dict[la] = out_str[:-1]
return res_dict

except Exception as error:
print("Error:", error)
print(XlitError.internal_err.value)
return XlitError.internal_err

else:
print("Unknown Langauge requested", lang_code)
print(XlitError.lang_err.value)
return XlitError.lang_err


if __name__ == "__main__":

handle = XlitCrankshaft()
print(handle.inferencer("123"))
engine = XlitEngine()
y = engine.translit_sentence("Hello World !")
print(y)
7 changes: 4 additions & 3 deletions utilities/lang_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ def __init__(self, glyphs = 'en'):
glyphs: json file with script information
"""
if glyphs == 'en':
self.glyphs = english_smallcase
self.glyphs = [chr(alpha) for alpha in range(97, 122+1)]
else:
glyph_data = json.load(open(glyphs))
self.glyphs = glyph_data["glyphs"]
self.dossier = json.load(open(glyphs))
self.glyphs = self.dossier["glyphs"]
self.numsym_map = self.dossier["numsym_map"]

self.char2idx = {}
self.idx2char = {}
Expand Down

0 comments on commit 3419db2

Please sign in to comment.