A Multi-Backend (Pytorch, Tensorflow, Jax) implementation of LLaMA using keras-core.
- Install your backend of choice (Pytorch, Tensorflow or Jax)
- Then install
llama_lite
git clone https://github.com/abdeladim-s/llama-lite && cd llama-lite
pip install -e .
- Get the
tinyllama
model weights from HF.
import os
os.environ["KERAS_BACKEND"] = "torch"
# os.environ["KERAS_BACKEND"] = "tensorflow"
# os.environ["KERAS_BACKEND"] = "jax"
from llama_lite.model import get_model_from_ckpt
model = get_model_from_ckpt('stories15M.pt')
prompt = "Once upon a time,"
max_new_tokens = 50
res = model.generate(prompt=prompt, max_new_tokens=max_new_tokens)
print(res)
MIT