diff --git a/week10_efficiency/benchmark.ipynb b/week10_efficiency/benchmark.ipynb index f861861c..a614d505 100644 --- a/week10_efficiency/benchmark.ipynb +++ b/week10_efficiency/benchmark.ipynb @@ -50,7 +50,7 @@ "\n", "import torch\n", "from torch import Tensor\n", - "from torch import nn\n", + "import torch.nn as nn\n", "\n", "from transformers.models.llama.modeling_llama import LlamaForCausalLM\n", "from transformers.models.llama.configuration_llama import LlamaConfig\n",