Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.

Commit 14b3153

Browse files
authored
[FEATURE] Support beam sample (#491)
1 parent bd47646 commit 14b3153

File tree

5 files changed

+145
-25
lines changed

5 files changed

+145
-25
lines changed

alpa/device_mesh.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import jax
3333
from jax import core, xla, device_put
3434
from jax._src.api import ShapeDtypeStruct
35-
from jax._src.lib import xla_bridge as xb, xla_extension as xe
35+
from jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe
3636
from jax._src.tree_util import tree_leaves
3737
from jax.abstract_arrays import array_types
3838
from jax.core import ShapedArray
@@ -48,11 +48,14 @@
4848
import alpa.collective as col
4949
from alpa.global_env import global_config
5050
from alpa.monkey_patch import set_override_backend
51-
from alpa.shard_parallel.auto_sharding import LogicalDeviceMesh
52-
from alpa.parallel_plan import PlacementSpec
51+
from alpa.shard_parallel.auto_sharding import (AutoShardingOption,
52+
LogicalDeviceMesh,
53+
run_spmd_partitioner_pass)
54+
from alpa.parallel_plan import PlacementSpec, StagePlan
5355
from alpa.timer import timers
5456
from alpa.util import (benchmark_func, list_gpu_info, OrderedSet,
55-
update_jax_platform, is_ray_node_resource)
57+
update_jax_platform, is_ray_node_resource,
58+
get_index_select_computation)
5659

5760
if global_config.nccl_mode == "cupy":
5861
import alpa.collective.worker_nccl_util_cupy as worker_nccl_util
@@ -608,6 +611,7 @@ class PhysicalDeviceMesh(ABC):
608611
num_hosts: int
609612
num_devices_per_host: int
610613
mesh_id: int
614+
operation_executables: dict
611615

