Skip to content

Commit

Permalink
refactored model training scripts, added trained models
Browse files Browse the repository at this point in the history
  • Loading branch information
ylytkin committed Jun 9, 2020
1 parent 2fc2112 commit ab6e172
Show file tree
Hide file tree
Showing 40 changed files with 535,917 additions and 1,734 deletions.
19 changes: 12 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,21 @@ Usage:
2. `python fiddle_tune_generator.py --help`

Current pipeline:
1. (See `scripts/transposing.py`) Transpose all tunes to unified keys :
1. (See `scripts/transpose.py`) Transpose all tunes to unified keys:
* G for major tunes,
* A for dorian tunes,
* E for minor tunes,
* D for mixolydian tunes.
2. (See `scripts/models_training.py`) For each pair of tune type (reel, jig, waltz, etc.) and mode (maj, min, dor, mix):
2. (See `scripts/train_general_model.py`) For each type of tunes (reel, jig, etc.):
1. generate a character-level train and test data,
2. train a neural network with two LSTM layers (256 neurons each with 0.2 and 0.5 dropout rate, respectively) and a dense softmax output.
2. train a neural network with three LSTM layers and a dense softmax output.
3. (See `src/utils.py`) Inference:
1. Based on the current sequence of character, get predictions for the next character,
2. Truncate the distribution of these predictions, leaving only the top probable predictions, such that their probabilities sum up to 0.95,
3. Randomly sample a character from this truncated distribution,
4. If the next character is the `EOS` tag, stop iteration.
1. based on the current sequence of characters, get predictions for the next character,
2. truncate the distribution of these predictions, leaving only the top probable predictions, such that their probabilities sum up to 0.95,
3. randomly sample a character from this truncated distribution,
4. if the next character is the `EOS` tag, stop iteration.

