From 87c8c8120d8f68477eb8854e301d1d0d06b85b79 Mon Sep 17 00:00:00 2001 From: printer83mph Date: Wed, 30 Nov 2022 01:50:06 -0500 Subject: [PATCH] Add CLI interface with serialization --- Cargo.lock | 196 ++++++++++++++++++++++++++++++++++++---- Cargo.toml | 4 +- src/lib.rs | 2 + src/main.rs | 105 +++++++++++++++++---- src/markov.rs | 47 +++++++--- src/markov/serialize.rs | 115 +++++++++++++++++++++++ 6 files changed, 423 insertions(+), 46 deletions(-) create mode 100644 src/lib.rs create mode 100644 src/markov/serialize.rs diff --git a/Cargo.lock b/Cargo.lock index b7808b8..d898e26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,23 +2,18 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "atty" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi", - "libc", - "winapi", -] - [[package]] name = "bitflags" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "cc" +version = "1.0.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9f73505338f7d905b19d18738976aae232eb46b8efc15554ffc56deb5d9ebe4" + [[package]] name = "cfg-if" version = "1.0.0" @@ -27,14 +22,14 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.0.26" +version = "4.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2148adefda54e14492fb9bddcc600b4344c5d1a3123bd666dcb939c6f0e0e57e" +checksum = "f94eecf2f2d0e0220737d9e6d8530da1e903b97dde58ebaf749262e85ce133c9" dependencies = [ - "atty", "bitflags", "clap_derive", "clap_lex", + "is-terminal", "once_cell", "strsim", "termcolor", @@ -62,6 +57,27 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "errno" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" +dependencies = [ + "errno-dragonfly", + "libc", + "winapi", +] + +[[package]] +name = "errno-dragonfly" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "getrandom" version = "0.2.8" @@ -81,25 +97,61 @@ checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9" [[package]] name = "hermit-abi" -version = "0.1.19" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" dependencies = [ "libc", ] +[[package]] +name = "io-lifetimes" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46112a93252b123d31a119a8d1a1ac19deac4fac6e0e8b0df58f0d4e5870e63c" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "is-terminal" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "927609f78c2913a6f6ac3c27a4fe87f43e2a35367c0c4b0f8265e8f49a104330" +dependencies = [ + "hermit-abi", + "io-lifetimes", + "rustix", + "windows-sys", +] + +[[package]] +name = "itoa" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4217ad341ebadf8d8e724e264f13e593e0648f5b3e94b3896a5df283be015ecc" + [[package]] name = "libc" version = "0.2.137" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89" +[[package]] +name = "linux-raw-sys" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f9f08d8963a6c613f4b1a78f4f4a4dbfadf8e6545b2d72861731e4858b8b47f" + [[package]] name = "markov-rust" version = "0.1.0" dependencies = [ "clap", "rand", + "serde", + "serde_json", ] [[package]] @@ -192,6 +244,57 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rustix" +version = "0.36.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb93e85278e08bb5788653183213d3a60fc242b10cb9be96586f5a73dcb67c23" +dependencies = [ + "bitflags", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys", + "windows-sys", +] + +[[package]] +name = "ryu" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4501abdff3ae82a1c1b477a17252eb69cee9e66eb915c1abaa4f44d873df9f09" + +[[package]] +name = "serde" +version = "1.0.148" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e53f64bb4ba0191d6d0676e1b141ca55047d83b74f5607e6d8eb88126c52c2dc" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.148" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a55492425aa53521babf6137309e7d34c20bbfbbfcfe2c7f3a047fd1f6b92c0c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "020ff22c755c2ed3f8cf162dbb41a7268d934702f3ed3631656ea597e08fc3db" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "strsim" version = "0.10.0" @@ -200,9 +303,9 @@ checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "syn" -version = "1.0.103" +version = "1.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d" +checksum = "4ae548ec36cf198c0ef7710d3c230987c2d6d7bd98ad6edc0274462724c585ce" dependencies = [ "proc-macro2", "quote", @@ -266,3 +369,60 @@ name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" diff --git a/Cargo.toml b/Cargo.toml index 968ccea..aea2cfa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,5 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -clap = { version = "4.0.26", features = ["derive"] } rand = "0.8.5" +serde = { version = "1.0.148", features = ["derive"] } +clap = { version = "4.0.27", features = ["derive"] } +serde_json = { version = "1.0.89" } diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..effa287 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,2 @@ +mod markov; +pub use markov::{Config, Model}; diff --git a/src/main.rs b/src/main.rs index 6729d4b..d7e3281 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,21 +1,94 @@ -mod markov; +use std::io::Write; + +use clap::Parser; +mod markov; use markov::Model; -fn main() { - let mut mk = Model::new_prose(); - mk.train_paragraph("French (français [fʁɑ̃sɛ] or langue française [lɑ̃ɡ fʁɑ̃sɛːz]) is a Romance language of the Indo-European family. It descended from the Vulgar Latin of the Roman Empire, as did all Romance languages. French evolved from Gallo-Romance, the Latin spoken in Gaul, and more specifically in Northern Gaul. Its closest relatives are the other langues d'oïl—languages historically spoken in northern France and in southern Belgium, which French (Francien) largely supplanted. French was also influenced by native Celtic languages of Northern Roman Gaul like Gallia Belgica and by the (Germanic) Frankish language of the post-Roman Frankish invaders. Today, owing to France's past overseas expansion, there are numerous French-based creole languages, most notably Haitian Creole. A French-speaking person or nation may be referred to as Francophone in both English and French."); - mk.train_paragraph("French is an official language in 29 countries across multiple continents, most of which are members of the Organisation internationale de la Francophonie (OIF), the community of 84 countries which share the official use or teaching of French. French is also one of six official languages used in the United Nations. It is spoken as a first language (in descending order of the number of speakers) in France; Canada (especially in the provinces of Quebec, Ontario, and New Brunswick, as well as other Francophone regions); Belgium (Wallonia and the Brussels-Capital Region); western Switzerland (specifically the cantons forming the Romandy region); parts of Luxembourg; parts of the United States (the states of Louisiana, Maine, New Hampshire and Vermont); Monaco; the Aosta Valley region of Italy; and various communities elsewhere."); - mk.train_paragraph("In 2015, approximately 40% of the francophone population (including L2 and partial speakers) lived in Europe, 36% in sub-Saharan Africa and the Indian Ocean, 15% in North Africa and the Middle East, 8% in the Americas, and 1% in Asia and Oceania. French is the second most widely spoken mother tongue in the European Union. Of Europeans who speak other languages natively, approximately one-fifth are able to speak French as a second language. French is the second most taught foreign language in the EU. All institutions of the EU use French as a working language along with English and German; in certain institutions, French is the sole working language (e.g. at the Court of Justice of the European Union). French is also the 18th most natively spoken language in the world, fifth most spoken language by total number of speakers and the second or third most studied language worldwide (with about 120 million learners as of 2017). As a result of French and Belgian colonialism from the 16th century onward, French was introduced to new territories in the Americas, Africa and Asia. Most second-language speakers reside in Francophone Africa, in particular Gabon, Algeria, Morocco, Tunisia, Mauritius, Senegal and Ivory Coast."); - mk.train_paragraph("French is estimated to have about 76 million native speakers; about 235 million daily, fluent speakers; and another 77–110 million secondary speakers who speak it as a second language to varying degrees of proficiency, mainly in Africa. According to the OIF, approximately 321 million people worldwide are \"able, to, speak the language\", without specifying the criteria for this estimation or whom it encompasses. According to a demographic projection led by the Université Laval and the Réseau Démographie de l'Agence universitaire de la Francophonie, the total number of French speakers will reach approximately 500 million in 2025 and 650 million by 2050. OIF estimates 700 million by 2050, 80% of whom will be in Africa."); - mk.train_paragraph("French has a long history as an international language of literature and scientific standards and is a primary or second language of many international organisations including the United Nations, the European Union, the North Atlantic Treaty Organization, the World Trade Organization, the International Olympic Committee, and the International Committee of the Red Cross. In 2011, Bloomberg Businessweek ranked French the third most useful language for business, after English and Standard Mandarin Chinese."); - mk.train_paragraph("French is a Romance language (meaning that it is descended primarily from Vulgar Latin) that evolved out of the Gallo-Romance dialects spoken in northern France. The language's early forms include Old French and Middle French."); - mk.train_paragraph("Due to Roman rule, Latin was gradually adopted by the inhabitants of Gaul, and as the language was learned by the common people it developed a distinct local character, with grammatical differences from Latin as spoken elsewhere, some of which being attested on graffiti. This local variety evolved into the Gallo-Romance tongues, which include French and its closest relatives, such as Arpitan."); - mk.train_paragraph("The evolution of Latin in Gaul was shaped by its coexistence for over half a millennium beside the native Celtic Gaulish language, which did not go extinct until the late sixth century, long after the fall of the Western Roman Empire. The population remained 90% indigenous in origin; the Romanizing class were the local native elite (not Roman settlers), whose children learned Latin in Roman schools. At the time of the collapse of the Empire, this local elite had been slowly abandoning Gaulish entirely, but the rural and lower class populations remained Gaulish speakers who could sometimes also speak Latin or Greek. The final language shift from Gaulish to Vulgar Latin among rural and lower class populations occurred later, when both they and the incoming Frankish ruler/military class adopted the Gallo-Roman Vulgar Latin speech of the urban intellectual elite."); - mk.train_paragraph("The Gaulish language likely survived into the sixth century in France despite considerable Romanization. Coexisting with Latin, Gaulish helped shape the Vulgar Latin dialects that developed into French contributing loanwords and calques (including oui, the word for \"yes\"), sound changes shaped by Gaulish influence, and influences in conjugation and word order. Recent computational studies suggest that early gender shifts may have been motivated by the gender of the corresponding word in Gaulish."); - mk.train_paragraph("The estimated number of French words that can be attributed to Gaulish is placed at 154 by the Petit Robert, which is often viewed as representing standardized French, while if non-standard dialects are included, the number increases to 240. Known Gaulish loans are skewed toward certain semantic fields, such as plant life (chêne, bille, etc.), animals (mouton, cheval, etc.), nature (boue, etc.), domestic activities (ex. berceau), farming and rural units of measure (arpent, lieue, borne, boisseau), weapons, and products traded regionally rather than further afield. This semantic distribution has been attributed to peasants being the last to hold onto Gaulish."); - - for _ in 0..10 { - println!("{}", mk.generate_paragraph()); +#[derive(clap::Subcommand, Debug)] +enum Action { + Train { + source: std::path::PathBuf, + model_file: std::path::PathBuf, + + #[arg(short, long, default_value_t = false)] + reset: bool, + }, + Generate { + model_file: std::path::PathBuf, + out_file: std::path::PathBuf, + }, +} + +#[derive(Parser, Debug)] +#[command(version, long_about = None)] +struct Args { + #[command(subcommand)] + action: Action, +} + +fn main() -> std::io::Result<()> { + let args = Args::parse(); + + match args.action { + Action::Train { + source, + model_file: model_file_path, + reset, + } => { + // load source file + let source_file = + std::fs::File::open(source).expect("Unexpected error opening source file"); + let mut source_reader = std::io::BufReader::new(source_file); + + // load model file + let mut model = { + let prev_model = std::fs::OpenOptions::new() + .read(true) + .open(&model_file_path); + + let model_exists = match prev_model { + Ok(_) => true, + Err(_) => false, + }; + + // make new model if file didn't exist or we're resetting + if !model_exists || reset { + Model::new_prose() + } else { + // TODO: should i borrow this?? + serde_json::from_reader(&prev_model.unwrap()).expect("Invalid model file") + } + }; + + // train model + println!("Training..."); + model.train_buf(&mut source_reader); + println!("Training complete!"); + + // serialize model + let serialized = serde_json::to_string(&model)?; + + // open up model file and clear it + let mut model_file = std::fs::OpenOptions::new() + .write(true) + .truncate(true) + .create(true) + .open(&model_file_path) + .expect("Could not open model file for writing"); + + // save serialized model + model_file.write_all(serialized.as_bytes())?; + println!("Saved model to {}.", model_file_path.to_str().unwrap()); + + Ok(()) + } + Action::Generate { + model_file, + out_file, + } => { + println!("generating bastard"); + Ok(()) + } } } diff --git a/src/markov.rs b/src/markov.rs index 5074e4d..faf1685 100644 --- a/src/markov.rs +++ b/src/markov.rs @@ -1,17 +1,12 @@ +use rand::prelude::random; use std::{ collections::{HashMap, HashSet}, io::BufRead, }; -use rand::prelude::random; - -// TODO: use this babey -enum Capitalization { - Ignore, - Match, - Capitalize, -} +mod serialize; +#[derive(Debug, Clone)] pub struct Config { paragraph_delimiter: char, word_delimiters: HashSet, @@ -20,12 +15,13 @@ pub struct Config { impl Config { pub fn prose() -> Config { Config { - paragraph_delimiter: '\n', - word_delimiters: HashSet::from([' ', '\t']), + paragraph_delimiter: '\n'.to_owned(), + word_delimiters: HashSet::from([' '.to_owned(), '\t'.to_owned()]), } } } +#[derive(Debug, Clone)] struct WordStats { word: String, occurrences: i32, @@ -78,6 +74,8 @@ fn pick_random(probabilities: &Vec) -> usize { } /// Markov chain model. Contains known words, and stats connecting them. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(from = "serialize::Model", into = "serialize::Model")] pub struct Model { config: Config, words: HashMap, @@ -109,7 +107,10 @@ impl Model { let mut bytes: Vec = Vec::new(); // iterate over paragraphs - while let Ok(_) = reader.read_until(self.config.paragraph_delimiter as u8, &mut bytes) { + while let Ok(bi) = reader.read_until(self.config.paragraph_delimiter as u8, &mut bytes) { + if bi == 0 { + break; + } if bytes.is_empty() { continue; } @@ -122,6 +123,30 @@ impl Model { } } + pub fn train_string(&mut self, content: &str) -> () { + let mut buffer: String = String::new(); + + for ch in content.chars() { + if ch == self.config.paragraph_delimiter { + // train paragraph if anything in buffer + if buffer.len() > 0 { + self.train_paragraph(&buffer); + } + + // clear buffer no matter what + buffer.clear(); + } else { + // add current char to buffer + buffer.push(ch); + } + } + + // train final paragraph + if buffer.len() > 0 { + self.train_paragraph(&buffer); + } + } + pub fn train_paragraph(&mut self, paragraph: &str) -> () { // split string into words, filter whitespace let words: Vec<&str> = paragraph diff --git a/src/markov/serialize.rs b/src/markov/serialize.rs new file mode 100644 index 0000000..c037310 --- /dev/null +++ b/src/markov/serialize.rs @@ -0,0 +1,115 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct Config { + paragraph_delimiter: char, + word_delimiters: HashSet, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct WordStats { + word: String, + occurrences: i32, + next_occurrences: Vec, + next_total: i32, + start_occurrences: i32, + end_occurrences: i32, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct Model { + config: Config, + start_words: HashSet, + start_total: i32, + end_words: HashSet, + end_total: i32, + stats: Vec, +} + +impl From for super::Config { + fn from(cfg: Config) -> Self { + Self { + paragraph_delimiter: cfg.paragraph_delimiter, + word_delimiters: cfg.word_delimiters, + } + } +} + +impl From for Config { + fn from(cfg: super::Config) -> Self { + Self { + paragraph_delimiter: cfg.paragraph_delimiter, + word_delimiters: cfg.word_delimiters, + } + } +} + +impl From for super::WordStats { + fn from(stat: WordStats) -> Self { + Self { + word: stat.word, + occurrences: stat.occurrences, + next_occurrences: stat.next_occurrences, + next_total: stat.next_total, + start_occurrences: stat.start_occurrences, + end_occurrences: stat.end_occurrences, + } + } +} + +impl From for WordStats { + fn from(stat: super::WordStats) -> Self { + Self { + word: stat.word, + occurrences: stat.occurrences, + next_occurrences: stat.next_occurrences, + next_total: stat.next_total, + start_occurrences: stat.start_occurrences, + end_occurrences: stat.end_occurrences, + } + } +} + +impl From for super::Model { + fn from(data: Model) -> Self { + super::Model { + config: super::Config::from(data.config), + words: data + .stats + .iter() + .enumerate() + .map(|(idx, stat)| (stat.word.to_owned(), idx)) + .collect(), + start_words: data.start_words, + start_total: data.start_total, + end_words: data.end_words, + end_total: data.end_total, + stats: data + .stats + .iter() + .map(|stat| super::WordStats::from(stat.to_owned())) + .collect(), + } + } +} + +impl From for Model { + fn from(data: super::Model) -> Self { + Model { + config: Config::from(data.config), + start_words: data.start_words, + start_total: data.start_total, + end_words: data.end_words, + end_total: data.end_total, + stats: data + .stats + .iter() + .map(|stat| WordStats::from(stat.to_owned())) + .collect(), + } + } +}