Skip to content

Commit

Permalink
Store more Assistant state in ChatContext to uncouple it from compo…
Browse files Browse the repository at this point in the history
…nent tree (#788)

* Merge ChatInputContext with ChatContext

* Move state from useAssistant to ChatProvider

* Move isLoading state from useAssistantApi to ChatProvider

* Correct nextMessageIdRef value when loading

* Rename nextMessageIdRef

* WIP - Fix scroll issue

* Don't memoize AuthenticatedView

* Fix tests

* Bring back memo

* Add selectedSite.id to key prop

* Smaller diff

* Add comment about React strict mode
  • Loading branch information
fredrikekelund authored Jan 14, 2025
1 parent 4d3a4b2 commit 08a8b3e
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 77 deletions.
48 changes: 25 additions & 23 deletions src/components/content-tab-assistant.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ interface AuthenticatedViewProps {
>[ 'markMessageAsFeedbackReceived' ];
}

export const AuthenticatedView = memo(
const AuthenticatedView = memo(
( {
messages,
isAssistantThinking,
Expand All @@ -125,7 +125,7 @@ export const AuthenticatedView = memo(
}: AuthenticatedViewProps ) => {
const endOfMessagesRef = useRef< HTMLDivElement >( null );
const lastMessageRef = useRef< HTMLDivElement >( null );
const [ showThinking, setShowThinking ] = useState( false );
const [ showThinking, setShowThinking ] = useState( isAssistantThinking );
const lastMessage = useMemo(
() =>
showThinking
Expand All @@ -135,23 +135,27 @@ export const AuthenticatedView = memo(
);
const messagesToRender =
messages[ messages.length - 1 ]?.role === 'assistant' ? messages.slice( 0, -1 ) : messages;
const showLastMessage = showThinking || messages[ messages.length - 1 ]?.role === 'assistant';
const previousMessagesLength = useRef( messages?.length );
const previousSiteId = useRef( siteId );

const showLastMessage = lastMessage?.role === 'assistant';
const previousMessagesLength = useRef( messages.length );
const isInitialRenderRef = useRef( true );

// This effect may run twice when the component is mounted, which makes the viewport scroll
// to the wrong position. This happens because the app runs in React strict mode, meaning
// it only affects the development environment. For more details, see
// https://github.com/Automattic/studio/pull/788#issuecomment-2586644007
useEffect( () => {
if ( ! messages?.length ) {
previousSiteId.current = siteId;
if ( ! messages.length ) {
return;
}

let timer: NodeJS.Timeout;
// Scroll to the end of the messages when the tab is opened or site ID changes
if ( previousMessagesLength.current === 0 || previousSiteId.current !== siteId ) {
if ( isInitialRenderRef.current ) {
endOfMessagesRef.current?.scrollIntoView( { behavior: 'instant' } );
isInitialRenderRef.current = false;
}
// Scroll when a new message is added
else if ( messages?.length > previousMessagesLength.current || showLastMessage ) {
else if ( messages.length > previousMessagesLength.current || showLastMessage ) {
// Scroll to the beginning of last message received from the assistant
if ( showLastMessage ) {
timer = setTimeout( () => {
Expand All @@ -166,11 +170,10 @@ export const AuthenticatedView = memo(
}
}

previousMessagesLength.current = messages?.length;
previousSiteId.current = siteId;
previousMessagesLength.current = messages.length;

return () => clearTimeout( timer );
}, [ messages?.length, showLastMessage, siteId ] );
}, [ messages.length, showLastMessage ] );

useEffect( () => {
let timer: NodeJS.Timeout;
Expand Down Expand Up @@ -486,16 +489,15 @@ export function ContentTabAssistant( { selectedSite }: ContentTabAssistantProps
siteId={ selectedSite.id }
disabled={ disabled }
/>
{
<AuthenticatedView
messages={ messages }
isAssistantThinking={ isAssistantThinking }
updateMessage={ updateMessage }
markMessageAsFeedbackReceived={ markMessageAsFeedbackReceived }
siteId={ selectedSite.id }
submitPrompt={ submitPrompt }
/>
}

<AuthenticatedView
messages={ messages }
isAssistantThinking={ isAssistantThinking }
updateMessage={ updateMessage }
markMessageAsFeedbackReceived={ markMessageAsFeedbackReceived }
siteId={ selectedSite.id }
submitPrompt={ submitPrompt }
/>
</>
) : (
! isOffline && <UnauthenticatedView onAuthenticate={ authenticate } />
Expand Down
8 changes: 7 additions & 1 deletion src/components/site-content-tabs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ export function SiteContentTabs() {
{ name === 'share' && <ContentTabSnapshots selectedSite={ selectedSite } /> }
{ name === 'sync' && <ContentTabSync selectedSite={ selectedSite } /> }
{ name === 'settings' && <ContentTabSettings selectedSite={ selectedSite } /> }
{ name === 'assistant' && <ContentTabAssistant selectedSite={ selectedSite } /> }
{ name === 'assistant' && (
<ContentTabAssistant
// TODO: Remove this key once https://github.com/Automattic/dotcom-forge/issues/10219 is fixed
key={ selectedTab + selectedSite.id }
selectedSite={ selectedSite }
/>
) }
{ name === 'import-export' && (
<ContentTabImportExport selectedSite={ selectedSite } />
) }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
import { renderHook, act } from '@testing-library/react';
import { ReactNode } from 'react';
import { getIpcApi } from '../../lib/get-ipc-api';
import { useAssistant } from '../use-assistant';
import { ChatProvider } from '../use-chat-context';
import { useGetWpVersion } from '../use-get-wp-version';
import { ThemeDetailsProvider } from '../use-theme-details';

jest.mock( '../../lib/get-ipc-api' );
jest.mock( '../use-get-wp-version' );

function ContextWrapper( { children }: { children: ReactNode } ) {
return (
<ThemeDetailsProvider>
<ChatProvider>{ children }</ChatProvider>
</ThemeDetailsProvider>
);
}

interface Message {
content: string;
Expand All @@ -15,6 +31,11 @@ describe( 'useAssistant', () => {
localStorage.clear();
jest.useFakeTimers();
jest.setSystemTime( MOCKED_TIME );
( getIpcApi as jest.Mock ).mockReturnValue( {
showMessageBox: jest.fn().mockResolvedValue( { response: 0, checkboxChecked: false } ),
executeWPCLiInline: jest.fn().mockResolvedValue( { stdout: '', stderr: 'Error' } ),
} );
( useGetWpVersion as jest.Mock ).mockReturnValue( '6.4.3' );
} );

afterEach( () => {
Expand All @@ -31,13 +52,17 @@ describe( 'useAssistant', () => {
JSON.stringify( { [ selectedSiteId ]: initialMessages } )
);

const { result } = renderHook( () => useAssistant( selectedSiteId ) );
const { result } = renderHook( () => useAssistant( selectedSiteId ), {
wrapper: ContextWrapper,
} );

expect( result.current.messages ).toEqual( initialMessages );
} );

it( 'should add a message correctly', () => {
const { result } = renderHook( () => useAssistant( selectedSiteId ) );
const { result } = renderHook( () => useAssistant( selectedSiteId ), {
wrapper: ContextWrapper,
} );

act( () => {
result.current.addMessage( 'Hello', 'user' );
Expand Down Expand Up @@ -72,7 +97,9 @@ describe( 'useAssistant', () => {
} );

it( 'should clear messages correctly', () => {
const { result } = renderHook( () => useAssistant( selectedSiteId ) );
const { result } = renderHook( () => useAssistant( selectedSiteId ), {
wrapper: ContextWrapper,
} );

act( () => {
result.current.addMessage( 'Hello', 'user' );
Expand Down
21 changes: 21 additions & 0 deletions src/hooks/tests/use-chat-context.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ describe( 'useChatContext hook', () => {
os: 'darwin',
getChatInput: expect.any( Function ),
saveChatInput: expect.any( Function ),
messagesDict: {},
setMessagesDict: expect.any( Function ),
chatIdDict: {},
setChatIdDict: expect.any( Function ),
lastMessageIdDictRef: expect.any( Object ),
isLoadingDict: {},
setIsLoadingDict: expect.any( Function ),
} );
} );

Expand Down Expand Up @@ -214,6 +221,13 @@ describe( 'useChatContext hook', () => {
os: 'darwin',
getChatInput: expect.any( Function ),
saveChatInput: expect.any( Function ),
messagesDict: {},
setMessagesDict: expect.any( Function ),
chatIdDict: {},
setChatIdDict: expect.any( Function ),
lastMessageIdDictRef: expect.any( Object ),
isLoadingDict: {},
setIsLoadingDict: expect.any( Function ),
} );
} );

Expand Down Expand Up @@ -372,6 +386,13 @@ describe( 'useChatContext hook', () => {
os: 'darwin',
getChatInput: expect.any( Function ),
saveChatInput: expect.any( Function ),
messagesDict: {},
setMessagesDict: expect.any( Function ),
chatIdDict: {},
setChatIdDict: expect.any( Function ),
lastMessageIdDictRef: expect.any( Object ),
isLoadingDict: {},
setIsLoadingDict: expect.any( Function ),
} );
} );

Expand Down
16 changes: 7 additions & 9 deletions src/hooks/use-assistant-api.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { useCallback, useState } from 'react';
import { useCallback } from 'react';
import { Message } from './use-assistant';
import { useAuth } from './use-auth';
import { ChatContextType } from './use-chat-context';
import { ChatContextType, useChatContext } from './use-chat-context';
import { usePromptUsage } from './use-prompt-usage';

const contextMapper = ( context?: ChatContextType ) => {
Expand All @@ -25,17 +25,15 @@ const contextMapper = ( context?: ChatContextType ) => {

export function useAssistantApi( selectedSiteId: string ) {
const { client } = useAuth();
const [ isLoading, setIsLoading ] = useState< Record< string, boolean > >( {
[ selectedSiteId ]: false,
} );
const { setIsLoadingDict, isLoadingDict } = useChatContext();
const { updatePromptUsage } = usePromptUsage();

const fetchAssistant = useCallback(
async ( chatId: string | undefined, messages: Message[], context?: ChatContextType ) => {
if ( ! client ) {
throw new Error( 'WPcom client not initialized' );
}
setIsLoading( ( prev ) => ( { ...prev, [ selectedSiteId ]: true } ) );
setIsLoadingDict( ( prev ) => ( { ...prev, [ selectedSiteId ]: true } ) );
const body = {
messages,
chat_id: chatId,
Expand Down Expand Up @@ -69,7 +67,7 @@ export function useAssistantApi( selectedSiteId: string ) {
response = data;
headers = response_headers;
} finally {
setIsLoading( ( prev ) => ( { ...prev, [ selectedSiteId ]: false } ) );
setIsLoadingDict( ( prev ) => ( { ...prev, [ selectedSiteId ]: false } ) );
}

const message = response?.choices?.[ 0 ]?.message?.content;
Expand All @@ -82,8 +80,8 @@ export function useAssistantApi( selectedSiteId: string ) {

return { message, messageApiId, chatId: response?.id };
},
[ client, selectedSiteId, updatePromptUsage ]
[ client, selectedSiteId, setIsLoadingDict, updatePromptUsage ]
);

return { fetchAssistant, isLoading: isLoading[ selectedSiteId ] };
return { fetchAssistant, isLoading: isLoadingDict[ selectedSiteId ] };
}
54 changes: 17 additions & 37 deletions src/hooks/use-assistant.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { useState, useEffect, useCallback, useRef } from 'react';
import { useCallback } from 'react';
import { CHAT_MESSAGES_STORE_KEY } from '../constants';
import { CHAT_ID_STORE_KEY, useChatContext } from './use-chat-context';
import { useSendFeedback } from './use-send-feedback';

export type Message = {
Expand All @@ -19,42 +20,21 @@ export type Message = {
feedbackReceived?: boolean;
};

export type MessageDict = { [ key: string ]: Message[] };
export type ChatIdDict = { [ key: string ]: string | undefined };

const chatIdStoreKey = 'ai_chat_ids';
const EMPTY_MESSAGES: Message[] = [];

export const useAssistant = ( instanceId: string ) => {
const [ messagesDict, setMessagesDict ] = useState< MessageDict >( {} );
const [ chatIdDict, setChatIdDict ] = useState< ChatIdDict >( {
[ instanceId ]: undefined,
} );
const { messagesDict, setMessagesDict, chatIdDict, setChatIdDict, lastMessageIdDictRef } =
useChatContext();
const chatId = chatIdDict[ instanceId ];
const nextMessageIdRef = useRef< { [ key: string ]: number } >( {
[ instanceId ]: -1, // The first message should have id 0, as we do +1 when we add message
} );

useEffect( () => {
const storedMessages = localStorage.getItem( CHAT_MESSAGES_STORE_KEY );
const storedChatIds = localStorage.getItem( chatIdStoreKey );

if ( storedMessages ) {
const parsedMessages: MessageDict = JSON.parse( storedMessages );
setMessagesDict( parsedMessages );
Object.entries( parsedMessages ).forEach( ( [ key, messages ] ) => {
nextMessageIdRef.current[ key ] = messages.length;
} );
}
if ( storedChatIds ) {
setChatIdDict( JSON.parse( storedChatIds ) );
}
}, [] );

const addMessage = useCallback(
( content: string, role: 'user' | 'assistant', chatId?: string, messageApiId?: number ) => {
const newMessageId = nextMessageIdRef.current[ instanceId ] + 1;
nextMessageIdRef.current[ instanceId ] = newMessageId;
if ( lastMessageIdDictRef.current[ instanceId ] === undefined ) {
lastMessageIdDictRef.current[ instanceId ] = -1;
}

const newMessageId = lastMessageIdDictRef.current[ instanceId ] + 1;
lastMessageIdDictRef.current[ instanceId ] = newMessageId;

setMessagesDict( ( prevDict ) => {
const prevMessages = prevDict[ instanceId ] || [];
Expand All @@ -78,15 +58,15 @@ export const useAssistant = ( instanceId: string ) => {
setChatIdDict( ( prevDict ) => {
if ( prevDict[ instanceId ] !== chatId && chatId ) {
const newDict = { ...prevDict, [ instanceId ]: chatId };
localStorage.setItem( chatIdStoreKey, JSON.stringify( newDict ) );
localStorage.setItem( CHAT_ID_STORE_KEY, JSON.stringify( newDict ) );
return newDict;
}
return prevDict;
} );

return newMessageId; // Return the new message ID
},
[ instanceId ]
[ instanceId, setMessagesDict, setChatIdDict, lastMessageIdDictRef ]
);

const updateMessage = useCallback(
Expand Down Expand Up @@ -119,7 +99,7 @@ export const useAssistant = ( instanceId: string ) => {
return newDict;
} );
},
[ instanceId ]
[ instanceId, setMessagesDict ]
);

const markMessageAsFailed = useCallback(
Expand All @@ -135,7 +115,7 @@ export const useAssistant = ( instanceId: string ) => {
return newDict;
} );
},
[ instanceId ]
[ instanceId, setMessagesDict ]
);

const sendFeedback = useSendFeedback();
Expand Down Expand Up @@ -183,11 +163,11 @@ export const useAssistant = ( instanceId: string ) => {

setChatIdDict( ( prevDict ) => {
const { [ instanceId ]: _, ...rest } = prevDict;
localStorage.setItem( chatIdStoreKey, JSON.stringify( rest ) );
localStorage.setItem( CHAT_ID_STORE_KEY, JSON.stringify( rest ) );
return rest;
} );
nextMessageIdRef.current[ instanceId ] = 0;
}, [ instanceId ] );
lastMessageIdDictRef.current[ instanceId ] = -1;
}, [ instanceId, setMessagesDict, setChatIdDict ] );

return {
messages: messagesDict[ instanceId ] || EMPTY_MESSAGES,
Expand Down
Loading

0 comments on commit 08a8b3e

Please sign in to comment.