Skip to content

Commit

Permalink
download model change (#14)
Browse files Browse the repository at this point in the history
* PR part 1: download changes

* using SAFE_MODEL_DIRECTORY func per review

* copy to base/small model folder without overwriting openai models

---------

Co-authored-by: Keith Ha <[email protected]>
  • Loading branch information
keith4ever and Keith Ha authored Dec 23, 2024
1 parent 067fca7 commit a858a31
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 21 deletions.
64 changes: 49 additions & 15 deletions scripts/download_models.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,61 @@ SOURCE_DIR="$CURRENT_DIR/.."
ARIA_OPTIONS="-x 8 -s 8 --continue --file-allocation=none"

# Set directories
MODELS_DIR="$SOURCE_DIR/models"
INPUTS_DIR="$SOURCE_DIR/inputs"
TINY_MODELS_DIR="$SOURCE_DIR/openai_whisper-tiny"
BASE_MODELS_DIR="$SOURCE_DIR/openai_whisper-base"
SMALL_MODELS_DIR="$SOURCE_DIR/openai_whisper-small"

# Make sure folders exist
mkdir -p "$MODELS_DIR"
mkdir -p "$INPUTS_DIR"
function SAFE_MODEL_DIRECTORY(){
if [ ! -d "${1}" ]; then
echo "mkdir ${1} .."
mkdir -p "${1}"
fi
}

SAFE_MODEL_DIRECTORY $TINY_MODELS_DIR
SAFE_MODEL_DIRECTORY $BASE_MODELS_DIR
SAFE_MODEL_DIRECTORY $SMALL_MODELS_DIR

# Download Whisper auxiliary models
HF_ARGMAX_URL="https://huggingface.co/argmaxinc/whisperkit-android/resolve/main"

aria2c $ARIA_OPTIONS -d "$INPUTS_DIR" -o converted_vocab.json $HF_ARGMAX_URL/converted_vocab.json
aria2c $ARIA_OPTIONS -d "$MODELS_DIR" -o melspectrogram.tflite $HF_ARGMAX_URL/melspectrogram.tflite
aria2c $ARIA_OPTIONS -d "$MODELS_DIR" -o postproc.tflite $HF_ARGMAX_URL/postproc.tflite
aria2c $ARIA_OPTIONS -d "$MODELS_DIR" -o voice_activity_detection.tflite $HF_ARGMAX_URL/voice_activity_detection.tflite
if [ ! -f $TINY_MODELS_DIR/converted_vocab.json ]; then
aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o converted_vocab.json $HF_ARGMAX_URL/converted_vocab.json
aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o MelSpectrogram.tflite $HF_ARGMAX_URL/melspectrogram.tflite
aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o postproc.tflite $HF_ARGMAX_URL/postproc.tflite
aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o voice_activity_detection.tflite $HF_ARGMAX_URL/voice_activity_detection.tflite
fi
if [ ! -f $BASE_MODELS_DIR/converted_vocab.json ]; then
cp $TINY_MODELS_DIR/converted_vocab.json $BASE_MODELS_DIR/.
cp $TINY_MODELS_DIR/MelSpectrogram.tflite $BASE_MODELS_DIR/.
cp $TINY_MODELS_DIR/postproc.tflite $BASE_MODELS_DIR/.
cp $TINY_MODELS_DIR/voice_activity_detection.tflite $BASE_MODELS_DIR/.
fi
if [ ! -f $SMALL_MODELS_DIR/converted_vocab.json ]; then
cp $TINY_MODELS_DIR/converted_vocab.json $SMALL_MODELS_DIR/.
cp $TINY_MODELS_DIR/MelSpectrogram.tflite $SMALL_MODELS_DIR/.
cp $TINY_MODELS_DIR/postproc.tflite $SMALL_MODELS_DIR/.
cp $TINY_MODELS_DIR/voice_activity_detection.tflite $SMALL_MODELS_DIR/.
fi

# Download Qualcomm models
HF_QUALCOMM_URL="https://huggingface.co/qualcomm"

aria2c $ARIA_OPTIONS -d "$MODELS_DIR" -o decoder_tiny.tflite $HF_QUALCOMM_URL/Whisper-Tiny-En/resolve/main/WhisperDecoder.tflite
aria2c $ARIA_OPTIONS -d "$MODELS_DIR" -o encoder_tiny.tflite $HF_QUALCOMM_URL/Whisper-Tiny-En/resolve/main/WhisperEncoder.tflite
aria2c $ARIA_OPTIONS -d "$MODELS_DIR" -o decoder_base.tflite $HF_QUALCOMM_URL/Whisper-Base-En/resolve/main/WhisperDecoder.tflite
aria2c $ARIA_OPTIONS -d "$MODELS_DIR" -o encoder_base.tflite $HF_QUALCOMM_URL/Whisper-Base-En/resolve/main/WhisperEncoder.tflite
aria2c $ARIA_OPTIONS -d "$MODELS_DIR" -o decoder_small.tflite $HF_QUALCOMM_URL/Whisper-Small-En/resolve/main/WhisperDecoder.tflite
aria2c $ARIA_OPTIONS -d "$MODELS_DIR" -o encoder_small.tflite $HF_QUALCOMM_URL/Whisper-Small-En/resolve/main/WhisperEncoder.tflite
if [ ! -f $TINY_MODELS_DIR/TextDecoder.tflite ]; then
aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o TextDecoder.tflite $HF_QUALCOMM_URL/Whisper-Tiny-En/resolve/main/WhisperDecoder.tflite
fi
if [ ! -f $TINY_MODELS_DIR/AudioEncoder.tflite ]; then
aria2c $ARIA_OPTIONS -d "$TINY_MODELS_DIR" -o AudioEncoder.tflite $HF_QUALCOMM_URL/Whisper-Tiny-En/resolve/main/WhisperEncoder.tflite
fi
if [ ! -f $BASE_MODELS_DIR/TextDecoder.tflite ]; then
aria2c $ARIA_OPTIONS -d "$BASE_MODELS_DIR" -o TextDecoder.tflite $HF_QUALCOMM_URL/Whisper-Base-En/resolve/main/WhisperDecoder.tflite
fi
if [ ! -f $BASE_MODELS_DIR/AudioEncoder.tflite ]; then
aria2c $ARIA_OPTIONS -d "$BASE_MODELS_DIR" -o AudioEncoder.tflite $HF_QUALCOMM_URL/Whisper-Base-En/resolve/main/WhisperEncoder.tflite
fi
if [ ! -f $SMALL_MODELS_DIR/TextDecoder.tflite ]; then
aria2c $ARIA_OPTIONS -d "$SMALL_MODELS_DIR" -o TextDecoder.tflite $HF_QUALCOMM_URL/Whisper-Small-En/resolve/main/WhisperDecoder.tflite
fi
if [ ! -f $SMALL_MODELS_DIR/AudioEncoder.tflite ]; then
aria2c $ARIA_OPTIONS -d "$SMALL_MODELS_DIR" -o AudioEncoder.tflite $HF_QUALCOMM_URL/Whisper-Small-En/resolve/main/WhisperEncoder.tflite
fi
18 changes: 12 additions & 6 deletions src/whisperax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,18 @@ int tflite_init(string argstr){
(int)args["freq"], (int)args["ch"], format
);

string tokenizer_json = root_path + "/inputs/converted_vocab.json";
string audio_model = root_path + "/models/voice_activity_detection.tflite";
string melspectro_model = root_path + "/models/melspectrogram.tflite";
string encoder_model = root_path + "/models/encoder_" + model_size + ".tflite";
string decoder_model = root_path + "/models/decoder_" + model_size + ".tflite";
string postproc_model = root_path + "/models/postproc.tflite";
std::string tokenizer_json = root_path +
"/openai_whisper-" + model_size + "/converted_vocab.json";
std::string audio_model = root_path +
"/openai_whisper-" + model_size + "/voice_activity_detection.tflite";
std::string melspectro_model = root_path +
"/openai_whisper-" + model_size + "/MelSpectrogram.tflite";
std::string encoder_model = root_path +
"/openai_whisper-" + model_size + "/AudioEncoder.tflite";
std::string decoder_model = root_path +
"/openai_whisper-" + model_size + "/TextDecoder.tflite";
std::string postproc_model = root_path +
"/openai_whisper-" + model_size + "/postproc.tflite";

melspectro = make_unique<MODEL_SUPER_CLASS>("mel_spectrogram");
encoder = make_unique<MODEL_SUPER_CLASS>("whisper_encoder");
Expand Down

0 comments on commit a858a31

Please sign in to comment.