Skip to content

Commit e8f80fa

Browse files
committed
Apply formatting
1 parent c7bf93c commit e8f80fa

File tree

16 files changed

+120
-77
lines changed

16 files changed

+120
-77
lines changed

ingress/mlir-gen/mlir_gen/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
23
from . import main
34

45
# Invoke on command line with `python -m mlir_gen`.

ingress/mlir-gen/mlir_gen/einsum.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from typing import Union
22

33
from mlir import ir
4-
from mlir.dialects import arith, linalg, tensor
5-
6-
from . import named, generic
7-
from .utils import get_outputs, get_weights, get_bias, affine_map
4+
from mlir.dialects import arith
5+
from mlir.dialects import linalg
6+
from mlir.dialects import tensor
7+
8+
from . import generic
9+
from . import named
10+
from .utils import affine_map
11+
from .utils import get_bias
12+
from .utils import get_outputs
13+
from .utils import get_weights
814

915

1016
def times_weights(

ingress/mlir-gen/mlir_gen/generic.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from typing import Union
22

33
from mlir import ir
4-
from mlir.dialects import linalg, arith, tensor, math
5-
6-
from .utils import (
7-
affine_map,
8-
get_bias,
9-
get_outputs,
10-
get_weights,
11-
parallel,
12-
reduction,
13-
)
4+
from mlir.dialects import arith
5+
from mlir.dialects import linalg
6+
from mlir.dialects import math
7+
from mlir.dialects import tensor
8+
9+
from .utils import affine_map
10+
from .utils import get_bias
11+
from .utils import get_outputs
12+
from .utils import get_weights
13+
from .utils import parallel
14+
from .utils import reduction
1415

1516

1617
def affine_maps_and_iter_types(rank: int):

ingress/mlir-gen/mlir_gen/main.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1-
import random
2-
import sys
3-
41
from argparse import ArgumentParser
5-
from typing import Sequence, Dict, Any, Optional
62
from collections import namedtuple
3+
import random
4+
import sys
5+
from typing import Any
6+
from typing import Dict
7+
from typing import Optional
8+
from typing import Sequence
79

810
import numpy as np
911

1012
from mlir import ir
1113
from mlir.dialects import func
1214

13-
from . import named, generic, einsum, utils as gen_utils
14-
15+
from . import einsum
16+
from . import generic
17+
from . import named
18+
from . import utils as gen_utils
1519

1620
BlockFactors = namedtuple("BlockFactors", "m n k vnni")
1721

@@ -136,9 +140,9 @@ def weights(
136140
assert k_as_num_inputs % block.k == 0, "invalid tile size for K dim"
137141
assert n_as_num_outputs % block.n == 0, "invalid tile size for N dim"
138142
if block.vnni:
139-
assert (
140-
block.n % block.vnni == 0
141-
), "incompatible tile sizes for N and VNNI dims"
143+
assert block.n % block.vnni == 0, (
144+
"incompatible tile sizes for N and VNNI dims"
145+
)
142146
shape = (
143147
n_as_num_outputs // block.n,
144148
k_as_num_inputs // block.k,

ingress/mlir-gen/mlir_gen/named.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from typing import Union
22

33
from mlir import ir
4-
from mlir.dialects import linalg, tensor, arith
4+
from mlir.dialects import arith
5+
from mlir.dialects import linalg
6+
from mlir.dialects import tensor
57

6-
from .utils import get_outputs, get_weights, get_bias
8+
from .utils import get_bias
9+
from .utils import get_outputs
10+
from .utils import get_weights
711

812

913
def times_weights(

ingress/mlir-gen/mlir_gen/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
import struct
2-
from enum import Enum, auto
31
from collections import abc
2+
from enum import Enum
3+
from enum import auto
4+
import struct
45
from typing import Union
56

67
import numpy as np
78

89
from mlir import ir
9-
from mlir.dialects import arith, linalg, tensor
10+
from mlir.dialects import arith
11+
from mlir.dialects import linalg
12+
from mlir.dialects import tensor
1013

1114

1215
class ConstantInitKind(Enum):

python/examples/ingress/torch/MLPModel/model.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,11 @@
33
import torch
44
import torch.nn as nn
55

6-
import os
76

87
class MLPModel(nn.Module):
98
def __init__(self):
109
super().__init__()
11-
self.net = nn.Sequential(
12-
nn.Linear(10, 32),
13-
nn.ReLU(),
14-
nn.Linear(32, 2)
15-
)
10+
self.net = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 2))
1611

1712
def forward(self, x):
1813
return self.net(x)

python/examples/ingress/torch/mlp_from_file.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from pathlib import Path
1818

1919
# MLIR infrastructure imports (only needed if you want to manipulate the MLIR module)
20-
import mlir.dialects.func as func
2120
from mlir import ir
21+
import mlir.dialects.func as func
2222

2323
# Lighthouse imports
2424
from lighthouse.ingress.torch import import_from_file
@@ -34,6 +34,7 @@
3434
# - Loads the MLPModel class and instantiates it with arguments obtained from 'get_init_inputs()'
3535
# - Calls get_sample_inputs() to get sample input tensors for shape inference
3636
# - Converts PyTorch model to linalg-on-tensors dialect operations using torch_mlir
37+
# fmt: off
3738
mlir_module_ir: ir.Module = import_from_file(
3839
model_path, # Path to the Python file containing the model
3940
model_class_name="MLPModel", # Name of the PyTorch nn.Module class to convert
@@ -42,6 +43,7 @@
4243
dialect="linalg-on-tensors", # Target MLIR dialect (linalg ops on tensor types)
4344
ir_context=ir_context # MLIR context for the conversion
4445
)
46+
# fmt: on
4547

4648
# The PyTorch model is now converted to MLIR at this point. You can now convert
4749
# the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file.

python/examples/ingress/torch/mlp_from_model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import torch
1616

1717
# MLIR infrastructure imports (only needed if you want to manipulate the MLIR module)
18-
import mlir.dialects.func as func
1918
from mlir import ir
19+
import mlir.dialects.func as func
2020

2121
# Lighthouse imports
2222
from lighthouse.ingress.torch import import_from_model
@@ -31,9 +31,7 @@
3131
ir_context = ir.Context()
3232
# Step 2: Convert the PyTorch model to MLIR
3333
mlir_module_ir: ir.Module = import_from_model(
34-
model,
35-
sample_args=(sample_input,),
36-
ir_context=ir_context
34+
model, sample_args=(sample_input,), ir_context=ir_context
3735
)
3836

3937
# The PyTorch model is now converted to MLIR at this point. You can now convert

python/examples/mlir/compile_and_run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# RUN: %PYTHON %s
22

3-
import torch
43
import argparse
54

5+
import torch
6+
67
from mlir import ir
78
from mlir.dialects import transform
89
from mlir.dialects.transform import structured

0 commit comments

Comments
 (0)