Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ISSUE] 加载模型时,dtype为什么写死为float32? #94

Closed
3 tasks done
wenyangchou opened this issue Jul 7, 2024 · 2 comments
Closed
3 tasks done

[ISSUE] 加载模型时,dtype为什么写死为float32? #94

wenyangchou opened this issue Jul 7, 2024 · 2 comments

Comments

@wenyangchou
Copy link
Contributor

阅读 README.md 和 dependencies.md

  • 我已经阅读过 README.md 和 dependencies.md 文件

检索 issue 和 discussion

  • 我已经确认之前没有 issue 或 discussion 涉及此 BUG

检查 Forge 版本

  • 我已经确认问题发生在最新代码或稳定版本中

你的issues

https://github.com/lenML/ChatTTS-Forge/blob/main/modules/models.py#L34

def load_chat_tts_in_thread():
    global chat_tts
    if chat_tts:
        return

    logger.info("Loading ChatTTS models")
    chat_tts = ChatTTS.Chat()
    device = devices.get_device_for("chattts")
    dtype = devices.dtype
    chat_tts.load(
        compile=config.runtime_env_vars.compile,
        use_flash_attn=config.runtime_env_vars.flash_attn,
        source="custom",
        custom_path="./models/ChatTTS",
        device=device,
        dtype=dtype,
        # dtype_vocos=devices.dtype_vocos,
        # dtype_dvae=devices.dtype_dvae,
        # dtype_gpt=devices.dtype_gpt,
        # dtype_decoder=devices.dtype_decoder,
    )

这边的dtype为什么会注释掉,写死为float32?是有什么坑吗,这样写flash_attn用不起来啊

@wenyangchou
Copy link
Contributor Author

transformer驱动起来的时候有个bug,不会正常的加载dtype。是这个原因?

@zhzLuke96
Copy link
Member

请提供报错信息,不然无法定位问题

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants