-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
02771fa
commit d4e5b53
Showing
3 changed files
with
79 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import keras_core | ||
from keras_core.ops import * | ||
from keras_core.ops import concatenate as concat | ||
from keras_core.ops import mean as reduce_mean | ||
from keras_core.ops import sum as reduce_sum | ||
from keras_core.ops import max as reduce_max | ||
from keras_core.ops import min as reduce_min | ||
from keras_core.ops import power as pow | ||
from keras_core.ops import clip as clip_by_value | ||
from keras_core.ops.image import extract_patches | ||
|
||
|
||
def resize(images, size, method="bilinear", preserve_aspect_ratio=False, antialias=False, name=None): | ||
return keras_core.ops.image.resize(images, size, interpolation=method, antialias=antialias, data_format=keras_core.backend.image_data_format()) | ||
|
||
|
||
def split(inputs, num_or_size_splits, axis=0, num=None, name="split"): | ||
if isinstance(num_or_size_splits, int): | ||
return keras_core.ops.split(inputs, num_or_size_splits, axis=axis) | ||
|
||
axis = (len(inputs.shape) + axis) if axis < 0 else axis | ||
split_axis_shape = inputs.shape[axis] | ||
assert split_axis_shape is not None | ||
|
||
size_splits = num_or_size_splits | ||
size_splits = [0 if ii is None or ii == -1 else ii for ii in size_splits] | ||
num_unknown_dim = sum([ii == 0 for ii in size_splits]) | ||
assert num_unknown_dim < 2, "At most one unknown dimension in num_or_size_splits: {}".format(num_or_size_splits) | ||
|
||
if num_unknown_dim == 1: | ||
size_splits = [(split_axis_shape - sum(size_splits)) if ii == 0 else ii for ii in size_splits] | ||
|
||
cum_split = [sum(num_or_size_splits[: id + 1]) for id, _ in enumerate(size_splits[:-1])] | ||
return keras_core.ops.split(inputs, cum_split, axis=axis) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from keras_cv_attention_models.llama2.llama2 import Llama2, Llama2_7B, RunPrediction, PositionalEncodingFourierRot1D, RMSNorm | ||
|
||
__head_doc__ = """ | ||
Keras implementation of [Github openai/gpt-2](https://github.com/openai/gpt-2). | ||
Paper [Language Models are Unsupervised Multitask Learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf). | ||
""" | ||
|
||
__tail_doc__ = """ vocab_size: model vocab size. | ||
max_block_size: number of tokens generated in each sample. | ||
include_top: boolena value if including output Dense head layer. Set false to exclude the head layer. | ||
dropout: float value for drop out rate for Embedding layer and attention blocks. | ||
activation: activation used in whole model, default `gelu/app`. | ||
pretrained: None or one of ["webtext", "huggingface"]. | ||
- if "webtext", will try to download and load ported weights if available. | ||
- if "huggingface", will try converting and loading weights from huggingface `transformers` pacakge. | ||
- if None, will initialize model with ranbdom weights. | ||
Returns: | ||
A `keras.Model` instance. | ||
""" | ||
|
||
Llama2.__doc__ = __head_doc__ + """ | ||
Args: | ||
num_blocks: . | ||
embedding_size: . | ||
num_heads: . | ||
block_use_bias: . | ||
model_name: string, model name. | ||
""" + __tail_doc__ + """ | ||
Model architectures: | ||
| Model | Params | FLOPs | vocab_size | LAMBADA PPL | | ||
| ------------| ------- | ------- | ---------- | ----------- | | ||
| GPT2_Base | 163.04M | 146.42G | 50257 | 35.13 | | ||
| GPT2_Medium | 406.29M | 415.07G | 50257 | 15.60 | | ||
| GPT2_Large | 838.36M | 890.28G | 50257 | 10.87 | | ||
| GPT2_XLarge | 1.638B | 1758.3G | 50257 | 8.63 | | ||
""" | ||
|
||
Llama2_7B.__doc__ = __head_doc__ + """ | ||
Args: | ||
""" + __tail_doc__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters