1
+ from functools import partial
1
2
import os
2
3
from typing import Sequence , Any
3
4
4
5
import alpa
5
6
import jax
7
+ from jax import xla
8
+ from jax import ShapeDtypeStruct , ShapedArray
9
+ from jax ._src .lib import xla_client as xc
10
+ from jax .core import Primitive
11
+ from jax .interpreters import pxla
12
+ from jax .interpreters .pxla import NoSharding , Replicated , ShardingSpec
6
13
import jax .numpy as jnp
7
14
import numpy as np
8
15
import torch
15
22
from examples .opt_serving .model .opt_utils import TransformerModelConfig
16
23
17
24
25
+ index_select_p = Primitive ("index-select" )
26
+ def jax_index_select (input , index , dim = 0 ):
27
+ return index_select_p .bind (input , index , dim = dim )
28
+
29
+ def _index_select_eval (input , index , dim ):
30
+ return input
31
+
32
+ def _index_select_translation (c , input , index , dim ):
33
+ return xc .ops .IndexSelect (input , index , dim )
34
+
35
+ index_select_p .def_abstract_eval (_index_select_eval )
36
+ index_select_p .def_impl (partial (xla .apply_primitive , index_select_p ))
37
+ xla .translations [index_select_p ] = _index_select_translation
38
+
39
+
18
40
@dataclass
19
41
class InferenceFuncOutput (ModelOutput ):
20
42
logits : Any = None
@@ -100,8 +122,46 @@ def __call__(self,
100
122
past_key_values = ret .past_key_values
101
123
return ret
102
124
103
-
104
- def get_hf_gpt_model (model_name , device ):
125
+ def _reorder_cache (self , past , beam_idx ):
126
+ # Current beam_idx is a torch tensor from beam scorer. To speedup,
127
+ # we need to have alpa's own beam scorer
128
+ cache = {}
129
+ cpu_idx = beam_idx .to ("cpu" ).numpy ()
130
+ if type (cpu_idx ) not in [ShapedArray , ShapeDtypeStruct ]:
131
+ cpu_idx = xla .canonicalize_dtype (cpu_idx )
132
+
133
+ def to_mesh (mesh ):
134
+ if mesh in cache :
135
+ return cache [mesh ]
136
+ avals = [ShapedArray (cpu_idx .shape , cpu_idx .dtype )]
137
+ replicated_spec = ShardingSpec ([NoSharding ()] * len (cpu_idx .shape ),
138
+ [Replicated (mesh .num_devices )])
139
+ specs = [replicated_spec ]
140
+ indices = [pxla .spec_to_indices (cpu_idx .shape , replicated_spec )]
141
+ ary = mesh .shard_args_to_arrays (avals , indices , specs , [cpu_idx ])[0 ]
142
+ cache [mesh ] = ary
143
+ return ary
144
+
145
+ def single_element_reorder_cache (ary ):
146
+ if hasattr (ary , "index_select" ):
147
+ # Torch or Alpa path
148
+ device_idx = None
149
+ if hasattr (ary , "device" ): # Torch to_device
150
+ device_idx = beam_idx .to (ary .device )
151
+ else :
152
+ device_idx = to_mesh (ary .device_mesh )
153
+ return ary .index_select (0 , device_idx )
154
+ else :
155
+ # Jax path
156
+ return jax_index_select (ary , cpu_idx , 0 )
157
+ return tuple (
158
+ tuple (
159
+ single_element_reorder_cache (past_state )
160
+ for past_state in layer_past )
161
+ for layer_past in past )
162
+
163
+
164
+ def get_hf_gpt_model (model_name , device , num_beams ):
105
165
raw_model = GPT2LMHeadModel .from_pretrained (model_name )
106
166
raw_model = raw_model .to (device )
107
167
@@ -116,6 +176,7 @@ def inference_func(input_ids,
116
176
return InferenceFuncOutput (out .logits , out .past_key_values )
117
177
118
178
inference_func_config = raw_model .config
179
+ inference_func_config .num_beams = num_beams
119
180
transformer_config = TransformerModelConfig (
120
181
H = raw_model .config .n_embd ,
121
182
L = raw_model .config .n_layer ,
@@ -127,7 +188,7 @@ def inference_func(input_ids,
127
188
executable , transformer_config )
128
189
129
190
130
- def get_hf_opt_model (model_name , device ):
191
+ def get_hf_opt_model (model_name , device , num_beams ):
131
192
raw_model = OPTForCausalLM .from_pretrained (
132
193
model_name ,
133
194
torch_dtype = torch .float16 if "cuda" in device else torch .float32 )
@@ -150,7 +211,7 @@ def inference_func(input_ids,
150
211
output_hidden_states = output_hidden_states )
151
212
return InferenceFuncOutput (out .logits , out .past_key_values )
152
213
153
- inference_func_config = InferenceFuncConfig ()
214
+ inference_func_config = InferenceFuncConfig (num_beams = num_beams )
154
215
for key in inference_func_config .__dataclass_fields__ .keys ():
155
216
setattr (inference_func_config , key , getattr (raw_model .config , key ))
156
217
transformer_config = TransformerModelConfig (
@@ -171,6 +232,7 @@ def get_model(model_name: str,
171
232
dtype = jnp .float16 ,
172
233
dummy = False ,
173
234
batch_size = 1 ,
235
+ num_beams = 1 ,
174
236
decoding_length_per_step = 1 ,
175
237
num_micro_batches = 1 ,
176
238
support_output_attentions = False ,
@@ -191,9 +253,9 @@ def get_model(model_name: str,
191
253
f"Cannot support num_micro_batches > 1 in autoregressive mode." )
192
254
193
255
if "gpt" in model_name :
194
- return get_hf_gpt_model (model_name , device )
256
+ return get_hf_gpt_model (model_name , device , num_beams )
195
257
if "facebook/opt" in model_name :
196
- return get_hf_opt_model (model_name , device )
258
+ return get_hf_opt_model (model_name , device , num_beams )
197
259
198
260
assert ("jax/opt" in model_name or "alpa/opt" in model_name )
199
261
name = model_name .split ("-" )[1 ].upper ()
@@ -220,7 +282,7 @@ def get_model(model_name: str,
220
282
221
283
# load params
222
284
params = load_params_np (params_aval , path , config , dummy )
223
- init_cache = init_cache_np (config , batch_size = 1 )
285
+ init_cache = init_cache_np (config , batch_size = batch_size * num_beams )
224
286
params , init_cache = jax .tree_map (jnp .array , (params , init_cache ))
225
287
else :
226
288
assert "alpa/opt" in model_name
@@ -238,9 +300,11 @@ def get_model(model_name: str,
238
300
seq_len = config .max_target_positions ,
239
301
vocab_size = config .vocab_size )
240
302
303
+ if autoregressive :
304
+ assert batch_size == 1 , "we only support batch_sie = 1 for autoregressive!"
241
305
executable , params_aval = get_pipeshard_executable (
242
306
config ,
243
- batch_size = batch_size ,
307
+ batch_size = batch_size * num_beams ,
244
308
num_micro_batches = num_micro_batches ,
245
309
decoding_length_per_step = decoding_length_per_step ,
246
310
support_output_attentions = support_output_attentions ,
@@ -253,7 +317,7 @@ def get_model(model_name: str,
253
317
if autoregressive :
254
318
init_cache = init_cache_dis_array (executable ,
255
319
config ,
256
- 1 ,
320
+ batch_size * num_beams ,
257
321
dummy = dummy )
258
322
set_skip_shard_args_check (init_cache )
259
323
executable .sync ()
@@ -292,7 +356,7 @@ def inference_func(input_ids,
292
356
return InferenceFuncOutput (logits_step , output .attention_cache ,
293
357
output .hidden_states , output .attentions )
294
358
295
- inference_func_config = InferenceFuncConfig ()
359
+ inference_func_config = InferenceFuncConfig (num_beams = num_beams )
296
360
return WrappedInferenceFunc (inference_func , inference_func_config ,
297
361
executable , transformer_config )
298
362
0 commit comments