|
| 1 | +from typing import Callable |
| 2 | +import dataclasses |
| 3 | +from collections import defaultdict |
| 4 | +import jax |
| 5 | +from jax import numpy as jnp |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +from jetstream.engine import engine_api |
| 9 | + |
| 10 | +import logging |
| 11 | + |
| 12 | +log = logging.getLogger(__name__) |
| 13 | + |
| 14 | + |
| 15 | +@dataclasses.dataclass |
| 16 | +class InputData: |
| 17 | + id: str |
| 18 | + tokens: jax.Array |
| 19 | + true_length: int |
| 20 | + |
| 21 | + |
| 22 | +class OfflineInference: |
| 23 | + |
| 24 | + def __init__(self, engine: engine_api.Engine, params=None): |
| 25 | + self.engine = engine |
| 26 | + self.decode_state = None |
| 27 | + if params is None: |
| 28 | + params = engine.load_params() |
| 29 | + self.params = params |
| 30 | + |
| 31 | + self.batch_size = engine.env.batch_size |
| 32 | + self.max_decode_length = engine.max_decode_length |
| 33 | + metadata = engine.get_tokenizer() |
| 34 | + self.tokenizer = engine.build_tokenizer(metadata) |
| 35 | + self.dummy = False |
| 36 | + |
| 37 | + self._cached_pref = {} |
| 38 | + self._cached_generate = None |
| 39 | + |
| 40 | + def init_decode_state(self): |
| 41 | + if self.decode_state is None: |
| 42 | + self.decode_state = self.engine.init_decode_state() |
| 43 | + |
| 44 | + def warmup(self, max_length=2048): |
| 45 | + self.init_decode_state() |
| 46 | + interesting_buckets = [ |
| 47 | + 32, |
| 48 | + 64, |
| 49 | + 128, |
| 50 | + 256, |
| 51 | + 512, |
| 52 | + 1024, |
| 53 | + 2048, |
| 54 | + 4096, |
| 55 | + ] |
| 56 | + for length in interesting_buckets: |
| 57 | + if length > max_length: |
| 58 | + break |
| 59 | + log.info(f"Compiling prefill: {length}") |
| 60 | + input_data = jax.ShapeDtypeStruct((length,), jnp.dtype("int32")) |
| 61 | + self._cached_pref[length] = ( |
| 62 | + jax.jit(self._prefill_insert, donate_argnums=(4,)) |
| 63 | + .lower( |
| 64 | + self.params, |
| 65 | + tokens=input_data, |
| 66 | + slot=0, |
| 67 | + true_length=length - 1, |
| 68 | + decode_state=self.decode_state) |
| 69 | + .compile() |
| 70 | + ) |
| 71 | + |
| 72 | + log.info(f"Compiling decode") |
| 73 | + self._cached_generate = ( |
| 74 | + jax.jit(self.engine.generate, donate_argnums=(1,)) |
| 75 | + .lower(self.params, self.decode_state) |
| 76 | + .compile() |
| 77 | + ) |
| 78 | + |
| 79 | + def _prefill_insert(self, params, tokens, slot, true_length, decode_state): |
| 80 | + """return decodestate.""" |
| 81 | + prefill_result, first_token = self.engine.prefill( |
| 82 | + params=params, padded_tokens=tokens, true_length=true_length |
| 83 | + ) |
| 84 | + decode_state = self.engine.insert(prefill_result, decode_state, slot=slot) |
| 85 | + return first_token, decode_state |
| 86 | + |
| 87 | + def batch_inference_with_callback( |
| 88 | + self, |
| 89 | + data: InputData, |
| 90 | + emit_first_token: Callable[[str, int], bool], |
| 91 | + emit_token: Callable[[str, int], bool], |
| 92 | + ): |
| 93 | + """callback is a function that takes id and token. It will be called once per output |
| 94 | +
|
| 95 | + token. |
| 96 | + """ |
| 97 | + |
| 98 | + def prefill(slot, tokens, true_length): |
| 99 | + nonlocal self |
| 100 | + if self.dummy: |
| 101 | + log.debug("dummy prefill") |
| 102 | + return 123 |
| 103 | + |
| 104 | + prefill_fn = self._prefill_insert |
| 105 | + if (cached := self._cached_pref.get(len(tokens))) is not None: |
| 106 | + prefill_fn = cached |
| 107 | + |
| 108 | + first_token, self.decode_state = prefill_fn( |
| 109 | + self.params, tokens=tokens, slot=slot, |
| 110 | + true_length=true_length, decode_state=self.decode_state |
| 111 | + ) |
| 112 | + return first_token.data[0][0].item() |
| 113 | + |
| 114 | + empty_slots = list(range(self.batch_size)) |
| 115 | + slot_to_id = {} |
| 116 | + |
| 117 | + dummy_length = 1 |
| 118 | + |
| 119 | + def decode(): |
| 120 | + log.debug("decode") |
| 121 | + nonlocal self |
| 122 | + nonlocal slot_to_id |
| 123 | + nonlocal dummy_length |
| 124 | + if self.dummy: |
| 125 | + log.debug("Dummy generate") |
| 126 | + res = engine_api.ResultTokens( |
| 127 | + data=np.array([[123, 1, dummy_length]] * self.batch_size), |
| 128 | + tokens_idx=(0, 0), |
| 129 | + valid_idx=(0, 0), |
| 130 | + length_idx=(0, 0), |
| 131 | + samples_per_slot=(0, 0), |
| 132 | + ) |
| 133 | + dummy_length += 1 |
| 134 | + self.decode_state, result_tokens = self.decode_state, res |
| 135 | + else: |
| 136 | + gen_fn = self.engine.generate |
| 137 | + if self._cached_generate is not None: |
| 138 | + gen_fn = self._cached_generate |
| 139 | + self.decode_state, result_tokens = gen_fn( |
| 140 | + self.params, self.decode_state |
| 141 | + ) |
| 142 | + |
| 143 | + result_tokens = result_tokens.convert_to_numpy() |
| 144 | + |
| 145 | + newly_empty = [] |
| 146 | + for slot, id_ in slot_to_id.items(): |
| 147 | + token, is_valid, length = result_tokens.data[slot] |
| 148 | + log.debug(f"slot is {slot}, length is {length}") |
| 149 | + should_finish = False |
| 150 | + if is_valid: |
| 151 | + should_finish = emit_token(id_, token.item()) |
| 152 | + if should_finish or length >= self.max_decode_length: |
| 153 | + newly_empty.append(slot) |
| 154 | + |
| 155 | + # Add slots of those that are empty to emtpy |
| 156 | + for slot in newly_empty: |
| 157 | + del slot_to_id[slot] |
| 158 | + empty_slots.append(slot) |
| 159 | + |
| 160 | + for row in data: |
| 161 | + log.debug(f"empty_slots {len(empty_slots)}") |
| 162 | + while not empty_slots: |
| 163 | + # If slots are all full, decode until there are free slots |
| 164 | + # to insert |
| 165 | + decode() |
| 166 | + # do one insert |
| 167 | + log.debug(f"prefill {row.id}") |
| 168 | + slot = empty_slots.pop() |
| 169 | + first_token = prefill(slot, row.tokens, row.true_length) |
| 170 | + should_terminate = emit_first_token(row.id, first_token) |
| 171 | + if not should_terminate: |
| 172 | + slot_to_id[slot] = row.id |
| 173 | + else: |
| 174 | + empty_slots.append(slot) # dont use the slot |
| 175 | + |
| 176 | + while slot_to_id: |
| 177 | + log.debug(f"slot to id {len(slot_to_id)}") |
| 178 | + decode() |
| 179 | + |
| 180 | + def batch_inference(self, data: InputData): |
| 181 | + """data is list of obj with id, tokens, and true length""" |
| 182 | + ans = defaultdict(list) |
| 183 | + |
| 184 | + def callback(id_, token): |
| 185 | + nonlocal ans |
| 186 | + ans[id_].append(token) |
| 187 | + return token == self.tokenizer.eos_id |
| 188 | + |
| 189 | + self.batch_inference_with_callback( |
| 190 | + data, emit_first_token=callback, emit_token=callback |
| 191 | + ) |
| 192 | + return ans |
0 commit comments