Skip to content

Commit

Permalink
Merge pull request #7 from stereoplegic/issue-3
Browse files Browse the repository at this point in the history
fix: flashfftconv imports
  • Loading branch information
Zymrael authored Feb 4, 2024
2 parents 53c3b23 + b890de0 commit 9bd94e5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
tokenizers
transformers
flash_attn
flash_attn

### flashfftconv: uncomment both for use_flash_depthwise or use_flashfft
# git+https://github.com/HazyResearch/flash-fft-conv.git#subdirectory=csrc/flashfftconv
# git+https://github.com/HazyResearch/flash-fft-conv.git
31 changes: 18 additions & 13 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,19 @@ def __init__(self, config, layer_idx) -> None:
self.data_dtype = None

if self.use_flash_depthwise:
self.fir_fn = FlashDepthwiseConv1d(
channels=3 * self.hidden_size,
kernel_size=self.short_filter_length,
padding=self.short_filter_length - 1,
weights=self.short_filter_weight,
bias=self.short_filter_bias,
device=None,
dtype=self.config.get("depthwise_dtype", torch.bfloat16),
)
try:
from flashfftconv import FlashDepthwiseConv1d

self.fir_fn = FlashDepthwiseConv1d(
channels=3 * self.hidden_size,
kernel_size=self.short_filter_length,
padding=self.short_filter_length - 1,
weights=self.short_filter_weight,
bias=self.short_filter_bias,
device=None,
dtype=self.config.get("depthwise_dtype", torch.bfloat16),
)
except ImportError: "flashfftconv not installed"
else:
self.fir_fn = F.conv1d

Expand Down Expand Up @@ -324,10 +328,11 @@ def __init__(self, config):
self.norm = RMSNorm(config) if config.get("final_norm", True) else None
self.unembed = self.emb if config.tie_embeddings else VocabParallelEmbedding(config)

if config.get("use_flashfft", "False"):
from flashfftconv import FlashFFTConv

self.flash_fft = FlashFFTConv(2 * config.seqlen, dtype=torch.bfloat16)
if config.get("use_flashfft", "True"):
try:
from flashfftconv import FlashFFTConv
self.flash_fft = FlashFFTConv(config.seqlen, dtype=torch.bfloat16)
except ImportError: "flashfftconv not installed"
else:
self.flash_fft = None

Expand Down

0 comments on commit 9bd94e5

Please sign in to comment.