612616
def get_signature(self) -> str:
613617
"""Return a signature string that contains the mesh shape and GPU
@@ -810,6 +814,7 @@ def __init__(self, devices: Sequence["Device"] = None):
810814
self.num_devices_per_host = len(self.devices)
811815
self.mesh_id = 0
812816
self.device_strs = []
817+
self.operation_executables = {}
813818

814819
self.set_runtime_random_seed(global_config.runtime_random_seed)
815820

@@ -898,6 +903,7 @@ def sync_workers(self):
898903

899904
def shutdown(self, forced=False):
900905
self.sync_workers()
906+
self.operation_executables.clear()
901907

902908

903909
def device_id_to_str(host_ip, device_id, device_type="gpu"):
@@ -934,6 +940,7 @@ def __init__(self,
934940
self.workers = None
935941
self.launched = False
936942
self.service_server = None
943+
self.operation_executables = {}
937944

938945
if devices is not None:
939946
if len(devices) != len(host_ids):
@@ -1301,6 +1308,7 @@ def shutdown(self, forced=False):
13011308
if not self.launched:
13021309
return
13031310
if not forced:
1311+
self.operation_executables.clear()
13041312
ray.get([w.shutdown.remote() for w in self.workers])
13051313
for worker in self.workers:
13061314
ray.kill(worker)
@@ -1309,6 +1317,7 @@ def shutdown(self, forced=False):
13091317
self.service_server.shutdown()
13101318
self.service_server = None
13111319
self.launched = False
1320+
self.operation_executables.clear() # clear with forced shutdown
13121321

13131322

13141323
########################################
@@ -1525,6 +1534,34 @@ def __float__(self):
15251534

15261535
# TODO(lmzheng): copy more functions from DeviceArray
15271536
# (jax/_src/device_array.py)
1537+
def index_select(self, dim, index):
1538+
"""Compile and run index select operation."""
1539+
# pylint: disable=import-outside-toplevel
1540+
from alpa.mesh_executable import NormalMeshDriverExecutable
1541+
if type(index) not in [ShapedArray, ShapeDtypeStruct]:
1542+
index = xla.canonicalize_dtype(index)
1543+
index_shape = xc.shape_from_pyval(index)
1544+
key = hash(("index_select", self.aval, dim, index_shape))
1545+
if key in self.device_mesh.operation_executables:
1546+
executable = self.device_mesh.operation_executables[key]
1547+
else:
1548+
index_aval = ShapedArray(index.shape, index.dtype)
1549+
c = get_index_select_computation(self.sharding_spec, dim, self.aval,
1550+
index_shape).as_hlo_module()
1551+
hlo_module = run_spmd_partitioner_pass(c,
1552+
self.device_mesh.num_devices)
1553+
1554+
as_option = AutoShardingOption()
1555+
strategy_config = StagePlan(global_config.compile_random_seed,
1556+
self.device_mesh.shape, 1 << 60,
1557+
as_option.all_reduce_threshold, None,
1558+
-1)
1559+
executable = NormalMeshDriverExecutable(self.device_mesh,
1560+
hlo_module, strategy_config,
1561+
[self.aval, index_aval],
1562+
[self.aval], [False, False])
1563+
self.device_mesh.operation_executables[key] = executable
1564+
return executable.launch_on_driver(self, index)
15281565

15291566
def __str__(self):
15301567
return (f"DistributedArray(sharding_spec={self.sharding_spec}, "

alpa/util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,23 @@ def compile_concatenate(backend, mesh_shape, sharding_spec, batch_size,
560560
return hlo_proto
561561

562562

563+
def get_index_select_computation(sharding_spec, dim, aval, index_shape):
564+
sharding = pxla.sharding_spec_sharding_proto(sharding_spec)
565+
c = xc.XlaBuilder("index_select")
566+
c.set_sharding(sharding)
567+
operand = xc.ops.Parameter(
568+
c, 0, xc.shape_from_pyval(np.ones(aval.shape, aval.dtype)))
569+
c.clear_sharding()
570+
index = xc.ops.Parameter(c, 1, index_shape)
571+
index_selected = xc.ops.IndexSelect(operand, index, dim)
572+
sharding2 = xc.OpSharding()
573+
sharding2.type = sharding.type.TUPLE
574+
sharding2.tuple_shardings = [sharding]
575+
c.set_sharding(sharding2)
576+
c = c.build(xc.ops.Tuple(c, [index_selected]))
577+
return c
578+
579+
563580
def get_shard_shape(aval: ShapedArray, sharding_spec: pxla.ShardingSpec):
564581
"""Return the shape of a shard."""
565582
shape = []

examples/opt_serving/model/opt_model.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class OPTConfig:
5757
decoder_attention_heads: int = 12
5858
decoder_input_dim: int = 768
5959
decoder_ffn_embed_dim: int = 3072
60-
batch_size: int = 1
6160
pad: int = 1
6261
activation_fn: str = 'relu'
6362
dtype: any = jnp.float16
@@ -530,17 +529,18 @@ def init_model_aval(config):
530529

531530

532531
def init_cache_aval(config, batch_size):
532+
dtype = config.dtype
533533
head_dim = config.decoder_embed_dim // config.decoder_attention_heads
534534

535535
all_cache = []
536536
for i in range(config.decoder_layers):
537537
layer_cache = (
538538
jax.core.ShapedArray((batch_size, config.max_target_positions,
539539
config.decoder_attention_heads, head_dim),
540-
config.dtype),
540+
dtype),
541541
jax.core.ShapedArray((batch_size, config.max_target_positions,
542542
config.decoder_attention_heads, head_dim),
543-
config.dtype),
543+
dtype),
544544
jax.core.ShapedArray((batch_size,), jnp.int32),
545545
)
546546
all_cache.append(layer_cache)
@@ -679,9 +679,6 @@ def get_pipeshard_executable(config,
679679
support_output_attentions=False,
680680
support_output_hidden_states=False,
681681
autoregressive=True):
682-
if autoregressive:
683-
assert num_micro_batches == 1, "we only support num_micro_batches=1 for autoregressive!"
684-
assert batch_size == 1, "we only support batch_sie = 1 for autoregressive!"
685682

686683
# Init model
687684
model, params = init_model_aval(config)
@@ -708,9 +705,9 @@ def inference_step_with_cache(params, batch):
708705
alpa.global_config.always_donate_micro_batch_vars = False
709706
executable = inference_step_with_cache.get_executable(
710707
params, {
711-
"input_ids": jax.core.ShapedArray((1, 1), jnp.int32),
712-
"position_ids": jax.core.ShapedArray((1, 1), jnp.int32),
713-
"cache": init_cache_aval(config, 1),
708+
"input_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32),
709+
"position_ids": jax.core.ShapedArray((batch_size, 1), jnp.int32),
710+
"cache": init_cache_aval(config, batch_size),
714711
})
715712
else:
716713

examples/opt_serving/model/wrapper.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
from functools import partial
12
import os
23
from typing import Sequence, Any
34

45
import alpa
56
import jax
7+
from jax import xla
8+
from jax import ShapeDtypeStruct, ShapedArray
9+
from jax._src.lib import xla_client as xc
10+
from jax.core import Primitive
11+
from jax.interpreters import pxla
12+
from jax.interpreters.pxla import NoSharding, Replicated, ShardingSpec
613
import jax.numpy as jnp
714
import numpy as np
815
import torch
@@ -15,6 +22,21 @@
1522
from examples.opt_serving.model.opt_utils import TransformerModelConfig
1623

1724

25+
index_select_p = Primitive("index-select")
26+
def jax_index_select(input, index, dim=0):
27+
return index_select_p.bind(input, index, dim=dim)
28+
29+
def _index_select_eval(input, index, dim):
30+
return input
31+
32+
def _index_select_translation(c, input, index, dim):
33+
return xc.ops.IndexSelect(input, index, dim)
34+
35+
index_select_p.def_abstract_eval(_index_select_eval)
36+
index_select_p.def_impl(partial(xla.apply_primitive, index_select_p))
37+
xla.translations[index_select_p] = _index_select_translation
38+
39+
1840
@dataclass
1941
class InferenceFuncOutput(ModelOutput):
2042
logits: Any = None
@@ -100,8 +122,46 @@ def __call__(self,
100122
past_key_values = ret.past_key_values
101123
return ret
102124

103-
104-
def get_hf_gpt_model(model_name, device):
125+
def _reorder_cache(self, past, beam_idx):
126+
# Current beam_idx is a torch tensor from beam scorer. To speedup,
127+
# we need to have alpa's own beam scorer
128+
cache = {}
129+
cpu_idx = beam_idx.to("cpu").numpy()
130+
if type(cpu_idx) not in [ShapedArray, ShapeDtypeStruct]:
131+
cpu_idx = xla.canonicalize_dtype(cpu_idx)
132+
133+
def to_mesh(mesh):
134+
if mesh in cache:
135+
return cache[mesh]
136+
avals = [ShapedArray(cpu_idx.shape, cpu_idx.dtype)]
137+
replicated_spec = ShardingSpec([NoSharding()] * len(cpu_idx.shape),
138+
[Replicated(mesh.num_devices)])
139+
specs = [replicated_spec]
140+
indices = [pxla.spec_to_indices(cpu_idx.shape, replicated_spec)]
141+
ary = mesh.shard_args_to_arrays(avals, indices, specs, [cpu_idx])[0]
142+
cache[mesh] = ary
143+
return ary
144+
145+
def single_element_reorder_cache(ary):
146+
if hasattr(ary, "index_select"):
147+
# Torch or Alpa path
148+
device_idx = None
149+
if hasattr(ary, "device"): # Torch to_device
150+
device_idx = beam_idx.to(ary.device)
151+
else:
152+
device_idx = to_mesh(ary.device_mesh)
153+
return ary.index_select(0, device_idx)
154+
else:
155+
# Jax path
156+
return jax_index_select(ary, cpu_idx, 0)
157+
return tuple(
158+
tuple(
159+
single_element_reorder_cache(past_state)
160+
for past_state in layer_past)
161+
for layer_past in past)
162+
163+
164+
def get_hf_gpt_model(model_name, device, num_beams):
105165
raw_model = GPT2LMHeadModel.from_pretrained(model_name)
106166
raw_model = raw_model.to(device)
107167

@@ -116,6 +176,7 @@ def inference_func(input_ids,
116176
return InferenceFuncOutput(out.logits, out.past_key_values)
117177

118178
inference_func_config = raw_model.config
179+
inference_func_config.num_beams = num_beams
119180
transformer_config = TransformerModelConfig(
120181
H=raw_model.config.n_embd,
121182
L=raw_model.config.n_layer,
@@ -127,7 +188,7 @@ def inference_func(input_ids,
127188
executable, transformer_config)
128189

129190

130-
def get_hf_opt_model(model_name, device):
191+
def get_hf_opt_model(model_name, device, num_beams):
131192
raw_model = OPTForCausalLM.from_pretrained(
132193
model_name,
133194
torch_dtype=torch.float16 if "cuda" in device else torch.float32)
@@ -150,7 +211,7 @@ def inference_func(input_ids,
150211
output_hidden_states=output_hidden_states)
151212
return InferenceFuncOutput(out.logits, out.past_key_values)
152213

153-
inference_func_config = InferenceFuncConfig()
214+
inference_func_config = InferenceFuncConfig(num_beams=num_beams)
154215
for key in inference_func_config.__dataclass_fields__.keys():
155216
setattr(inference_func_config, key, getattr(raw_model.config, key))
156217
transformer_config = TransformerModelConfig(
@@ -171,6 +232,7 @@ def get_model(model_name: str,
171232
dtype=jnp.float16,
172233
dummy=False,
173234
batch_size=1,
235+
num_beams=1,
174236
decoding_length_per_step=1,
175237
num_micro_batches=1,
176238
support_output_attentions=False,
@@ -191,9 +253,9 @@ def get_model(model_name: str,
191253
f"Cannot support num_micro_batches > 1 in autoregressive mode.")
192254

193255
if "gpt" in model_name:
194-
return get_hf_gpt_model(model_name, device)
256+
return get_hf_gpt_model(model_name, device, num_beams)
195257
if "facebook/opt" in model_name:
196-
return get_hf_opt_model(model_name, device)
258+
return get_hf_opt_model(model_name, device, num_beams)
197259

198260
assert ("jax/opt" in model_name or "alpa/opt" in model_name)
199261
name = model_name.split("-")[1].upper()
@@ -220,7 +282,7 @@ def get_model(model_name: str,
220282

221283
# load params
222284
params = load_params_np(params_aval, path, config, dummy)
223-
init_cache = init_cache_np(config, batch_size=1)
285+
init_cache = init_cache_np(config, batch_size=batch_size * num_beams)
224286
params, init_cache = jax.tree_map(jnp.array, (params, init_cache))
225287
else:
226288
assert "alpa/opt" in model_name
@@ -238,9 +300,11 @@ def get_model(model_name: str,
238300
seq_len=config.max_target_positions,
239301
vocab_size=config.vocab_size)
240302

303+
if autoregressive:
304+
assert batch_size == 1, "we only support batch_sie = 1 for autoregressive!"
241305
executable, params_aval = get_pipeshard_executable(
242306
config,
243-
batch_size=batch_size,
307+
batch_size=batch_size * num_beams,
244308
num_micro_batches=num_micro_batches,
245309
decoding_length_per_step=decoding_length_per_step,
246310
support_output_attentions=support_output_attentions,
@@ -253,7 +317,7 @@ def get_model(model_name: str,
253317
if autoregressive:
254318
init_cache = init_cache_dis_array(executable,
255319
config,
256-
1,
320+
batch_size * num_beams,
257321
dummy=dummy)
258322
set_skip_shard_args_check(init_cache)
259323
executable.sync()
@@ -292,7 +356,7 @@ def inference_func(input_ids,
292356
return InferenceFuncOutput(logits_step, output.attention_cache,
293357
output.hidden_states, output.attentions)
294358

295-
inference_func_config = InferenceFuncConfig()
359+
inference_func_config = InferenceFuncConfig(num_beams=num_beams)
296360
return WrappedInferenceFunc(inference_func, inference_func_config,
297361
executable, transformer_config)
298362

examples/opt_serving/textgen_demo.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,21 @@
77
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False)
88
tokenizer.add_bos_token = False
99

10+
num_beams = 1
1011
# Load the model
1112
model = get_model(model_name="alpa/opt-2.7b",
1213
device="cuda",
13-
path="/home/ubuntu/opt_weights/")
14+
path="/home/ubuntu/efs/parax-proj/",
15+
num_beams=num_beams)
1416

1517
# Generate
1618
prompt = "Paris is the capital city of"
1719

1820
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
19-
output = model.generate(input_ids=input_ids, max_length=256, do_sample=True)
21+
output = model.generate(input_ids=input_ids,
22+
max_length=256,
23+
do_sample=True,
24+
num_beams=num_beams)
2025
generated_string = tokenizer.batch_decode(output, skip_special_tokens=True)
2126

2227
print(generated_string)

0 commit comments

Comments
 (0)