Ideas for further improvement.
* Currently, we train a general model for all tunes of a given type, regardless of the mode (maj, min, dor, mix, etc.). It is not obvious, whether this is a problem, since all mode information is set in the preamble to the tune, and not in the abc code of the tune. Thus, most likely, everything will be fine with the current approach.
Nevertheless, in case it turns out to be a problem, it would be a cool possibility to try transfer learning on such general models (i.e. taking a general reels model as a base and transfer it to a major reels model or a mixolydian reels model, for example).
Therefore, it's a good idea to try such approach and see if it improves accuracy.
66 changes: 66 additions & 0 deletions data/general_training_history_hornpipe.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
{
"loss": [
2.2027834272638755,
1.4709336349423943,
1.3236833429792374,
1.2539887479203635,
1.2055175764498494,
1.170422483856396,
1.1405980348886129,
1.1162082077475415,
1.0948683269362565,
1.0783531257399097,
1.0608448799103642,
1.0481862285098433,
1.0344981938673865,
1.0239652900923668,
1.0123700644369644,
1.0036355976908586,
0.995490107373693,
0.9860053734147739,
0.9790697431955239,
0.9722387210354396,
0.9646777928557752,
0.9575336323698767,
0.952657629668923,
0.945080463151119,
0.9400030484889849,
0.9330590541429699,
0.9298332058929305,
0.9244280153269565,
0.919524573430733,
0.9143426884545839
],
"accuracy": [
0.37121686339378357,
0.5442774891853333,
0.5876496434211731,
0.6078925728797913,
0.6219408512115479,
0.6323307156562805,
0.640771746635437,
0.6479368805885315,
0.6540265083312988,
0.6591864228248596,
0.6644980311393738,
0.6681720614433289,
0.672454833984375,
0.6753780245780945,
0.6786727905273438,
0.6810027956962585,
0.6843714714050293,
0.686382532119751,
0.6883760690689087,
0.6902374029159546,
0.6926588416099548,
0.6952008605003357,
0.6967334747314453,
0.6986531019210815,
0.7011037468910217,
0.7025682926177979,
0.7038383483886719,
0.704777717590332,
0.7065885066986084,
0.7082961201667786
]
}
46 changes: 46 additions & 0 deletions data/general_training_history_jig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
{
"loss": [
1.7485593905054546,
1.232539892932018,
1.1555106971925593,
1.1116407526276904,
1.0808908686779255,
1.0566199864585142,
1.0419783018683748,
1.0291121112496444,
1.0121226673010995,
0.9953259065454297,
0.9841143867321538,
0.9739004774349398,
0.9651796275213597,
0.9560457347075108,
0.9480758148386335,
0.9420823400357955,
0.9345112734435304,
0.9485676076639102,
0.939601549915876,
0.9307886066455506
],
"accuracy": [
0.4846493899822235,
0.6078300476074219,
0.6311549544334412,
0.6444598436355591,
0.6540873050689697,
0.6615901589393616,
0.6664486527442932,
0.6714232563972473,
0.676124095916748,
0.6814731359481812,
0.684691846370697,
0.6880208253860474,
0.690861165523529,
0.6937748193740845,
0.6960694193840027,
0.6979149580001831,
0.7007066011428833,
0.6972887516021729,
0.7001186609268188,
0.7024385929107666
]
}
56 changes: 56 additions & 0 deletions data/general_training_history_polka.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"loss": [
1.914830084064797,
1.3660089044428942,
1.2609644384996572,
1.194659333289817,
1.151160910194242,
1.1167465416875593,
1.0882923511679128,
1.0645893417671957,
1.046438196781454,
1.0282729951399754,
1.0126073673173552,
0.9992953752529438,
0.9871039348484825,
0.9749189536754925,
0.9608236847448675,
0.9546419406923475,
0.9482410713047635,
0.9383804985539869,
0.9302145606996141,
0.9291163670627184,
0.9183714502667596,
0.9128819188764656,
0.9053957809167267,
0.8984473749170218,
0.8946406133049507
],
"accuracy": [
0.4612747132778168,
0.5725358724594116,
0.6034190654754639,
0.6229808330535889,
0.6361965537071228,
0.6467113494873047,
0.6552852988243103,
0.6621569991111755,
0.6674497127532959,
0.6731629967689514,
0.6778088212013245,
0.6817491054534912,
0.6849029064178467,
0.6884934306144714,
0.6921465992927551,
0.6943926811218262,
0.696394145488739,
0.6990931034088135,
0.7018749117851257,
0.702501654624939,
0.705584704875946,
0.7069594264030457,
0.708720326423645,
0.7112211585044861,
0.712719202041626
]
}
36 changes: 36 additions & 0 deletions data/general_training_history_reel.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"loss": [
1.4810838146640952,
1.2000247257323786,
1.1349456660699544,
1.0964310678069296,
1.0662369179621065,
1.042173972670099,
1.0227506263879442,
1.0061142553282334,
0.9906122513578767,
0.9774652934192335,
0.9665163805008757,
0.9554989706092435,
0.945794614407412,
0.9364177599225922,
0.9286806484007543
],
"accuracy": [
0.5412585139274597,
0.6212437748908997,
0.6411371231079102,
0.6527940630912781,
0.6622454524040222,
0.6698986291885376,
0.6757311820983887,
0.6807018518447876,
0.6854258179664612,
0.6895868182182312,
0.6932461857795715,
0.6965733170509338,
0.6994356513023376,
0.7024344205856323,
0.7049398422241211
]
}
86 changes: 86 additions & 0 deletions data/general_training_history_slip_jig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"loss": [
2.8437124311352084,
1.767414824590843,
1.5417460197410682,
1.4461139430497476,
1.3810912463501586,
1.3308511423243412,
1.2887538057919934,
1.2546608403511255,
1.2255080007260948,
1.1985697584253598,
1.1764657489259585,
1.1538989294420374,
1.1337424985055904,
1.116631184209658,
1.1023308574729573,
1.0847670502763707,
1.0725138204277298,
1.0587112321537326,
1.0471602826168829,
1.0332949370579863,
1.0237513088271544,
1.0178735726424575,
1.0031604604081863,
0.9969825221159662,
0.9847881694408466,
0.979301877645361,
0.9709734268593987,
0.9646497185435808,
0.9538076126428177,
0.9464841457251609,
0.9419594740261982,
0.9383753186930466,
0.929688434962235,
0.9233357441008128,
0.9176087392827217,
0.9124171673580916,
0.9060364770676982,
0.900536821891928,
0.897823082122808,
0.8941678422335192
],
"accuracy": [
0.24754953384399414,
0.4637649357318878,
0.5214890837669373,
0.5481681823730469,
0.5649504661560059,
0.5798517465591431,
0.5916683673858643,
0.6023072600364685,
0.6109917759895325,
0.6199696063995361,
0.6271830201148987,
0.6329036355018616,
0.6387364268302917,
0.645052433013916,
0.6491207480430603,
0.6539656519889832,
0.6574256420135498,
0.6630556583404541,
0.6657649874687195,
0.6704717874526978,
0.6725598573684692,
0.6741518378257751,
0.6789363026618958,
0.6811020374298096,
0.6854680180549622,
0.6861668825149536,
0.6886648535728455,
0.6916934251785278,
0.6940403580665588,
0.6970171332359314,
0.6977419257164001,
0.699722170829773,
0.7012494206428528,
0.7037343978881836,
0.7044764161109924,
0.7068837285041809,
0.7083333134651184,
0.7108743786811829,
0.7109348177909851,
0.7126259803771973
]
}
56 changes: 56 additions & 0 deletions data/general_training_history_waltz.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"loss": [
2.685673989264583,
1.7188711989661887,
1.4924569414973246,
1.401771911228768,
1.3454527150858102,
1.3063072291509499,
1.274768285599294,
1.2517376461174199,
1.2289602055723214,
1.210415361542043,
1.1933982292099583,
1.1796947502508595,
1.1652406830839075,
1.1551744893119613,
1.1439671116885866,
1.1327622177552517,
1.1249285683723138,
1.114917850068356,
1.1075156324314015,
1.0997151801820362,
1.0938059545828738,
1.0879304353425938,
1.0767206167510899,
1.0743636452631695,
1.065599986282323
],
"accuracy": [
0.29798388481140137,
0.5022055506706238,
0.5474862456321716,
0.5695353746414185,
0.5842868089675903,
0.5945073962211609,
0.6027480363845825,
0.6107384562492371,
0.6153762340545654,
0.6210232377052307,
0.6263281106948853,
0.6290820837020874,
0.634014904499054,
0.6365615129470825,
0.6402884125709534,
0.6437458992004395,
0.6458755135536194,
0.6489951610565186,
0.6505752801895142,
0.6537846922874451,
0.6546164751052856,
0.6566563248634338,
0.6599854826927185,
0.661076009273529,
0.663205623626709
]
}
Loading

0 comments on commit ab6e172

Please sign in to comment.