Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async Stream Genenerator? #604

Open
KingBipo opened this issue Aug 28, 2024 · 3 comments
Open

Async Stream Genenerator? #604

KingBipo opened this issue Aug 28, 2024 · 3 comments

Comments

@KingBipo
Copy link

Using Stream Generator Like Async Dynamic Generator for Concurrent Text Generation

Description:

Hello,

I have been working with the stream generator and I am wondering if there is a way to use it in a manner similar to the Async Dynamic generator to generate two texts simultaneously using the same model.

I've attempted the following approach with the example code, however the generation process keeps waiting for the completion of the first text before starting the next one:

async def start_myServer(self):
        ip = "127.0.0.1"
        port = 8765
        async with websockets.serve(self.handle_client,ip, port): 
            print(f'Start Server -> {ip}:{port}')
            await asyncio.Future()  # Run forever
        
    async def handle_client(self, websocket, path):
        message = await websocket.recv()
        client_id = str(id(websocket))  # Unique identifier for each client
        self.user_prompts[client_id] = []
        self.responses_ids[client_id] = []
        await self.process_message(websocket, client_id,message)
          
    async def process_message(self, ws, client_id, message):
        self.user_prompts[client_id].append(message)
                
        async with self.model_lock:
            print(f'Start processing for: {client_id}')
            torch.cuda.empty_cache()
            max_len = self.model.config.max_seq_len - self.min_space_in_ctx
            a_ctx = self.get_tokenized_context(max_len, client_id)
            
            # Run the blocking code in a separate thread using asyncio.to_thread
            response_text = await asyncio.to_thread(self._generate_response, a_ctx, client_id)

            # Now handle the WebSocket communication asynchronously
            await ws.send(response_text)
            self.busy = False
            
    def _generate_response(self, a_ctx, client_id):
        self.generator.begin_stream_ex(a_ctx, self.sampler_Settings)
        
        response_tokens = 0
        response_text = ""
        self.responses_ids[client_id].append(torch.empty((1, 0), dtype=torch.long))
        
        while True:
            # Get response stream
            res = self.generator.stream_ex()
            chunk = res["chunk"]
            eos = res["eos"]
            tokens = res["chunk_token_ids"]
                
            if len(response_text) == 0: 
                chunk = chunk.lstrip()
            response_text += chunk
            self.responses_ids[client_id][-1] = torch.cat([self.responses_ids[client_id][-1], tokens], dim=-1)

            if self.generator.full():
                active_context = self.get_tokenized_context(self.model.config.max_seq_len - self.min_space_in_ctx, model)
                self.generator.begin_stream(active_context, self.sampler_Settings)

            response_tokens += 1
            if response_tokens == self.max_response_tokens:
                if self.tokenizer.eos_token_id in self.generator.stop_tokens:
                    self.responses_ids[client_id][-1] = torch.cat([self.responses_ids[client_id][-1], self.tokenizer.single_token(self.tokenizer.eos_token_id)], dim=-1)

                print(f"\n!! Response exceeded {self.max_response_tokens} tokens and was cut short.")
                break

            if eos:
                break

        print(f'End processing for: {client_id}')
        return response_text
        
    def get_tokenized_context(self, max_len, client_id):
        print(f'Start CTX for: {client_id}')
        while True:
            ctx = torch.empty((1, 0), dtype=torch.long)
            
            for i in range(len(self.user_prompts[client_id])): #TODO emb fur jeden input brahct aber immer nur der neue
                up_text = self.format_promt(self.user_prompts[client_id][i], ctx.shape[-1] == 0)
                    
                ids = self.tokenizer.encode(
                    up_text, 
                    encode_special_tokens=True, 
                    add_bos=True
                )
                ctx  = torch.cat((ctx,ids),dim=-1)
                        
                if i < len(self.responses_ids[client_id]):
                    ctx = torch.cat([ctx, self.responses_ids[client_id][i]], dim=-1)
                    
            if ctx.shape[-1] < max_len:
                print(f'End CTX for: {client_id}')
                return ctx

            # If the context is too long, remove the first Q/A pair and try again
            self.user_prompts[client_id] = self.user_prompts[client_id][1:]
            self.responses_ids[client_id] = self.responses_ids[client_id][1:]
            
    def format_promt(self,user_prompt,first):
        if first:
            return """<|start_header_id|>system<|end_header_id|>\n\n
                <|system_prompt|><|eot_id|>
                <|start_header_id|>user<|end_header_id|>\n\n
                <|user_prompt|><|eot_id|>
                <|start_header_id|>assistant<|end_header_id|>""".replace("<|system_prompt|>", self.system_prompt) .replace("<|user_prompt|>", user_prompt)
        else: 
            return """<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n
                <|user_prompt|><|eot_id|>
                <|start_header_id|>assistant<|end_header_id|>""".replace("<|user_prompt|>", user_prompt)

My Question:

  • Is it possible to use the stream generator like the Async Dynamic generator to generate two texts at the same time on the same model?
  • If yes, could you provide an help or guidance on how to implement this?

Any help or advice would be greatly appreciated!

Thank you for your time and assistance!

@KingBipo
Copy link
Author

I am using the stream generator mainly because of the input_embedding implementation when starting a stream. If there is a way to recreate something like that with an asynchronous dynamic generator, that would be helpful too.

This is what i have for the Embedding in the StreamingGenerator and it works fine:

self.embedding_layer = next(m for m in model.modules if isinstance(m, ExLlamaV2Embedding))
self.emb = self.prepair_embedding_tensor(e_str)

def prepair_embedding_tensor(self,s:str) -> torch.tensor:
        s_tokens = self.tokenizer.encode(
            s,
            add_bos=True, 
            encode_special_tokens=True
        )
        print("s_tokens Shape:", s_tokens.shape)
        embedding_tensor = self.embedding_layer.forward(hidden_states=s_tokens)
        print("Embedding Tensor Shape:", embedding_tensor.shape)
        return embedding_tensor

self.generator.begin_stream_ex(a_ctx, sampler_settings, input_embeddings=self.emb)

If something similar is possible with the async dynamic generator, please help.

@turboderp
Copy link
Owner

The regular streaming generator doesn't support batching, let alone continuous batching like you'd want here. It would be much simpler to add indexed embeddings to the dynamic generator, which I suppose I do want to get to. The main challenge is that the indices would need to persist and be unique over the lifetime of the model to not cause errors with caching.

E.g. if a context starts with [EMBEDDING_INDEX + i for i in range(1024)] that will create four pages of keys/values that could persist indefinitely and end up being reused the next time you use those same indices, even if you want to use different underlying embeddings.

So I need to cook up some kind of system for creating indices and managing the lifetime of the corresponding embeddings.

@KingBipo
Copy link
Author

Thank you for your response! I'm using the dynamic generator for now, so I can implement the embeddings with it when the time comes. I wish you success with this implementation and once again, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants