diff --git a/datasets/create_chitchat_dataset.py b/datasets/create_chitchat_dataset.py deleted file mode 100644 index d5ed426f..00000000 --- a/datasets/create_chitchat_dataset.py +++ /dev/null @@ -1,325 +0,0 @@ -import os -import json -import random -import re - -from tqdm import tqdm - - -_path = os.path.dirname(__file__) -_samsum_train_path = os.path.join(_path, "data/samsum-train.json") -_squad_train_path = os.path.join(_path, "data/squad2-train.json") -_squad_filter_path = os.path.join(_path, "data/squad_items_about_people.json") -_candidate_answers = [ - "unknown", - "I don't know", - "I do not know", - "I have no information about this", -] -_unknown_fraction = 0.1 -_context_fraction = 0.2 - - -def get_speakers(dialogue): - speakers = set() - for line in dialogue.split("\n"): - name = line[: line.find(":")] - speakers.add(name) - - return list(speakers) - - -def select_random_pair_of_speakers(candidates): - random.shuffle(candidates) - return candidates[:2] - - -def create_inset_from_unanswerable_question(squad_set, first_speaker, second_speaker): - data = squad_set["data"] - item = random.choice(data) - paragraph = random.choice(item["paragraphs"]) - qas = random.choice(paragraph["qas"]) - question = qas["question"] - answer = random.choice(_candidate_answers) - return f"{first_speaker}: {question}\n{second_speaker}: {answer}\n" - - -def from_name_to_2nd_person(text, name): - text = re.sub(f"{name} doesn't", "you don't", text, flags=re.IGNORECASE) - text = re.sub(f"{name} does not", "you do not", text, flags=re.IGNORECASE) - text = re.sub(f"{name} does", "you do", text, flags=re.IGNORECASE) - text = re.sub(f"{name}'s", "your", text, flags=re.IGNORECASE) - text = re.sub(f"does {name}", "do you", text, flags=re.IGNORECASE) - text = re.sub(f"is {name}", "are you", text, flags=re.IGNORECASE) - text = re.sub(f"was {name}", "were you", text, flags=re.IGNORECASE) - text = re.sub(f"{name} is", "you are", text, flags=re.IGNORECASE) - text = re.sub(f"{name}", "you", text, flags=re.IGNORECASE) - return text - - -def from_name_to_1st_person(text, name): - text = re.sub(f"{name} doesn't", "I don't", text, flags=re.IGNORECASE) - text = re.sub(f"{name} does not", "I do not", text, flags=re.IGNORECASE) - text = re.sub(f"{name} does", "I do", text, flags=re.IGNORECASE) - text = re.sub(f"{name}'s", "my", text, flags=re.IGNORECASE) - text = re.sub(f"does {name}", "do I", text, flags=re.IGNORECASE) - text = re.sub(f"is {name}", "am I", text, flags=re.IGNORECASE) - text = re.sub(f"was {name}", "was I", text, flags=re.IGNORECASE) - text = re.sub(f"{name} is", "I am", text, flags=re.IGNORECASE) - text = re.sub(f"to {name}", "to me", text, flags=re.IGNORECASE) - text = re.sub(f"{name}", "I", text, flags=re.IGNORECASE) - return text - - -def from_2nd_person_to_name(text, name): - text = re.sub("you don't", f"{name} doesn't", text, flags=re.IGNORECASE) - text = re.sub("you do not", f"{name} does not", text, flags=re.IGNORECASE) - text = re.sub("you do", f"{name} does", text, flags=re.IGNORECASE) - text = re.sub("your", f"{name}'s", text, flags=re.IGNORECASE) - text = re.sub("do you", f"does {name}", text, flags=re.IGNORECASE) - text = re.sub("are you", f"is {name}", text, flags=re.IGNORECASE) - text = re.sub("were you", f"was {name}", text, flags=re.IGNORECASE) - text = re.sub("you are", f"{name} is", text, flags=re.IGNORECASE) - text = re.sub("you will", f"{name} will", text, flags=re.IGNORECASE) - text = re.sub("you'll", f"{name} will", text, flags=re.IGNORECASE) - text = re.sub(" you ", f" {name} ", text, flags=re.IGNORECASE) - text = re.sub(" you\.", f" {name}\.", text, flags=re.IGNORECASE) - text = re.sub(" you!", f" {name}!", text, flags=re.IGNORECASE) - text = re.sub(" you\?", f" {name}\?", text, flags=re.IGNORECASE) - return text - - -def from_1st_person_to_name(text, name): - text = re.sub("I don't", f"{name} doesn't", text, flags=re.IGNORECASE) - text = re.sub("I do not", f"{name} does not", text, flags=re.IGNORECASE) - text = re.sub("I do", f"{name} does", text, flags=re.IGNORECASE) - text = re.sub("my ", f"{name}'s ", text, flags=re.IGNORECASE) - text = re.sub("do I", f"does {name}", text, flags=re.IGNORECASE) - text = re.sub("am I", f"is {name}", text, flags=re.IGNORECASE) - text = re.sub("was I", f"was {name}", text, flags=re.IGNORECASE) - text = re.sub("I am", f"{name} is", text, flags=re.IGNORECASE) - text = re.sub("to me", f"to {name}", text, flags=re.IGNORECASE) - text = re.sub("I will", f"{name} will", text, flags=re.IGNORECASE) - text = re.sub("I'll", f"{name} will", text, flags=re.IGNORECASE) - text = re.sub("I'm", f"{name} is", text) - text = re.sub("I ", f"{name} ", text) - text = re.sub(" I\?", f" {name}\?", text) - text = re.sub(" me ", f" {name} ", text, flags=re.IGNORECASE) - text = re.sub(" me\.", f" {name}\.", text, flags=re.IGNORECASE) - text = re.sub(" me!", f" {name}!", text, flags=re.IGNORECASE) - text = re.sub(" me\?", f" {name}\?", text, flags=re.IGNORECASE) - return text - - -def replace_names(text, names, replace_function): - names = sorted(names, key=lambda x: -len(x)) - for name in names: - if name in text: - return replace_function(text, name) - - return text - - -def create_inset_with_first_person_answer( - squad_set, squad_people_filter, first_speaker, second_speaker -): - squad_item_number, names = random.sample(squad_people_filter.items(), 1)[0] - squad_item_number = int(squad_item_number) - names = names["names"] - question, answer = "", "" - while ( - "you" not in question.lower() - and "your" not in question.lower() - and "I " not in answer - and "my" not in answer.lower() - ): - paragraph = random.choice(squad_set["data"][squad_item_number]["paragraphs"]) - qas = random.choice(paragraph["qas"]) - if not qas["answers"]: - continue - - question = replace_names(qas["question"], names, from_name_to_2nd_person) - answer = replace_names( - random.choice(qas["answers"])["text"], names, from_name_to_1st_person - ) - - context = replace_names( - paragraph["context"], names, lambda x, y: x.replace(y, second_speaker) - ) - return f"{first_speaker}: {question}\n{second_speaker}: {answer}\n", context - - -def create_inset_with_first_person_query( - squad_set, squad_people_filter, first_speaker, second_speaker -): - squad_item_number, names = random.sample(squad_people_filter.items(), 1)[0] - squad_item_number = int(squad_item_number) - names = names["names"] - question, answer = "", "" - while ( - "I" not in question.lower() - and "my" not in question.lower() - and "you " not in answer.lower() - and "your " not in answer.lower() - ): - paragraph = random.choice(squad_set["data"][squad_item_number]["paragraphs"]) - qas = random.choice(paragraph["qas"]) - if not qas["answers"]: - continue - - question = replace_names(qas["question"], names, from_name_to_1st_person) - answer = replace_names( - random.choice(qas["answers"])["text"], names, from_name_to_2nd_person - ) - - context = replace_names( - paragraph["context"], names, lambda x, y: x.replace(y, second_speaker) - ) - return f"{first_speaker}: {question}\n{second_speaker}: {answer}\n", context - - -def is_question(line): - return "?" in line - - -def get_sequence_of_speakers(dialogue_lines): - return [line.split(":")[0] for line in dialogue_lines if ":" in line] - - -def find_next_speaker(speaker_sequence, index, curr_speaker): - for speaker in speaker_sequence[index + 1 :]: - if curr_speaker != speaker: - return speaker - - raise RuntimeWarning("No next speaker in conversation.") - - -def find_prior_speaker(speaker_sequence, index, curr_speaker): - for speaker in speaker_sequence[:index][::-1]: - if curr_speaker != speaker: - return speaker - - raise RuntimeError("No prior speaker in conversation.") - - -def substitute_pronouns_with_speaker_names(dialogue_text): - dialogue_lines = [line for line in dialogue_text.split("\n") if line] - speaker_sequence = get_sequence_of_speakers(dialogue_lines) - new_lines = [] - for index in range(len(dialogue_lines) - 1): - line = dialogue_lines[index] - curr_speaker = speaker_sequence[index] - if "remembers" in curr_speaker: - new_lines.append(line) - continue - - new_line = from_1st_person_to_name(line, curr_speaker) - try: - next_speaker = find_next_speaker(speaker_sequence, index, curr_speaker) - - except RuntimeWarning: - new_lines.append(new_line) - break - - new_line = from_2nd_person_to_name(new_line, next_speaker) - new_lines.append(new_line) - - new_line = from_1st_person_to_name(dialogue_lines[-1], speaker_sequence[-1]) - try: - prior_speaker = find_prior_speaker(speaker_sequence, -1, speaker_sequence[-1]) - - except RuntimeWarning: - new_lines.append(new_line) - return "\n".join(new_lines) - - new_line = from_2nd_person_to_name(new_line, prior_speaker) - new_lines.append(new_line) - - return "\n".join(new_lines) - - -if __name__ == "__main__": - samsum_train = json.load(open(_samsum_train_path)) - squad_train = json.load(open(_squad_train_path)) - squad_people_filter = json.load(open(_squad_filter_path)) - - new_train_set = [] - for item in tqdm(samsum_train[:1000]): - new_item = {} - dialogue = item["dialogue"].replace("\r", "") - if not dialogue: - continue - - speakers = get_speakers(dialogue) - first, second = select_random_pair_of_speakers(speakers) - inset = create_inset_from_unanswerable_question(squad_train, first, second) - first_person_answer, sp_context = create_inset_with_first_person_answer( - squad_train, squad_people_filter, first, second - ) - first_person_query, fp_context = create_inset_with_first_person_query( - squad_train, squad_people_filter, first, second - ) - - new_dialogue = "" - num_lines = len(dialogue.split("\n")) - unknown_inserted_before = False - first_person_answer_inserted_before = False - first_person_query_inserted_before = False - - for line in dialogue.split("\n"): - new_dialogue += line + "\n" - if line and is_question(line): - continue - - threshold = _unknown_fraction / num_lines - context_threshold = _context_fraction / num_lines - if random.uniform(0, 1) < threshold and not unknown_inserted_before: - new_dialogue += inset - unknown_inserted_before = True - - elif ( - random.uniform(0, 1) < context_threshold - and not first_person_answer_inserted_before - ): - if random.choice([1, 0]): - new_dialogue += f"{second} remembers: " + sp_context + "\n" - first_person_answer = first_person_answer.replace( - f"{second}:", f"{second}: [factual]" - ) - - else: - new_dialogue += f"{second}: " + sp_context + "\n" - first_person_answer = first_person_answer.replace( - f"{second}:", f"{second}: [answer in conversation]" - ) - - new_dialogue += first_person_answer - first_person_answer_inserted_before = True - continue - - elif ( - random.uniform(0, 1) < context_threshold - and not first_person_query_inserted_before - ): - if random.choice([1, 0]): - new_dialogue += f"{second} remembers: " + fp_context + "\n" - first_person_query = first_person_query.replace( - f"{second}:", f"{second}: [factual]" - ) - - else: - new_dialogue += f"{first}: " + fp_context + "\n" - first_person_query = first_person_query.replace( - f"{second}:", f"{second}: [answer in conversation]" - ) - - new_dialogue += first_person_query - first_person_answer_inserted_before = True - - new_item["dialogue"] = ( - "In the dialogue below some people are talking:\n" - + substitute_pronouns_with_speaker_names(new_dialogue) - ) - new_train_set.append(new_item) - - json.dump(new_train_set, open(os.path.join(_path, "data/dialogues.json"), "w")) diff --git a/datasets/create_rules_dataset.py b/datasets/create_rules_dataset.py deleted file mode 100644 index c711745a..00000000 --- a/datasets/create_rules_dataset.py +++ /dev/null @@ -1,45 +0,0 @@ -import asyncio -import pandas as pd - -from wafl.config import Configuration -from wafl.connectors.remote.remote_llm_connector import RemoteLLMConnector - - -def get_prompt(df, theme): - prompt = "" - for _, row in df.sample(9).iterrows(): - prompt += ( - f""" - -Create a plausible dialogue about the theme \"{row["Theme"]}\" based on the following summary and rules. - -The rules are as follows: -{row["Rules"]} - -The conversation goes as follows: -{row["Conversation"]} - - """.strip() - + "\n\n" - ) - - return ( - prompt - + f'\nCreate plausible dialogue about the theme "{theme}" based on the following summary and rules.\n\nThe rules are as follows:\n' - ) - - -if __name__ == "__main__": - config = Configuration.load_local_config() - remote_llm_connector = RemoteLLMConnector( - config.get_value("llm_model"), last_strings=[""] - ) - - df = pd.read_csv("data/complex_instructions.csv") - theme = "playing a song that the user likes" - prompt = get_prompt(df, theme) - print( - asyncio.run( - remote_llm_connector.predict(prompt, temperature=0.5, num_tokens=1500) - ) - ) diff --git a/datasets/train_llm_on_rules_dataset.py b/datasets/train_llm_on_rules_dataset.py deleted file mode 100644 index 251bbe9d..00000000 --- a/datasets/train_llm_on_rules_dataset.py +++ /dev/null @@ -1,122 +0,0 @@ -import random - -import pandas as pd -from datasets import Dataset -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - TrainingArguments, - Trainer, - DataCollatorForLanguageModeling, -) - -model_name_or_path = "mistralai/Mistral-7B-Instruct-v0.1" -max_length = 1024 + 512 - - -def get_prompts(df): - prompts = [] - for _, row in df.sample(frac=1).iterrows(): - memory = "" - if memory == "": - memory = "The user has no memory." - - current_rule = row["Rules"] - rules = df.sample(random.choice([1, 2]))["Rules"].tolist() + [current_rule] - random.shuffle(rules) - rules = "\n".join(rules) - prompt = ( - f""" -The user is talking with a chatbot about the theme \"{row["Theme"]}\" based on the following summary. - -{memory} - - -The rules are as follows: - -{rules} - - -The conversation goes as follows: -{row["Conversation"]} - """.strip() - + "\n\n" - ) - prompts.append(prompt) - - return prompts - - -def preprocess_function(sample): - model_inputs = tokenizer( - sample["prompt"], - return_tensors="pt", - max_length=max_length, - padding="max_length", - ) - labels = tokenizer( - sample["prompt"], - return_tensors="pt", - max_length=max_length, - padding="max_length", - ) - - model_inputs["labels"] = labels["input_ids"] - return model_inputs - - -def model_init(): - model = AutoModelForCausalLM.from_pretrained(model_name_or_path) - parameters = model.parameters() - for parameter in parameters: - parameter.requires_grad = False - - model.model.enable_input_require_grads() - model.lm_head.training = True - for index in range(len(model.model.layers)): - model.model.layers[index].self_attn.k_proj.training = True - - return model - - -def create_dataset_from_file(filepath): - df = pd.read_csv(filepath) - prompts = get_prompts(df) - return Dataset.from_dict({"prompt": prompts}) - - -if __name__ == "__main__": - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - tokenizer.pad_token = tokenizer.eos_token - dataset = create_dataset_from_file("data/complex_instructions.csv") - train_dataset = dataset.map( - preprocess_function, batched=True, batch_size=1, num_proc=4 - ) - data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) - learning_rate = 1e-6 - output_dir_name = f"checkpoint_lr{learning_rate}" - training_args = TrainingArguments( - output_dir=output_dir_name, - per_device_train_batch_size=1, - per_device_eval_batch_size=1, - evaluation_strategy="steps", - use_cpu=True, - learning_rate=learning_rate, - num_train_epochs=2, - logging_steps=200, - eval_steps=200, - save_total_limit=1, - ) - model = model_init() - trainer = Trainer( - model=model, - args=training_args, - tokenizer=tokenizer, - data_collator=data_collator, - train_dataset=train_dataset, - ) - trainer.train() - trainer.save_model("wafl-mistral") - model = trainer.model - model.push_to_hub("fractalego/wafl-mistral") - tokenizer.push_to_hub("fractalego/wafl-mistral") diff --git a/documentation/source/configuration.rst b/documentation/source/configuration.rst index b8b0f9e2..b2d4a730 100644 --- a/documentation/source/configuration.rst +++ b/documentation/source/configuration.rst @@ -13,9 +13,11 @@ A typical configuration file looks like this: "deactivate_sound": true, "rules": "rules.yaml", "functions": "functions.py", + "frontend_port": 8081, "llm_model": { "model_host": "localhost", - "model_port": 8080 + "model_port": 8080, + "temperature": 0.4 }, "listener_model": { "model_host": "localhost", @@ -45,7 +47,10 @@ These settings regulate the following: * "functions" is the file containing the functions that can be used in the rules. The default is "functions.py". + * "frontend_port" is the port where the web frontend is running. The default is 8090. + * "llm_model" is the configuration to connect to the LLM model in the backend. The default is "localhost:8080". + The "temperature" parameter is used to set the temperature for the LLM model. The default is 0.4. * "listener_model" is the configuration to connect to the listener model in the backend. The default is "localhost:8080". diff --git a/tests/config.json b/tests/config.json index 4bb4fa90..c78c40b3 100644 --- a/tests/config.json +++ b/tests/config.json @@ -4,6 +4,7 @@ "deactivate_sound": true, "rules": "rules.yaml", "functions": "functions.py", + "max_recursion": 2, "llm_model": { "model_host": "localhost", "model_port": 8080 diff --git a/tests/test_rules.py b/tests/test_rules.py index 619e3cd9..7be70da6 100644 --- a/tests/test_rules.py +++ b/tests/test_rules.py @@ -5,6 +5,7 @@ from wafl.config import Configuration from wafl.events.conversation_events import ConversationEvents from wafl.interface.dummy_interface import DummyInterface +from wafl.parsing.rules_parser import get_facts_and_rules_from_text wafl_example = """ rules: @@ -13,6 +14,14 @@ - the user says their name: - reply casually to the conversation" + + - the user wants to buy coffee: + - the bot says the coffee prices: + - decaf is 1.50 + - regular is 1.00 + - espresso is 2.00 + - ask for which price range + - tell them the right coffee """ @@ -48,3 +57,39 @@ def test__rules_are_not_always_triggered(self): asyncio.run(conversation_events.process_next()) unexpected = "bot: the horse is tall" self.assertNotEqual(unexpected, interface.get_utterances_list()[-1]) + + def test__rules_can_nest(self): + interface = DummyInterface( + to_utter=[ + "I want to buy coffee", + ] + ) + config = Configuration.load_local_config() + config.set_value("rules", wafl_example) + conversation_events = ConversationEvents( + config=config, + interface=interface, + ) + asyncio.run(conversation_events.process_next()) + self.assertIn("decaf", interface.get_facts_and_timestamp()[0][1]) + self.assertIn("regular", interface.get_facts_and_timestamp()[0][1]) + self.assertIn("espresso", interface.get_facts_and_timestamp()[0][1]) + + def test__nested_rules_are_printed_correctly(self): + rule_text = """ +rules: + - the user wants to know the time: + - output "The time is get_time()": + - if the time is before 12:00 say "Good morning" + - if the time is after 12:00 say "Good afternoon" + """.strip() + + facts_and_rules = get_facts_and_rules_from_text(rule_text) + rule = facts_and_rules["rules"][0] + expected = """ +the user wants to know the time + - output "The time is get_time()" + - if the time is before 12:00 say "Good morning" + - if the time is after 12:00 say "Good afternoon" + """.strip() + self.assertEqual(expected, str(rule).strip()) diff --git a/todo.txt b/todo.txt index d0d39bbb..3d726a32 100644 --- a/todo.txt +++ b/todo.txt @@ -1,8 +1,76 @@ +* add config file for model names + - llm model name + - whisper model name + +/* add version name +/* let user decide port for frontend +/* update docs about port +/* push new version +* update pypi with wafl and wafl-llm +* clean code for llm eval and make it public +* update huggingface readme +* read overleaf paper + + +* on wafl_llm make it so only some LLMs are supported +* change speaker model with newer one + + +1) train on more steps + a) try 3 epochs, save each + b) use lr=1e-6 + c) use batch_size=4 + d) do not use 4 bit original model, use 16 bit (on the GPU) +2) evaluate result +3) Upload to hf +4) create a test set of 50 elements for paper. Find a way to test it. repeat from 1) +5) refactor code +6) maybe change voice model +6) write paper + + + ### TODO +* script to add wrong when none are needed + + +On the to_modify set: +* sometimes the user answers yes (after "do you confirm?") and the dialogue does not have "user: yes" + + +On the accepted set: +* CHANGE <|USER|>\n into user: (some of the elements are in the wrong format) +* Perhaps change function() into function() (the memory should store the results of the function) +* Create a first paragraphs with the summary of the conversation: The conversation must always be grounded on the summary (USE LLM TO CREATE THE SUMMARY) +* The LLM wrote text after hallucinating the result of the execution. Think about how to deal with that. +* all the rules that says "two level of retrieval" should have the trigger rewritten to something more specific +* change "bot" into "assistant" some of times +* some sentences are between [] and should be removed +* put the items in so far in the conversation summary. If it is a function then you need to simulaten the relevant output using the LLM +* sometimes at the end of the conversation the bot says "Process finished with exit code 0". Erase this +* add ability to index files and files in entire folders +* if the bot uses a function to retrieve information, you should add . This is symmetrical to with a function call when necessary. +* some tags like should end the training item text +* todo User -> user, or at least be internally consistent + +* find a way to use HuggingFaceH4/ultrachat_200k as a starting point for each item + - each item should be easy to copy into a csv. + - Separate the items with special tokens/lines +* Create a dataset with about 500 elements + - use hugginface chat dataset as a starting point for + - themes + - conversation guide in prompt + - use LLM to create corresponding python code +* retriever in create_prompt +* change num_replicas back to 10 in remote_llm_connector + + /* create actions from command line /* add condition of when to stop to the actions + +Actions: #### Find way to delete cache in remote llm connector #### Put colors in action output (and dummy interface) #### Add green for when an expectation is matched diff --git a/wafl/answerer/dialogue_answerer.py b/wafl/answerer/dialogue_answerer.py index 7fe8c68c..e4b1a160 100644 --- a/wafl/answerer/dialogue_answerer.py +++ b/wafl/answerer/dialogue_answerer.py @@ -10,7 +10,7 @@ get_last_user_utterance, ) from wafl.answerer.base_answerer import BaseAnswerer -from wafl.answerer.rule_creator import RuleCreator +from wafl.answerer.rule_maker import RuleMaker from wafl.connectors.bridges.llm_chitchat_answer_bridge import LLMChitChatAnswerBridge from wafl.exceptions import CloseConversation from wafl.extractors.dataclasses import Query, Answer @@ -31,7 +31,7 @@ def __init__(self, config, knowledge, interface, code_path, logger): self._init_python_module(code_path.replace(".py", "")) self._prior_rule_with_timestamp = None self._max_predictions = 3 - self._rule_creator = RuleCreator( + self._rule_creator = RuleMaker( knowledge, config, interface, @@ -126,7 +126,7 @@ async def _get_relevant_facts( > conversational_timestamp - self._max_num_past_utterances_for_facts ] facts_and_thresholds = await self._knowledge.ask_for_facts_with_threshold( - query, is_from_user=True, threshold=0.8 + query, is_from_user=True, threshold=0.85 ) if facts_and_thresholds: facts = [ @@ -160,9 +160,19 @@ def _init_python_module(self, module_name): self._functions = [item[0] for item in getmembers(self._module, isfunction)] async def _substitute_results_in_answer(self, answer_text): - matches = re.finditer(r"(.*?)", answer_text, re.DOTALL) + matches = re.finditer(r"(.*?)|(.*?\))$", answer_text, re.DOTALL|re.MULTILINE) for match in matches: to_execute = match.group(1) + if not to_execute: + continue + result = await self._run_code(to_execute) + answer_text = answer_text.replace(match.group(0), result) + + matches = re.finditer(r"(.*?\))$", answer_text, re.DOTALL|re.MULTILINE) + for match in matches: + to_execute = match.group(1) + if not to_execute: + continue result = await self._run_code(to_execute) answer_text = answer_text.replace(match.group(0), result) @@ -178,6 +188,13 @@ async def _substitute_memory_in_answer_and_get_memories_if_present( answer_text = answer_text.replace(match.group(0), "") memories.append(to_execute) + matches = re.finditer(r"(.*?\))$", answer_text, re.DOTALL|re.MULTILINE) + memories = [] + for match in matches: + to_execute = match.group(1) + answer_text = answer_text.replace(match.group(0), "") + memories.append(to_execute) + return answer_text, memories async def _run_code(self, to_execute): @@ -206,13 +223,13 @@ async def _run_code(self, to_execute): except Exception as e: result = ( - f'Error while executing\n\n"""python\n{to_execute}\n"""\n\n{str(e)}' + f"Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}" ) traceback.print_exc() break if not result: - result = f'\n"""python\n{to_execute}\n"""' + result = f"\n```python\n{to_execute}\n```" return result diff --git a/wafl/answerer/rule_creator.py b/wafl/answerer/rule_creator.py deleted file mode 100644 index a1a49ba0..00000000 --- a/wafl/answerer/rule_creator.py +++ /dev/null @@ -1,49 +0,0 @@ -class RuleCreator: - def __init__( - self, - knowledge, - config, - interface, - max_num_rules, - delete_current_rule, - max_recursion=1, - ): - self._knowledge = knowledge - self._config = config - self._interface = interface - self._max_num_rules = max_num_rules - self._delete_current_rule = delete_current_rule - self._max_indentation = max_recursion - self._indent_str = " " - - async def create_from_query(self, query): - rules = await self._knowledge.ask_for_rule_backward(query) - rules = rules[: self._max_num_rules] - rules_texts = [] - for rule in rules: - rules_text = f"- If {rule.effect.text} go through the following points:\n" - for cause_index, cause in enumerate(rule.causes): - rules_text += f"{self._indent_str}{cause_index + 1}) {cause.text}\n" - rules_text += await self.recursively_add_rules(cause) - - rules_text += f'{self._indent_str}{len(rule.causes) + 1}) After you completed all the steps output "{self._delete_current_rule}" and continue the conversation.\n' - - rules_texts.append(rules_text) - await self._interface.add_fact(f"The bot remembers the rule:\n{rules_text}") - - return "\n".join(rules_texts) - - async def recursively_add_rules(self, query, depth=2): - rules = await self._knowledge.ask_for_rule_backward(query, threshold=0.95) - rules = rules[: self._max_num_rules] - rules_texts = [] - for rule in rules: - rules_text = f"- If {rule.effect.text} go through the following points:\n" - for cause_index, causes in enumerate(rule.causes): - indentation = self._indent_str * depth - rules_text += f"{indentation}{cause_index + 1}) {causes.text}\n" - rules_text += await self.recursively_add_rules(causes.text, depth + 1) - - rules_texts.append(rules_text) - - return "\n".join(rules_texts) diff --git a/wafl/answerer/rule_maker.py b/wafl/answerer/rule_maker.py new file mode 100644 index 00000000..f9775d21 --- /dev/null +++ b/wafl/answerer/rule_maker.py @@ -0,0 +1,34 @@ +class RuleMaker: + def __init__( + self, + knowledge: "Knowledge", + config: "BaseConfig", + interface: "BaseInterface", + max_num_rules: int, + delete_current_rule: str, + max_recursion: int = 3, + ): + self._knowledge = knowledge + self._config = config + self._interface = interface + self._max_num_rules = max_num_rules + self._delete_current_rule = delete_current_rule + if not config.get_value("max_recursion"): + self._max_indentation = max_recursion + else: + self._max_indentation = config.get_value("max_recursion") + + async def create_from_query(self, query): + rules = await self._knowledge.ask_for_rule_backward(query, threshold=0.92) + rules = rules[: self._max_num_rules] + rules_texts = [] + for rule in rules: + rules_text = rule.get_string_using_template( + "- If {effect} go through the following points:" + ) + rules_text += f'{rule.indent_str}- After you completed all the steps output "{self._delete_current_rule}" and continue the conversation.\n' + + rules_texts.append(rules_text) + await self._interface.add_fact(f"The bot remembers the rule:\n{rules_text}") + + return "\n".join(rules_texts) diff --git a/wafl/command_line.py b/wafl/command_line.py index 14fbb6cd..fc68db06 100644 --- a/wafl/command_line.py +++ b/wafl/command_line.py @@ -7,7 +7,7 @@ run_from_command_line, run_testcases, print_incipit, - download_models + download_models, ) from wafl.runners.run_from_actions import run_action from wafl.runners.run_from_audio import run_from_audio diff --git a/wafl/connectors/base_llm_connector.py b/wafl/connectors/base_llm_connector.py index 806164e4..1bd57a63 100644 --- a/wafl/connectors/base_llm_connector.py +++ b/wafl/connectors/base_llm_connector.py @@ -1,7 +1,6 @@ import logging import re - from wafl.connectors.utils import select_best_answer _system_logger = logging.getLogger(__file__) @@ -39,11 +38,7 @@ async def generate(self, prompt: str) -> str: text = prompt start = len(text) - while ( - all(item not in text[start:] for item in self._last_strings) - and len(text) < start + self._max_reply_length - ): - text += select_best_answer(await self.predict(text), self._last_strings) + text += select_best_answer(await self.predict(text), self._last_strings) end_set = set() for item in self._last_strings: @@ -59,7 +54,7 @@ async def generate(self, prompt: str) -> str: if end_set: end = min(end_set) - candidate_answer = text[start:end].split("bot: ")[-1].strip() + candidate_answer = text[start:end].strip() candidate_answer = re.sub(r"(.*)<\|.*\|>", r"\1", candidate_answer).strip() if prompt not in self._cache: diff --git a/wafl/connectors/remote/remote_llm_connector.py b/wafl/connectors/remote/remote_llm_connector.py index d232df7d..e9e3c36e 100644 --- a/wafl/connectors/remote/remote_llm_connector.py +++ b/wafl/connectors/remote/remote_llm_connector.py @@ -1,7 +1,10 @@ +import json + import aiohttp import asyncio from wafl.connectors.base_llm_connector import BaseLLMConnector +from wafl.variables import is_supported class RemoteLLMConnector(BaseLLMConnector): @@ -9,13 +12,14 @@ class RemoteLLMConnector(BaseLLMConnector): _max_reply_length = 1024 _num_prediction_tokens = 200 _cache = {} - _num_replicas = 10 - def __init__(self, config, last_strings=None): + def __init__(self, config, last_strings=None, num_replicas=3): super().__init__(last_strings) host = config["model_host"] port = config["model_port"] + self._default_temperature = config["temperature"] self._server_url = f"https://{host}:{port}/predictions/bot" + self._num_replicas = num_replicas try: loop = asyncio.get_running_loop() @@ -28,28 +32,37 @@ def __init__(self, config, last_strings=None): ): raise RuntimeError("Cannot connect a running LLM.") - async def predict(self, prompt: str, temperature=None, num_tokens=None) -> [str]: + async def predict( + self, prompt: str, temperature=None, num_tokens=None, num_replicas=None + ) -> [str]: if not temperature: - temperature = 0.5 + temperature = self._default_temperature if not num_tokens: num_tokens = self._num_prediction_tokens + if not num_replicas: + num_replicas = self._num_replicas + payload = { "data": prompt, "temperature": temperature, "num_tokens": num_tokens, "last_strings": self._last_strings, - "num_replicas": self._num_replicas, + "num_replicas": num_replicas, } for _ in range(self._max_tries): async with aiohttp.ClientSession( - connector=aiohttp.TCPConnector(ssl=False) + conn_timeout=6000, + connector=aiohttp.TCPConnector(ssl=False), ) as session: async with session.post(self._server_url, json=payload) as response: - answer = await response.text() - return answer.split("<||>") + answer = json.loads(await response.text()) + status = answer["status"] + if status != "success": + raise RuntimeError(f"Error in prediction: {answer}") + return answer["prediction"].split("<||>") return [""] @@ -66,7 +79,14 @@ async def check_connection(self): conn_timeout=3, connector=aiohttp.TCPConnector(ssl=False) ) as session: async with session.post(self._server_url, json=payload) as response: - await response.text() + answer = json.loads(await response.text()) + wafl_llm_version = answer["version"] + print(f"Connected to wafl-llm v{wafl_llm_version}.") + if not is_supported(wafl_llm_version): + print("This version of wafl-llm is not supported.") + print("Please update wafl-llm.") + raise aiohttp.client.InvalidURL + return True except aiohttp.client.InvalidURL: diff --git a/wafl/connectors/remote/remote_whisper_connector.py b/wafl/connectors/remote/remote_whisper_connector.py index 3b7a8cba..d9498a83 100644 --- a/wafl/connectors/remote/remote_whisper_connector.py +++ b/wafl/connectors/remote/remote_whisper_connector.py @@ -38,6 +38,11 @@ async def predict(self, waveform, hotword=None) -> Dict[str, float]: async with session.post(self._server_url, json=payload) as response: data = await response.text() prediction = json.loads(data) + if "transcription" not in prediction: + raise RuntimeError( + "No transcription found in prediction. Is your microphone working?" + ) + transcription = prediction["transcription"] score = prediction["score"] logp = prediction["logp"] diff --git a/wafl/events/conversation_events.py b/wafl/events/conversation_events.py index cf38856b..e00572f1 100644 --- a/wafl/events/conversation_events.py +++ b/wafl/events/conversation_events.py @@ -1,5 +1,6 @@ import os import re +import traceback from wafl.events.answerer_creator import create_answerer from wafl.simple_text_processing.normalize import normalized @@ -59,6 +60,7 @@ async def _process_query(self, text: str): if ( not text_is_question + and self._interface.get_utterances_list() and self._interface.get_utterances_list()[-1].find("user:") == 0 ): await self._interface.output("I don't know what to reply") @@ -108,7 +110,7 @@ def reload_knowledge(self): def reset_discourse_memory(self): self._answerer = create_answerer( - self._config, self._knowledge, self._interface, logger + self._config, self._knowledge, self._interface, self._logger ) def _activation_word_in_text(self, activation_word, text): diff --git a/wafl/facts.py b/wafl/facts.py index 23db2067..c2a3eec2 100644 --- a/wafl/facts.py +++ b/wafl/facts.py @@ -1,9 +1,10 @@ from dataclasses import dataclass +from typing import Union @dataclass class Fact: - text: str + text: Union[str, dict] is_question: bool = False variable: str = None is_interruption: bool = False diff --git a/wafl/frontend/index.html b/wafl/frontend/index.html index bc7283f1..8137ca99 100644 --- a/wafl/frontend/index.html +++ b/wafl/frontend/index.html @@ -39,6 +39,37 @@ +
  • + + + + + + + + + +
  • diff --git a/wafl/frontend/selector.html b/wafl/frontend/selector.html index 15ad1ce6..0752552e 100644 --- a/wafl/frontend/selector.html +++ b/wafl/frontend/selector.html @@ -2,7 +2,6 @@ WAFL frontend - diff --git a/wafl/interface/base_interface.py b/wafl/interface/base_interface.py index 18089c28..cb54f27c 100644 --- a/wafl/interface/base_interface.py +++ b/wafl/interface/base_interface.py @@ -20,6 +20,9 @@ async def input(self) -> str: def bot_has_spoken(self, to_set: bool = None): raise NotImplementedError + async def insert_input(self, text: str): + pass + def is_listening(self): return self._is_listening @@ -32,9 +35,6 @@ def deactivate(self): self._facts = [] self._utterances = [] - def add_hotwords(self, hotwords: List[str]): - raise NotImplementedError - async def add_choice(self, text): self._choices.append((time.time(), text)) await self.output(f"Making the choice: {text}", silent=True) @@ -60,8 +60,17 @@ def reset_history(self): self._choices = [] self._facts = [] + def add_hotwords(self, hotwords): + pass + def _decorate_reply(self, text: str) -> str: if not self._decorator: return text return self._decorator.extract(text, self._utterances) + + def _insert_utterance(self, speaker, text: str): + if self._utterances == [] or text != self._utterances[-1][1].replace( + f"{speaker}: ", "" + ): + self._utterances.append((time.time(), f"{speaker}: {text}")) diff --git a/wafl/interface/list_interface.py b/wafl/interface/list_interface.py new file mode 100644 index 00000000..c7416d2e --- /dev/null +++ b/wafl/interface/list_interface.py @@ -0,0 +1,53 @@ +import asyncio +from typing import List + +from wafl.interface.base_interface import BaseInterface + + +class ListInterface(BaseInterface): + def __init__(self, interfaces_list: List[BaseInterface]): + super().__init__() + self._interfaces_list = interfaces_list + self._synchronize_interfaces() + + async def output(self, text: str, silent: bool = False): + await asyncio.wait( + [interface.output(text, silent) for interface in self._interfaces_list], + return_when=asyncio.ALL_COMPLETED, + ) + + async def input(self) -> str: + done, pending = await asyncio.wait( + [interface.input() for interface in self._interfaces_list], + return_when=asyncio.FIRST_COMPLETED, + ) + return done.pop().result() + + async def insert_input(self, text: str): + await asyncio.wait( + [interface.insert_input(text) for interface in self._interfaces_list], + return_when=asyncio.ALL_COMPLETED, + ) + + def bot_has_spoken(self, to_set: bool = None): + for interface in self._interfaces_list: + interface.bot_has_spoken(to_set) + + def activate(self): + for interface in self._interfaces_list: + interface.activate() + super().activate() + + def deactivate(self): + for interface in self._interfaces_list: + interface.deactivate() + super().deactivate() + self._synchronize_interfaces() + + def add_hotwords(self, hotwords): + for interface in self._interfaces_list: + interface.add_hotwords(hotwords) + + def _synchronize_interfaces(self): + for interface in self._interfaces_list: + interface._utterances = self._utterances diff --git a/wafl/interface/queue_interface.py b/wafl/interface/queue_interface.py index 08fdc247..cf14c9ca 100644 --- a/wafl/interface/queue_interface.py +++ b/wafl/interface/queue_interface.py @@ -1,5 +1,4 @@ import asyncio -import time from wafl.interface.base_interface import BaseInterface @@ -16,9 +15,8 @@ async def output(self, text: str, silent: bool = False): self.output_queue.append({"text": text, "silent": True}) return - utterance = text - self.output_queue.append({"text": utterance, "silent": False}) - self._utterances.append((time.time(), f"bot: {text}")) + self.output_queue.append({"text": text, "silent": False}) + self._insert_utterance("bot", text) self.bot_has_spoken(True) async def input(self) -> str: @@ -26,9 +24,12 @@ async def input(self) -> str: await asyncio.sleep(0.1) text = self.input_queue.pop(0) - self._utterances.append((time.time(), f"user: {text}")) + self._insert_utterance("user", text) return text + async def insert_input(self, text: str): + self.input_queue.append(text) + def bot_has_spoken(self, to_set: bool = None): if to_set != None: self._bot_has_spoken = to_set diff --git a/wafl/interface/voice_interface.py b/wafl/interface/voice_interface.py index 46e53999..b3a1c2f1 100644 --- a/wafl/interface/voice_interface.py +++ b/wafl/interface/voice_interface.py @@ -1,12 +1,10 @@ import os import random import re -import time from wafl.events.utils import remove_text_between_brackets -from wafl.simple_text_processing.deixis import from_bot_to_user from wafl.interface.base_interface import BaseInterface -from wafl.interface.utils import get_most_common_words, not_good_enough +from wafl.interface.utils import not_good_enough from wafl.listener.whisper_listener import WhisperListener from wafl.speaker.fairseq_speaker import FairSeqSpeaker from wafl.speaker.soundfile_speaker import SoundFileSpeaker @@ -42,17 +40,6 @@ def __init__(self, config): self._bot_has_spoken = False self._utterances = [] - async def add_hotwords_from_knowledge( - self, knowledge: "Knowledge", max_num_words: int = 100, count_threshold: int = 5 - ): - hotwords = get_most_common_words( - knowledge.get_facts_and_rule_as_text(), - max_num_words=max_num_words, - count_threshold=count_threshold, - ) - hotwords = [word.lower() for word in hotwords] - self._listener.add_hotwords(hotwords) - def add_hotwords(self, hotwords): self._listener.add_hotwords(hotwords) @@ -65,8 +52,8 @@ async def output(self, text: str, silent: bool = False): return self._listener.activate() - text = from_bot_to_user(text) - self._utterances.append((time.time(), f"bot: {text}")) + text = text + self._insert_utterance("bot", text) print(COLOR_START + "bot> " + text + COLOR_END) await self._speaker.speak(text) self.bot_has_spoken(True) @@ -89,7 +76,7 @@ async def input(self) -> str: print(COLOR_START + "user> " + text + COLOR_END) utterance = remove_text_between_brackets(text) if utterance.strip(): - self._utterances.append((time.time(), f"user: {text}")) + self._insert_utterance("user", text) return text diff --git a/wafl/listener/whisper_listener.py b/wafl/listener/whisper_listener.py index 4fbe75ff..f2b7770f 100644 --- a/wafl/listener/whisper_listener.py +++ b/wafl/listener/whisper_listener.py @@ -91,7 +91,12 @@ async def input(self): while True: await asyncio.sleep(0) - inp = self.stream.read(self._chunk) + try: + inp = self.stream.read(self._chunk) + except IOError: + self.activate() + inp = self.stream.read(self._chunk) + rms_val = _rms(inp) if rms_val > self._volume_threshold: waveform = self.record(start_with=inp) diff --git a/wafl/rules.py b/wafl/rules.py index 3e3d32f6..fc4837cf 100644 --- a/wafl/rules.py +++ b/wafl/rules.py @@ -6,15 +6,24 @@ class Rule: effect: "Fact" causes: List["Fact"] + _max_indentation = 3 + indent_str = " " def toJSON(self): return str(self) + def get_string_using_template(self, effect_template: str) -> str: + rule_str = effect_template.replace("{effect}", self.effect.text) + "\n" + return self._add_clauses(rule_str) + def __str__(self): - rule_str = self.effect.text + rule_str = self.effect.text + "\n" + return self._add_clauses(rule_str) + + def _add_clauses(self, rule_str: str) -> str: for cause in self.causes: try: - rule_str += "\n " + cause.text + rule_str += self._recursively_add_clauses(cause) except TypeError as e: print(f"Error in rule:'''\n{rule_str}'''") @@ -23,3 +32,22 @@ def __str__(self): raise e return rule_str + + def _recursively_add_clauses(self, query: str, depth: int = 1) -> str: + indentation = self.indent_str * depth + if type(query) == str: + return f"{indentation}- {query}\n" + + if type(query.text) == str: + return f"{indentation}- {query.text}\n" + + if depth > self._max_indentation: + return "" + + clause = list(query.text.keys())[0] + rules_text = f"{indentation}- {clause}\n" + for clauses in query.text.values(): + for cause_index, clause in enumerate(clauses): + rules_text += self._recursively_add_clauses(clause, depth + 1) + + return rules_text diff --git a/wafl/run.py b/wafl/run.py index 4f664ecb..b0397e84 100644 --- a/wafl/run.py +++ b/wafl/run.py @@ -52,4 +52,3 @@ def download_models(): import nltk nltk.download("averaged_perceptron_tagger") - diff --git a/wafl/runners/routes.py b/wafl/runners/routes.py index 169bfabc..78b182c0 100644 --- a/wafl/runners/routes.py +++ b/wafl/runners/routes.py @@ -1,22 +1,9 @@ -import asyncio import os -import random -import sys -import threading -from flask import Flask, render_template, redirect, url_for +from flask import Flask from flask_cors import CORS -from wafl.config import Configuration -from wafl.events.conversation_events import ConversationEvents -from wafl.interface.queue_interface import QueueInterface -from wafl.knowledge.single_file_knowledge import SingleFileKnowledge -from wafl.logger.local_file_logger import LocalFileLogger -from wafl.scheduler.conversation_loop import ConversationLoop -from wafl.scheduler.scheduler import Scheduler -from wafl.scheduler.web_loop import WebLoop _path = os.path.dirname(__file__) -_logger = LocalFileLogger() app = Flask( __name__, static_url_path="", @@ -26,51 +13,11 @@ CORS(app) -@app.route("/create_new_instance", methods=["POST"]) -def create_new_instance(): - conversation_id = random.randint(0, sys.maxsize) - result = create_scheduler_and_webserver_loop(conversation_id) - add_new_rules(app, conversation_id, result["web_server_loop"]) - thread = threading.Thread(target=result["scheduler"].run) - thread.start() - return redirect(url_for(f"index_{conversation_id}")) - - -@app.route("/") -async def index(): - return render_template("selector.html") - - def get_app(): return app -def create_scheduler_and_webserver_loop(conversation_id): - config = Configuration.load_local_config() - interface = QueueInterface() - interface.activate() - conversation_events = ConversationEvents( - config=config, - interface=interface, - logger=_logger, - ) - conversation_loop = ConversationLoop( - interface, - conversation_events, - _logger, - activation_word="", - max_misses=-1, - deactivate_on_closed_conversation=False, - ) - asyncio.run(interface.output("Hello. How may I help you?")) - web_loop = WebLoop(interface, conversation_id, conversation_events) - return { - "scheduler": Scheduler([conversation_loop, web_loop]), - "web_server_loop": web_loop, - } - - -def add_new_rules(app, conversation_id, web_server_loop): +def add_new_rules(app: Flask, conversation_id: int, web_server_loop: "WebLoop"): app.add_url_rule( f"/{conversation_id}/", f"index_{conversation_id}", @@ -125,3 +72,9 @@ def add_new_rules(app, conversation_id, web_server_loop): web_server_loop.thumbs_down, methods=["POST"], ) + app.add_url_rule( + f"/{conversation_id}/toggle_logs", + f"toggle_logs_{conversation_id}", + web_server_loop.toggle_logs, + methods=["POST"], + ) diff --git a/wafl/runners/run_from_actions.py b/wafl/runners/run_from_actions.py index c4868874..4e005690 100644 --- a/wafl/runners/run_from_actions.py +++ b/wafl/runners/run_from_actions.py @@ -41,11 +41,11 @@ def predict_action(config, actions_list, expected_list): raise ValueError("The agent did not say anything.") if expected and not asyncio.run( - entailer.left_entails_right( - last_utterance, - expected, - "\n".join(interface.get_utterances_list()[:-1]), - ) + entailer.left_entails_right( + last_utterance, + expected, + "\n".join(interface.get_utterances_list()[:-1]), + ) ): del entailer, conversation_events, interface raise ValueError( diff --git a/wafl/runners/run_web_and_audio_interface.py b/wafl/runners/run_web_and_audio_interface.py new file mode 100644 index 00000000..7f705db9 --- /dev/null +++ b/wafl/runners/run_web_and_audio_interface.py @@ -0,0 +1,62 @@ +import random +import sys +import threading + +from flask import render_template, redirect, url_for + +from wafl.interface.list_interface import ListInterface +from wafl.interface.voice_interface import VoiceInterface +from wafl.scheduler.scheduler import Scheduler +from wafl.scheduler.web_loop import WebLoop +from wafl.scheduler.conversation_loop import ConversationLoop +from wafl.logger.local_file_logger import LocalFileLogger +from wafl.events.conversation_events import ConversationEvents +from wafl.interface.queue_interface import QueueInterface +from wafl.config import Configuration +from wafl.runners.routes import get_app, add_new_rules + + +app = get_app() +_logger = LocalFileLogger() + + +def run_app(): + @app.route("/create_new_instance", methods=["POST"]) + def create_new_instance(): + conversation_id = random.randint(0, sys.maxsize) + result = create_scheduler_and_webserver_loop(conversation_id) + add_new_rules(app, conversation_id, result["web_server_loop"]) + thread = threading.Thread(target=result["scheduler"].run) + thread.start() + return redirect(url_for(f"index_{conversation_id}")) + + @app.route("/") + async def index(): + return render_template("selector.html") + + def create_scheduler_and_webserver_loop(conversation_id): + config = Configuration.load_local_config() + interface = ListInterface([VoiceInterface(config), QueueInterface()]) + interface.activate() + conversation_events = ConversationEvents( + config=config, + interface=interface, + logger=_logger, + ) + conversation_loop = ConversationLoop( + interface, + conversation_events, + _logger, + activation_word=config.get_value("waking_up_word"), + ) + web_loop = WebLoop(interface, conversation_id, conversation_events) + return { + "scheduler": Scheduler([conversation_loop, web_loop]), + "web_server_loop": web_loop, + } + + app.run(host="0.0.0.0", port=Configuration.load_local_config().get_value("frontend_port")) + + +if __name__ == "__main__": + run_app() diff --git a/wafl/runners/run_web_interface.py b/wafl/runners/run_web_interface.py index 35814571..ac285396 100644 --- a/wafl/runners/run_web_interface.py +++ b/wafl/runners/run_web_interface.py @@ -1,10 +1,63 @@ -from wafl.runners.routes import get_app +import asyncio +import random +import sys +import threading + +from flask import render_template, redirect, url_for + +from wafl.scheduler.scheduler import Scheduler +from wafl.scheduler.web_loop import WebLoop +from wafl.scheduler.conversation_loop import ConversationLoop +from wafl.logger.local_file_logger import LocalFileLogger +from wafl.events.conversation_events import ConversationEvents +from wafl.interface.queue_interface import QueueInterface +from wafl.config import Configuration +from wafl.runners.routes import get_app, add_new_rules + app = get_app() +_logger = LocalFileLogger() def run_app(): - app.run(host="0.0.0.0", port=8889) + @app.route("/create_new_instance", methods=["POST"]) + def create_new_instance(): + conversation_id = random.randint(0, sys.maxsize) + result = create_scheduler_and_webserver_loop(conversation_id) + add_new_rules(app, conversation_id, result["web_server_loop"]) + thread = threading.Thread(target=result["scheduler"].run) + thread.start() + return redirect(url_for(f"index_{conversation_id}")) + + @app.route("/") + async def index(): + return render_template("selector.html") + + def create_scheduler_and_webserver_loop(conversation_id): + config = Configuration.load_local_config() + interface = QueueInterface() + interface.activate() + conversation_events = ConversationEvents( + config=config, + interface=interface, + logger=_logger, + ) + conversation_loop = ConversationLoop( + interface, + conversation_events, + _logger, + activation_word="", + max_misses=-1, + deactivate_on_closed_conversation=False, + ) + asyncio.run(interface.output("Hello. How may I help you?")) + web_loop = WebLoop(interface, conversation_id, conversation_events) + return { + "scheduler": Scheduler([conversation_loop, web_loop]), + "web_server_loop": web_loop, + } + + app.run(host="0.0.0.0", port=Configuration.load_local_config().get_value("frontend_port")) if __name__ == "__main__": diff --git a/wafl/scheduler/messages_creator.py b/wafl/scheduler/messages_creator.py new file mode 100644 index 00000000..ec0d85cb --- /dev/null +++ b/wafl/scheduler/messages_creator.py @@ -0,0 +1,70 @@ +from wafl.scheduler.web_interface_implementation import get_html_from_dialogue_item + + +class MessagesCreator: + def __init__(self, interface): + self._interface = interface + self._toggled_windows = [] + + def toggle_logs(self): + if "logs" in self._toggled_windows: + self._toggled_windows.remove("logs") + else: + self._toggled_windows.append("logs") + + async def get_messages_window(self): + conversation = "" + conversation += await self._get_dialogue() + if "logs" in self._toggled_windows: + conversation += await self._get_logs() + + return conversation + + async def _get_dialogue(self): + dialogue_items = self._interface.get_utterances_list_with_timestamp() + dialogue = [] + for index, item in enumerate(dialogue_items): + dialogue.append( + ( + item[0], + get_html_from_dialogue_item( + item[1], + ), + ) + ) + + dialogue_items = dialogue + dialogue_items = sorted(dialogue_items, key=lambda x: x[0])[::-1] + dialogue_items = [item[1] for item in dialogue_items] + conversation = ( + "
    " + ) + conversation += "".join(dialogue_items) + return conversation + + async def _get_logs(self): + choices = self._interface.get_choices_and_timestamp() + choices = [ + ( + item[0], + "
    " + item[1] + "
    ", + ) + for item in choices + ] + facts = self._interface.get_facts_and_timestamp() + facts = [ + ( + item[0], + "
    " + item[1] + "
    ", + ) + for item in facts + ] + + choices_and_facts = choices + facts + choices_and_facts = sorted(choices_and_facts, key=lambda x: x[0])[::-1] + choices_and_facts = [item[1] for item in choices_and_facts] + conversation = "
    " + conversation += "
    " + conversation += "".join(choices_and_facts) + conversation += "
    " + return conversation diff --git a/wafl/scheduler/web_loop.py b/wafl/scheduler/web_loop.py index f43a5a9e..2c16a416 100644 --- a/wafl/scheduler/web_loop.py +++ b/wafl/scheduler/web_loop.py @@ -2,11 +2,9 @@ import os from flask import render_template, request, jsonify -from wafl.interface.queue_interface import QueueInterface +from wafl.interface.base_interface import BaseInterface from wafl.logger.history_logger import HistoryLogger -from wafl.scheduler.web_interface_implementation import ( - get_html_from_dialogue_item, -) +from wafl.scheduler.messages_creator import MessagesCreator _path = os.path.dirname(__file__) @@ -14,7 +12,7 @@ class WebLoop: def __init__( self, - interface: QueueInterface, + interface: BaseInterface, conversation_id: int, conversation_events: "ConversationEvents", ): @@ -23,13 +21,14 @@ def __init__( self._conversation_id = conversation_id self._conversation_events = conversation_events self._prior_dialogue_items = "" + self._messages_creator = MessagesCreator(self._interface) async def index(self): return render_template("index.html", conversation_id=self._conversation_id) async def handle_input(self): query = request.form["query"] - self._interface.input_queue.append(query) + await self._interface.insert_input(query) return f"""