Skip to content

Commit

Permalink
Add migraphx ep for llama2 scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
turneram committed Aug 28, 2024
1 parent c604fe9 commit 29f9c36
Show file tree
Hide file tree
Showing 7 changed files with 922 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer,
"GlobalMaxPool",
"Greater",
"GreaterOrEqual",
"GroupQueryAttention",
"HardSigmoid",
"HardSwish",
"Identity",
Expand Down Expand Up @@ -903,8 +904,10 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer,
"Shape",
"Sigmoid",
"Sign",
"SimplifiedLayerNormalization",
"Sin",
"Sinh",
"SkipSimplifiedLayerNormalization",
"Slice",
"Softmax",
"Softplus",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def get_args(rank=0):
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
choices=["cpu", "cuda", "rocm"],
choices=["cpu", "cuda", "rocm", "migraphx"],
)
parser.add_argument("-id", "--device-id", type=int, default=0)
parser.add_argument("-w", "--warmup-runs", type=int, default=5)
Expand Down Expand Up @@ -622,9 +622,11 @@ def get_args(rank=0):
# Set runtime properties
if "ort" in args.benchmark_type:
setattr(args, "execution_provider", f"{args.device.upper()}ExecutionProvider") # noqa: B010
if args.device == "migraphx":
setattr(args, "execution_provider", "MIGraphXExecutionProvider")
if args.execution_provider == "CUDAExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": rank})
elif args.execution_provider == "ROCMExecutionProvider":
elif args.execution_provider == "ROCMExecutionProvider" or args.execution_provider == "MIGraphXExecutionProvider":
args.execution_provider = (args.execution_provider, {"device_id": rank})
args.device = "cuda"

Expand Down
Loading

0 comments on commit 29f9c36

Please sign in to comment.