Skip to content
/ alpa Public
forked from alpa-projects/alpa

Training and serving large-scale neural networks

License

Notifications You must be signed in to change notification settings

wynot12/alpa

This branch is 67 commits behind alpa-projects/alpa:main.

Folders and files

NameName
Last commit message
Last commit date
Sep 7, 2022
Nov 23, 2022
Nov 27, 2022
Nov 22, 2022
Nov 22, 2022
Nov 23, 2022
Nov 23, 2022
Jul 4, 2022
Nov 27, 2022
Nov 27, 2022
Oct 29, 2022
May 18, 2022
Jun 22, 2022
Jun 7, 2022
Sep 9, 2022
Oct 15, 2022
Aug 30, 2022
Nov 22, 2022
Jul 23, 2022

Repository files navigation

logo

CI Build Jaxlib

Documentation | Slack

Alpa is a system for training and serving large-scale neural networks.

Scaling neural networks to hundreds of billions of parameters has enabled dramatic breakthroughs such as GPT-3, but training and serving these large-scale neural networks require complicated distributed system techniques. Alpa aims to automate large-scale distributed training and serving with just a few lines of code.

The key features of Alpa include:

💻 Automatic Parallelization. Alpa automatically parallelizes users' single-device code on distributed clusters with data, operator, and pipeline parallelism.

🚀 Excellent Performance. Alpa achieves linear scaling on training models with billions of parameters on distributed clusters.

Tight Integration with Machine Learning Ecosystem. Alpa is backed by open-source, high-performance, and production-ready libraries such as Jax, XLA, and Ray.

Serving

Alpa provides a free, unlimited OPT-175B text generation service. Try the service at https://opt.alpa.ai/ and share your prompting results!

The code below shows how to use huggingface/transformers interface and Alpa distributed backend for large model inference. Detailed documentation is in Serving OPT-175B using Alpa.

from transformers import AutoTokenizer
from llm_serving.model.wrapper import get_model

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
tokenizer.add_bos_token = False

# Load the model. Alpa automatically downloads the weights to the specificed path
model = get_model(model_name="alpa/opt-2.7b", path="~/opt_weights/")

# Generate
prompt = "Paris is the capital city of"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = model.generate(input_ids=input_ids, max_length=256, do_sample=True)
generated_string = tokenizer.batch_decode(output, skip_special_tokens=True)

print(generated_string)

Training

Use Alpa's decorator @parallelize to scale your single-device training code to distributed clusters. Check out the documentation site and examples folder for installation instructions, tutorials, examples, and more.

import alpa

# Parallelize the training step in Jax by simply using a decorator
@alpa.parallelize
def train_step(model_state, batch):
    def loss_func(params):
        out = model_state.forward(params, batch["x"])
        return jnp.mean((out - batch["y"]) ** 2)

    grads = grad(loss_func)(model_state.params)
    new_model_state = model_state.apply_gradient(grads)
    return new_model_state

# The training loop now automatically runs on your designated cluster
model_state = create_train_state()
for batch in data_loader:
    model_state = train_step(model_state, batch)

Learning more

Getting Involved

License

Alpa is licensed under the Apache-2.0 license.

About

Training and serving large-scale neural networks

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 94.6%
  • Jupyter Notebook 4.3%
  • Other 1.1%