Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Extend the remove broadcast + squeeze pass #3635

Open
naoyam opened this issue Dec 23, 2024 · 4 comments · May be fixed by #3643
Open

Feature request: Extend the remove broadcast + squeeze pass #3635

naoyam opened this issue Dec 23, 2024 · 4 comments · May be fixed by #3643
Assignees
Labels

Comments

@naoyam
Copy link
Collaborator

naoyam commented Dec 23, 2024

Here's a pattern in the Mistral RoPE backward function:

Inputs:
  T3_g___bfloat[bS11{1}, iS12{32}, iS13{4096}, iS14{128}]
Outputs:
  T89_g___bfloat[bS371{1}, iS372{4096}, iS377{1024}rf]

%kernel_math {
T38_l___bfloat[bS153{1}, iS158{8}rf, iS159{4}rf, iS155{4096}, iS156{128}] = view( T3_g___bfloat[bS11{1}, iS12{32}, iS13{4096}, iS14{128}] )
T42_l_float[bS171{1}, iS172{8}, iS173{4}, iS174{4096}, iS175{128}]
   = __bfloat2float(T38_l___bfloat[bS153{1}, iS158{8}rf, iS159{4}rf, iS155{4096}, iS156{128}]);
T46_l_float[iS187{8}, iS188{4}, iS189{4096}, iS190{128}]
   = squeeze( T42_l_float[bS171{1}, iS172{8}, iS173{4}, iS174{4096}, iS175{128}] )
T47_l_float[iS191{8}, rS192{4}, iS193{4096}, iS194{128}]
   = reduction( T46_l_float[iS187{8}, iS188{4}, iS189{4096}, iS190{128}], op = add, initial value = float(0), allreduce = false )
T54_l___bfloat[iS221{8}, iS222{4096}, iS223{128}]
   = __float2bfloat(T47_l_float[iS191{8}, rS192{4}, iS193{4096}, iS194{128}]);
T63_l___bfloat[bS260{1}, iS261{8}, bS262{1}, iS263{4096}, iS264{128}]
   = broadcast( T54_l___bfloat[iS221{8}, iS222{4096}, iS223{128}] )
T64_l___bfloat[bS265{1}, iS266{8}, bS267{1}, iS268{4096}, iS269{128}]
   = Set( T63_l___bfloat[bS260{1}, iS261{8}, bS262{1}, iS263{4096}, iS264{128}], cache_op=Streaming )
T71_l_float[bS294{1}, iS295{8}, bS296{1}, iS297{4096}, iS298{128}]
   = __bfloat2float(T64_l___bfloat[bS265{1}, iS266{8}, bS267{1}, iS268{4096}, iS269{128}]);
T76_l_float[iS315{8}, iS316{4096}, iS317{128}]
   = squeeze( T71_l_float[bS294{1}, iS295{8}, bS296{1}, iS297{4096}, iS298{128}] )
T79_l___bfloat[iS326{8}, iS327{4096}, iS328{128}]
   = __float2bfloat(T76_l_float[iS315{8}, iS316{4096}, iS317{128}]);
T82_l___bfloat[bS337{1}, iS338{8}, iS339{4096}, iS340{128}]
   = broadcast( T79_l___bfloat[iS326{8}, iS327{4096}, iS328{128}] )
T83_l___bfloat[bS341{1}, iS342{8}, iS343{4096}, iS344{128}]
   = Set( T82_l___bfloat[bS337{1}, iS338{8}, iS339{4096}, iS340{128}], cache_op=Streaming )
T86_l___bfloat[bS353{1}, iS355{4096}, iS354{8}, iS356{128}]
   = Set.Permute( T83_l___bfloat[bS341{1}, iS342{8}, iS343{4096}, iS344{128}], cache_op=Streaming )
T89_g___bfloat[bS371{1}, iS372{4096}, iS377{1024}rf] = view( T86_l___bfloat[bS353{1}, iS355{4096}, iS354{8}, iS356{128}] )

This is currently segmented into two segements, one reduction and one pointwise.

g{(reduction)
group id: 5
inputs:
  T3_g___bfloat[bS11{1}, iS12{32}, iS13{4096}, iS14{128}] __bfloat
outputs:
  T54_g___bfloat[iS221{8}, iS222{4096}, iS223{128}] __bfloat


T38_l___bfloat[bS153{1}, iS158{8}rf, iS159{4}rf, iS155{4096}, iS156{128}] = view( T3_g___bfloat[bS11{1}, iS12{32}, iS13{4096}, iS14{128}] )
(43)
T42_g_float[bS171{1}, iS172{8}, iS173{4}, iS174{4096}, iS175{128}]
   = __bfloat2float(T38_l___bfloat[bS153{1}, iS158{8}rf, iS159{4}rf, iS155{4096}, iS156{128}]);
(47)
T46_g_float[iS187{8}, iS188{4}, iS189{4096}, iS190{128}]
   = squeeze( T42_g_float[bS171{1}, iS172{8}, iS173{4}, iS174{4096}, iS175{128}] )
(51)
T47_l_float[iS191{8}, rS192{4}, iS193{4096}, iS194{128}]
   = reduction( T46_g_float[iS187{8}, iS188{4}, iS189{4096}, iS190{128}], op = add, initial value = float(0), allreduce = false )
(52)
T54_g___bfloat[iS221{8}, iS222{4096}, iS223{128}]
   = __float2bfloat(T47_l_float[iS191{8}, rS192{4}, iS193{4096}, iS194{128}]);
(61)
}

g{(pointwise)
group id: 6
inputs:
  T54_g___bfloat[iS221{8}, iS222{4096}, iS223{128}] __bfloat
outputs:
  T89_g___bfloat[bS371{1}, iS372{4096}, iS377{1024}rf] __bfloat


T63_g___bfloat[bS260{1}, iS261{8}, bS262{1}, iS263{4096}, iS264{128}]
   = broadcast( T54_g___bfloat[iS221{8}, iS222{4096}, iS223{128}] )
(74)
T71_l_float[bS294{1}, iS295{8}, bS296{1}, iS297{4096}, iS298{128}]
   = __bfloat2float(T63_g___bfloat[bS260{1}, iS261{8}, bS262{1}, iS263{4096}, iS264{128}]);
(162)
T76_g_float[iS315{8}, iS316{4096}, iS317{128}]
   = squeeze( T71_l_float[bS294{1}, iS295{8}, bS296{1}, iS297{4096}, iS298{128}] )
(87)
T79_g___bfloat[iS326{8}, iS327{4096}, iS328{128}]
   = __float2bfloat(T76_g_float[iS315{8}, iS316{4096}, iS317{128}]);
(90)
T82_l___bfloat[bS337{1}, iS338{8}, iS339{4096}, iS340{128}]
   = broadcast( T79_g___bfloat[iS326{8}, iS327{4096}, iS328{128}] )
(93)
T86_g___bfloat[bS353{1}, iS355{4096}, iS354{8}, iS356{128}]
   = Set.Permute( T82_l___bfloat[bS337{1}, iS338{8}, iS339{4096}, iS340{128}], cache_op=Streaming )
(161)
T89_g___bfloat[bS371{1}, iS372{4096}, iS377{1024}rf] = view( T86_g___bfloat[bS353{1}, iS355{4096}, iS354{8}, iS356{128}] )
(103)
}

It seems the second segment should be just meta operations, but it's probably not detected as such due to the type cast ops. I think this should be safe to ignore the type cast ops and remove the broadcast and squeeze ops. With that, this segment would be just a no-op segment.

Note that while this is a part of a bwd function of the Mistral RoPE, the perf impact is likely small as it's just a small part of the overall fusion, as shown below. The above section corresponds to the upper right vertical sequence from T3 to T89.

mistral_bwd.pdf

@naoyam naoyam added the rope label Dec 23, 2024
@jjsjann123
Copy link
Collaborator

Thunder definition has the cast ops explicit in the trace. those are currently not cancelled out, since they are separated by the squeeze op. But we should be able to expand this and handle that: https://github.com/NVIDIA/Fuser/blob/main/csrc/preseg_passes/consecutive_cast.cpp

Naoya also mentioned that the broadcast/squeeze pattern could also cancel each other out: https://github.com/NVIDIA/Fuser/blob/main/csrc/preseg_passes/remove_bcast_squeeze.cpp

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Dec 24, 2024

Since broadcast and squeeze don't affect the values computed in the fusion directly I think they'll commute with most ops and we should be able to just move all of the broadcasts and squeezes toward the inputs or outputs as a pass before we combine bcast+squeeze. That way theyd be adjacent and we could remove these and the consecutive casts as normal afterward.

@jjsjann123
Copy link
Collaborator

An orthogonal note.

T86_g___bfloat[bS353{1}, iS355{4096}, iS354{8}, iS356{128}]
   = Set.Permute( T82_l___bfloat[bS337{1}, iS338{8}, iS339{4096}, iS340{128}], cache_op=Streaming )
(161)
T89_g___bfloat[bS371{1}, iS372{4096}, iS377{1024}rf] = view( T86_g___bfloat[bS353{1}, iS355{4096}, iS354{8}, iS356{128}] )
(103)

We'll need the input to the second kernel T54_g___bfloat[iS221{8}, iS222{4096}, iS223{128}] to have a view-compatible stride in order to be able to handle the view as a meta operation.

@jjsjann123
Copy link
Collaborator

Since broadcast and squeeze don't affect the values computed in the fusion directly I think they'll commute with most ops and we should be able to just move all of the broadcasts and squeezes toward the inputs or outputs as a pass before we combine bcast+squeeze. That way theyd be adjacent and we could remove these and the consecutive casts as normal afterward.

I think we can do the same for the consecutive cast pass as well.... i.e. they should be able to move across meta operations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants