Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul backend function execution for improved performance and flexibility #270

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jhalakpatel
Copy link
Collaborator

@jhalakpatel jhalakpatel commented Oct 15, 2024

This PR replaces the DPS-style calling convention with a non-DPS approach, eliminating the requirement for call sites to preallocate output buffers. This change enables us to bypass the computation of output shapes and advance allocation of output buffers, laying the groundwork for supporting data-dependent shapes where network outputs can have dynamic dimensions.

The underlying compiler stack has been enhanced to avoid allocating oversized buffers and eliminate an extra device-to-device copy operation from TensorRT-allocated memory to MLIR-TRT managed memory.

Additionally, we've improved the copy operation to support copying to host memory. This enhancement removes the need to track output device allocations for device-to-host copies. Previously, copy outputs were restricted to device allocations; now they can be allocated on both device and host.

Tests have been updated to align with the new calling convention, ensuring compatibility and correctness.

@jhalakpatel jhalakpatel force-pushed the jhalakp-use-non-dps-exec-func branch from ba2dd98 to b745b6a Compare October 21, 2024 18:57
@jhalakpatel jhalakpatel force-pushed the jhalakp-use-non-dps-exec-func branch from b745b6a to 780e18b Compare November 4, 2024 22:01
@jhalakpatel jhalakpatel changed the title Update MLIR-TRT execution function to use non-DPS style calling convention Overhaul backend function execution for improved performance and flexibility Nov 4, 2024
@jhalakpatel jhalakpatel marked this pull request as ready for review November 4, 2024 22:03
@jhalakpatel jhalakpatel force-pushed the jhalakp-use-non-dps-exec-func branch 2 times, most recently from ceeeaf6 to 19e841d Compare November 4, 2024 22:13
@@ -188,7 +188,7 @@ def eval(self) -> runtime.MemRefValue:
executor = Executor(executable)
# Upon computing the value of this tensor, we switch it to have a `Storage`
# parameter so that it does not need to be computed again.
data = executor.execute([out.device for out in flat_ir.outputs])
data = executor.execute()
executor.stream.synchronize()
assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor"
data = data[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also remove the hack a few lines below?

self.trace_tensor.device = flat_ir.outputs[0].device

Copy link
Collaborator Author

@jhalakpatel jhalakpatel Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still a few issues:

  1. MLIR-TensorRT still allocates inputs only on the device.
  2. This PR fixes the copy operation only.

There are still issues since the input get always allocated on device. There will be a device mismatch between inputs (which are always on device) vs output which could be now on (host as well as device).

    def infer_devices(self):
        """
        Infers devices for the operation and updates output tensor devices accordingly.
        """
        assert (
>           self.inputs and len(self.outputs) == 1 and all(inp.device == self.inputs[0].device for inp in self.inputs)
        ), "Default implementation cannot handle cases where there are no inputs, multiple outputs, or multiple inputs with different devices. Please override."
E       AssertionError: Default implementation cannot handle cases where there are no inputs, multiple outputs, or multiple inputs with different devices. Please override.

@jhalakpatel jhalakpatel force-pushed the jhalakp-use-non-dps-exec-func branch 2 times, most recently from 5247906 to 1914775 Compare November 6, 2024 06:03
@jhalakpatel jhalakpatel force-pushed the jhalakp-use-non-dps-exec-func branch 2 times, most recently from a44e199 to 4986057 Compare November 9, 2024 00:13
…ibility

This PR replaces the DPS-style calling convention with a non-DPS approach, eliminating the requirement for call sites to preallocate output buffers. This change enables us to bypass the computation of output shapes and advance allocation of output buffers, laying the groundwork for supporting data-dependent shapes where network outputs can have dynamic dimensions.

The underlying compiler stack has been enhanced to avoid allocating oversized buffers and eliminate an extra device-to-device copy operation from TensorRT-allocated memory to MLIR-TRT managed memory.

Additionally, we've improved the copy operation to support copying to host memory. This enhancement removes the need to track output device allocations for device-to-host copies. Previously, copy outputs were restricted to device allocations; now they can be allocated on both device and host.

Tests have been updated to align with the new calling convention, ensuring compatibility and correctness.

Other changes:
Fix type constraints tests
Address review comments
@jhalakpatel jhalakpatel force-pushed the jhalakp-use-non-dps-exec-func branch from 4986057 to 383c182 Compare November 9, 2024 00:14
@@ -186,10 +186,11 @@ def eval(self) -> runtime.MemRefValue:

compiler = Compiler(trt_builder_opt_level=0)
executable = compiler.compile(mlir, flat_ir=flat_ir)
# Ensure that session and client are available as long as tensor lives.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants