File tree Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -1317,6 +1317,7 @@ def forward(
13171317 unit_locations : Optional [Dict ] = None ,
13181318 source_representations : Optional [Dict ] = None ,
13191319 subspaces : Optional [List ] = None ,
1320+ labels : Optional [torch .LongTensor ] = None ,
13201321 output_original_output : Optional [bool ] = False ,
13211322 return_dict : Optional [bool ] = None ,
13221323 ):
@@ -1438,7 +1439,10 @@ def forward(
14381439 )
14391440
14401441 # run intervened forward
1441- counterfactual_outputs = self .model (** base )
1442+ if labels is not None :
1443+ counterfactual_outputs = self .model (** base , labels = labels )
1444+ else :
1445+ counterfactual_outputs = self .model (** base )
14421446 set_handlers_to_remove .remove ()
14431447
14441448 self ._output_validation ()
Original file line number Diff line number Diff line change 1010
1111setup (
1212 name = "pyvene" ,
13- version = "0.0.8dev " ,
13+ version = "0.0.8 " ,
1414 description = "Use Activation Intervention to Interpret Causal Mechanism of Model" ,
1515 long_description = long_description ,
1616 long_description_content_type = 'text/markdown' ,
You can’t perform that action at this time.
0 commit comments