Skip to content

Commit

Permalink
Memset prefill input tensors in text_generator_main binary.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697068488
  • Loading branch information
ai-edge-bot authored and copybara-github committed Nov 16, 2024
1 parent 9de82a7 commit cad366d
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ai_edge_torch/generative/examples/cpp/text_generator_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <ios>
#include <iterator>
Expand Down Expand Up @@ -284,6 +285,8 @@ int main(int argc, char* argv[]) {
// NOTE: We skip the last token and use that during decode.
int prefill_seq_size =
std::min(static_cast<int>(prompt_tokens.size()), max_seq_size);
std::memset(prefill_input->data.i32, 0, prefill_input->bytes);
std::memset(prefill_input_pos->data.i32, 0, prefill_input_pos->bytes);
for (int i = 0; i < prefill_seq_size - 1; ++i) {
prefill_input->data.i32[i] = prompt_tokens[i];
prefill_input_pos->data.i32[i] = i;
Expand Down

0 comments on commit cad366d

Please sign in to comment.