Skip to content

Commit

Permalink
more fixes for shorting
Browse files Browse the repository at this point in the history
And more tests
  • Loading branch information
brettelliot committed Nov 18, 2024
1 parent a68bf3c commit d1a3e69
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 29 deletions.
61 changes: 33 additions & 28 deletions lumibot/components/drift_rebalancer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,41 +211,46 @@ def _calculate_drift(self) -> pd.DataFrame:
total_value = self.df["current_value"].sum()
self.df["current_weight"] = self.df["current_value"] / total_value
self.df["target_value"] = self.df["target_weight"] * total_value
self.df["drift"] = self.df.apply(self._calculate_drift_row, axis=1)
return self.df.copy()

def calculate_drift_row(row: pd.Series) -> Decimal:
if row["is_quote_asset"]:
# We can never buy or sell the quote asset
return Decimal(0)
def _calculate_drift_row(self, row: pd.Series) -> Decimal:

elif row["current_weight"] == Decimal(0) and row["target_weight"] == Decimal(0):
return Decimal(0)
if row["is_quote_asset"]:
# We can never buy or sell the quote asset
return Decimal(0)

# Check if we should sell everything
elif row["current_quantity"] > Decimal(0) and row["target_weight"] == Decimal(0):
return Decimal(-1)
elif row["current_weight"] == Decimal(0) and row["target_weight"] == Decimal(0):
# Should nothing change?
return Decimal(0)

# Check if we need to buy for the first time
elif row["current_quantity"] == Decimal(0) and row["target_weight"] > Decimal(0):
return Decimal(1)
elif row["current_quantity"] > Decimal(0) and row["target_weight"] == Decimal(0):
# Should we sell everything
return Decimal(-1)

# Check if we need to short everything
elif row["current_quantity"] == Decimal(0) and row["target_weight"] == Decimal(-1):
return Decimal(-1)
elif row["current_quantity"] == Decimal(0) and row["target_weight"] > Decimal(0):
# We don't have any of this asset but we wanna buy some.
return Decimal(1)

# Otherwise we just need to adjust our holding. Calculate the drift.
else:
if self.drift_type == DriftType.ABSOLUTE:
return row["target_weight"] - row["current_weight"]
elif self.drift_type == DriftType.RELATIVE:
# Relative drift is calculated by: difference / target_weight.
# Example: target_weight=0.20 and current_weight=0.23
# The drift is (0.20 - 0.23) / 0.20 = -0.15
return (row["target_weight"] - row["current_weight"]) / row["target_weight"]
else:
raise ValueError(f"Invalid drift_type: {self.drift_type}")
elif row["current_quantity"] == Decimal(0) and row["target_weight"] == Decimal(-1):
# Should we short everything we have
return Decimal(-1)

self.df["drift"] = self.df.apply(calculate_drift_row, axis=1)
return self.df.copy()
elif row["current_quantity"] == Decimal(0) and row["target_weight"] < Decimal(0):
# We don't have any of this asset but we wanna short some.
return Decimal(-1)

# Otherwise we just need to adjust our holding. Calculate the drift.
else:
if self.drift_type == DriftType.ABSOLUTE:
return row["target_weight"] - row["current_weight"]
elif self.drift_type == DriftType.RELATIVE:
# Relative drift is calculated by: difference / target_weight.
# Example: target_weight=0.20 and current_weight=0.23
# The drift is (0.20 - 0.23) / 0.20 = -0.15
return (row["target_weight"] - row["current_weight"]) / row["target_weight"]
else:
raise ValueError(f"Invalid drift_type: {self.drift_type}")


class DriftOrderLogic:
Expand Down
64 changes: 63 additions & 1 deletion tests/test_drift_rebalancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,68 @@ def mock_add_positions(self):
check_names=False
)

def test_drift_is_negative_one_when_we_have_none_of_an_asset_and_target_weights_says_we_should_short_some(self, mocker):
strategy = MockStrategyWithDriftCalculationLogic(
broker=self.backtesting_broker,
drift_threshold=Decimal("0.05"),
drift_type=DriftType.ABSOLUTE,
shorting=True
)
target_weights = {
"AAPL": Decimal("0.25"),
"GOOGL": Decimal("0.25"),
"MSFT": Decimal("0.25"),
"AMZN": Decimal("-0.25")
}

def mock_add_positions(self):
self._add_position(
symbol="AAPL",
is_quote_asset=False,
current_quantity=Decimal("10"),
current_value=Decimal("1500")
)
self._add_position(
symbol="GOOGL",
is_quote_asset=False,
current_quantity=Decimal("5"),
current_value=Decimal("1000")
)
self._add_position(
symbol="MSFT",
is_quote_asset=False,
current_quantity=Decimal("8"),
current_value=Decimal("800")
)

mocker.patch.object(DriftCalculationLogic, "_add_positions", mock_add_positions)
df = strategy.drift_rebalancer_logic.calculate(target_weights=target_weights)

pd.testing.assert_series_equal(
df["current_weight"],
pd.Series([
Decimal('0.4545454545454545454545454545'),
Decimal('0.3030303030303030303030303030'),
Decimal('0.2424242424242424242424242424'),
Decimal('0')
]),
check_names=False
)

assert df["target_value"].tolist() == [Decimal("825"), Decimal("825"), Decimal("825"), Decimal("-825")]

pd.testing.assert_series_equal(
df["drift"],
pd.Series([
Decimal('-0.2045454545454545454545454545'),
Decimal('-0.0530303030303030303030303030'),
Decimal('0.0075757575757575757575757576'),
Decimal('-1')
]),
check_names=False
)


def test_drift_is_zero_when_current_weight_and_target_weight_are_zero(self, mocker):
strategy = MockStrategyWithDriftCalculationLogic(
broker=self.backtesting_broker,
Expand Down Expand Up @@ -586,7 +648,7 @@ def mock_add_positions(self):

assert df["current_weight"].tolist() == [Decimal("0.0"), Decimal("1.0")]
assert df["target_value"].tolist() == [Decimal("-500"), Decimal("500")]
assert df["drift"].tolist() == [Decimal("-0.50"), Decimal("0")]
assert df["drift"].tolist() == [Decimal("-1.0"), Decimal("0")]

def test_calculate_absolute_drift_when_we_want_a_100_percent_short_position(self, mocker):
strategy = MockStrategyWithDriftCalculationLogic(
Expand Down

0 comments on commit d1a3e69

Please sign in to comment.