Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store more Assistant state in ChatContext to uncouple it from component tree #788

Merged
merged 13 commits into from
Jan 14, 2025
Merged
48 changes: 25 additions & 23 deletions src/components/content-tab-assistant.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ interface AuthenticatedViewProps {
>[ 'markMessageAsFeedbackReceived' ];
}

export const AuthenticatedView = memo(
const AuthenticatedView = memo(
( {
messages,
isAssistantThinking,
Expand All @@ -125,7 +125,7 @@ export const AuthenticatedView = memo(
}: AuthenticatedViewProps ) => {
const endOfMessagesRef = useRef< HTMLDivElement >( null );
const lastMessageRef = useRef< HTMLDivElement >( null );
const [ showThinking, setShowThinking ] = useState( false );
const [ showThinking, setShowThinking ] = useState( isAssistantThinking );
const lastMessage = useMemo(
() =>
showThinking
Expand All @@ -135,23 +135,27 @@ export const AuthenticatedView = memo(
);
const messagesToRender =
messages[ messages.length - 1 ]?.role === 'assistant' ? messages.slice( 0, -1 ) : messages;
const showLastMessage = showThinking || messages[ messages.length - 1 ]?.role === 'assistant';
const previousMessagesLength = useRef( messages?.length );
const previousSiteId = useRef( siteId );

const showLastMessage = lastMessage?.role === 'assistant';
const previousMessagesLength = useRef( messages.length );
const isInitialRenderRef = useRef( true );

// This effect may run twice when the component is mounted, which makes the viewport scroll
// to the wrong position. This happens because the app runs in React strict mode, meaning
// it only affects the development environment. For more details, see
// https://github.com/Automattic/studio/pull/788#issuecomment-2586644007
useEffect( () => {
if ( ! messages?.length ) {
previousSiteId.current = siteId;
if ( ! messages.length ) {
return;
}

let timer: NodeJS.Timeout;
// Scroll to the end of the messages when the tab is opened or site ID changes
if ( previousMessagesLength.current === 0 || previousSiteId.current !== siteId ) {
if ( isInitialRenderRef.current ) {
Comment on lines -150 to +153
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a key prop specifically for ContentTabAssistant in src/components/site-content-tabs.tsx instead.

endOfMessagesRef.current?.scrollIntoView( { behavior: 'instant' } );
isInitialRenderRef.current = false;
}
// Scroll when a new message is added
else if ( messages?.length > previousMessagesLength.current || showLastMessage ) {
else if ( messages.length > previousMessagesLength.current || showLastMessage ) {
// Scroll to the beginning of last message received from the assistant
if ( showLastMessage ) {
timer = setTimeout( () => {
Expand All @@ -166,11 +170,10 @@ export const AuthenticatedView = memo(
}
}

previousMessagesLength.current = messages?.length;
previousSiteId.current = siteId;
previousMessagesLength.current = messages.length;

return () => clearTimeout( timer );
}, [ messages?.length, showLastMessage, siteId ] );
}, [ messages.length, showLastMessage ] );

useEffect( () => {
let timer: NodeJS.Timeout;
Expand Down Expand Up @@ -486,16 +489,15 @@ export function ContentTabAssistant( { selectedSite }: ContentTabAssistantProps
siteId={ selectedSite.id }
disabled={ disabled }
/>
{
<AuthenticatedView
messages={ messages }
isAssistantThinking={ isAssistantThinking }
updateMessage={ updateMessage }
markMessageAsFeedbackReceived={ markMessageAsFeedbackReceived }
siteId={ selectedSite.id }
submitPrompt={ submitPrompt }
/>
}

<AuthenticatedView
messages={ messages }
isAssistantThinking={ isAssistantThinking }
updateMessage={ updateMessage }
markMessageAsFeedbackReceived={ markMessageAsFeedbackReceived }
siteId={ selectedSite.id }
submitPrompt={ submitPrompt }
/>
</>
) : (
! isOffline && <UnauthenticatedView onAuthenticate={ authenticate } />
Expand Down
8 changes: 7 additions & 1 deletion src/components/site-content-tabs.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ export function SiteContentTabs() {
{ name === 'share' && <ContentTabSnapshots selectedSite={ selectedSite } /> }
{ name === 'sync' && <ContentTabSync selectedSite={ selectedSite } /> }
{ name === 'settings' && <ContentTabSettings selectedSite={ selectedSite } /> }
{ name === 'assistant' && <ContentTabAssistant selectedSite={ selectedSite } /> }
{ name === 'assistant' && (
<ContentTabAssistant
// TODO: Remove this key once https://github.com/Automattic/dotcom-forge/issues/10219 is fixed
key={ selectedTab + selectedSite.id }
selectedSite={ selectedSite }
/>
) }
{ name === 'import-export' && (
<ContentTabImportExport selectedSite={ selectedSite } />
) }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
import { renderHook, act } from '@testing-library/react';
import { ReactNode } from 'react';
import { getIpcApi } from '../../lib/get-ipc-api';
import { useAssistant } from '../use-assistant';
import { ChatProvider } from '../use-chat-context';
import { useGetWpVersion } from '../use-get-wp-version';
import { ThemeDetailsProvider } from '../use-theme-details';

jest.mock( '../../lib/get-ipc-api' );
jest.mock( '../use-get-wp-version' );

function ContextWrapper( { children }: { children: ReactNode } ) {
return (
<ThemeDetailsProvider>
<ChatProvider>{ children }</ChatProvider>
</ThemeDetailsProvider>
);
}

interface Message {
content: string;
Expand All @@ -15,6 +31,11 @@ describe( 'useAssistant', () => {
localStorage.clear();
jest.useFakeTimers();
jest.setSystemTime( MOCKED_TIME );
( getIpcApi as jest.Mock ).mockReturnValue( {
showMessageBox: jest.fn().mockResolvedValue( { response: 0, checkboxChecked: false } ),
executeWPCLiInline: jest.fn().mockResolvedValue( { stdout: '', stderr: 'Error' } ),
} );
( useGetWpVersion as jest.Mock ).mockReturnValue( '6.4.3' );
} );

afterEach( () => {
Expand All @@ -31,13 +52,17 @@ describe( 'useAssistant', () => {
JSON.stringify( { [ selectedSiteId ]: initialMessages } )
);

const { result } = renderHook( () => useAssistant( selectedSiteId ) );
const { result } = renderHook( () => useAssistant( selectedSiteId ), {
wrapper: ContextWrapper,
} );

expect( result.current.messages ).toEqual( initialMessages );
} );

it( 'should add a message correctly', () => {
const { result } = renderHook( () => useAssistant( selectedSiteId ) );
const { result } = renderHook( () => useAssistant( selectedSiteId ), {
wrapper: ContextWrapper,
} );

act( () => {
result.current.addMessage( 'Hello', 'user' );
Expand Down Expand Up @@ -72,7 +97,9 @@ describe( 'useAssistant', () => {
} );

it( 'should clear messages correctly', () => {
const { result } = renderHook( () => useAssistant( selectedSiteId ) );
const { result } = renderHook( () => useAssistant( selectedSiteId ), {
wrapper: ContextWrapper,
} );

act( () => {
result.current.addMessage( 'Hello', 'user' );
Expand Down
21 changes: 21 additions & 0 deletions src/hooks/tests/use-chat-context.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ describe( 'useChatContext hook', () => {
os: 'darwin',
getChatInput: expect.any( Function ),
saveChatInput: expect.any( Function ),
messagesDict: {},
setMessagesDict: expect.any( Function ),
chatIdDict: {},
setChatIdDict: expect.any( Function ),
lastMessageIdDictRef: expect.any( Object ),
isLoadingDict: {},
setIsLoadingDict: expect.any( Function ),
} );
} );

Expand Down Expand Up @@ -214,6 +221,13 @@ describe( 'useChatContext hook', () => {
os: 'darwin',
getChatInput: expect.any( Function ),
saveChatInput: expect.any( Function ),
messagesDict: {},
setMessagesDict: expect.any( Function ),
chatIdDict: {},
setChatIdDict: expect.any( Function ),
lastMessageIdDictRef: expect.any( Object ),
isLoadingDict: {},
setIsLoadingDict: expect.any( Function ),
} );
} );

Expand Down Expand Up @@ -372,6 +386,13 @@ describe( 'useChatContext hook', () => {
os: 'darwin',
getChatInput: expect.any( Function ),
saveChatInput: expect.any( Function ),
messagesDict: {},
setMessagesDict: expect.any( Function ),
chatIdDict: {},
setChatIdDict: expect.any( Function ),
lastMessageIdDictRef: expect.any( Object ),
isLoadingDict: {},
setIsLoadingDict: expect.any( Function ),
} );
} );

Expand Down
16 changes: 7 additions & 9 deletions src/hooks/use-assistant-api.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { useCallback, useState } from 'react';
import { useCallback } from 'react';
import { Message } from './use-assistant';
import { useAuth } from './use-auth';
import { ChatContextType } from './use-chat-context';
import { ChatContextType, useChatContext } from './use-chat-context';
import { usePromptUsage } from './use-prompt-usage';

const contextMapper = ( context?: ChatContextType ) => {
Expand All @@ -25,17 +25,15 @@ const contextMapper = ( context?: ChatContextType ) => {

export function useAssistantApi( selectedSiteId: string ) {
const { client } = useAuth();
const [ isLoading, setIsLoading ] = useState< Record< string, boolean > >( {
[ selectedSiteId ]: false,
} );
const { setIsLoadingDict, isLoadingDict } = useChatContext();
const { updatePromptUsage } = usePromptUsage();

const fetchAssistant = useCallback(
async ( chatId: string | undefined, messages: Message[], context?: ChatContextType ) => {
if ( ! client ) {
throw new Error( 'WPcom client not initialized' );
}
setIsLoading( ( prev ) => ( { ...prev, [ selectedSiteId ]: true } ) );
setIsLoadingDict( ( prev ) => ( { ...prev, [ selectedSiteId ]: true } ) );
const body = {
messages,
chat_id: chatId,
Expand Down Expand Up @@ -69,7 +67,7 @@ export function useAssistantApi( selectedSiteId: string ) {
response = data;
headers = response_headers;
} finally {
setIsLoading( ( prev ) => ( { ...prev, [ selectedSiteId ]: false } ) );
setIsLoadingDict( ( prev ) => ( { ...prev, [ selectedSiteId ]: false } ) );
}

const message = response?.choices?.[ 0 ]?.message?.content;
Expand All @@ -82,8 +80,8 @@ export function useAssistantApi( selectedSiteId: string ) {

return { message, messageApiId, chatId: response?.id };
},
[ client, selectedSiteId, updatePromptUsage ]
[ client, selectedSiteId, setIsLoadingDict, updatePromptUsage ]
);

return { fetchAssistant, isLoading: isLoading[ selectedSiteId ] };
return { fetchAssistant, isLoading: isLoadingDict[ selectedSiteId ] };
}
54 changes: 17 additions & 37 deletions src/hooks/use-assistant.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { useState, useEffect, useCallback, useRef } from 'react';
import { useCallback } from 'react';
import { CHAT_MESSAGES_STORE_KEY } from '../constants';
import { CHAT_ID_STORE_KEY, useChatContext } from './use-chat-context';
import { useSendFeedback } from './use-send-feedback';

export type Message = {
Expand All @@ -19,42 +20,21 @@ export type Message = {
feedbackReceived?: boolean;
};

export type MessageDict = { [ key: string ]: Message[] };
export type ChatIdDict = { [ key: string ]: string | undefined };

const chatIdStoreKey = 'ai_chat_ids';
const EMPTY_MESSAGES: Message[] = [];

export const useAssistant = ( instanceId: string ) => {
const [ messagesDict, setMessagesDict ] = useState< MessageDict >( {} );
const [ chatIdDict, setChatIdDict ] = useState< ChatIdDict >( {
[ instanceId ]: undefined,
} );
const { messagesDict, setMessagesDict, chatIdDict, setChatIdDict, lastMessageIdDictRef } =
useChatContext();
const chatId = chatIdDict[ instanceId ];
const nextMessageIdRef = useRef< { [ key: string ]: number } >( {
[ instanceId ]: -1, // The first message should have id 0, as we do +1 when we add message
} );

useEffect( () => {
const storedMessages = localStorage.getItem( CHAT_MESSAGES_STORE_KEY );
const storedChatIds = localStorage.getItem( chatIdStoreKey );

if ( storedMessages ) {
const parsedMessages: MessageDict = JSON.parse( storedMessages );
setMessagesDict( parsedMessages );
Object.entries( parsedMessages ).forEach( ( [ key, messages ] ) => {
nextMessageIdRef.current[ key ] = messages.length;
} );
}
if ( storedChatIds ) {
setChatIdDict( JSON.parse( storedChatIds ) );
}
}, [] );

const addMessage = useCallback(
( content: string, role: 'user' | 'assistant', chatId?: string, messageApiId?: number ) => {
const newMessageId = nextMessageIdRef.current[ instanceId ] + 1;
nextMessageIdRef.current[ instanceId ] = newMessageId;
if ( lastMessageIdDictRef.current[ instanceId ] === undefined ) {
lastMessageIdDictRef.current[ instanceId ] = -1;
}

const newMessageId = lastMessageIdDictRef.current[ instanceId ] + 1;
lastMessageIdDictRef.current[ instanceId ] = newMessageId;

setMessagesDict( ( prevDict ) => {
const prevMessages = prevDict[ instanceId ] || [];
Expand All @@ -78,15 +58,15 @@ export const useAssistant = ( instanceId: string ) => {
setChatIdDict( ( prevDict ) => {
if ( prevDict[ instanceId ] !== chatId && chatId ) {
const newDict = { ...prevDict, [ instanceId ]: chatId };
localStorage.setItem( chatIdStoreKey, JSON.stringify( newDict ) );
localStorage.setItem( CHAT_ID_STORE_KEY, JSON.stringify( newDict ) );
return newDict;
}
return prevDict;
} );

return newMessageId; // Return the new message ID
},
[ instanceId ]
[ instanceId, setMessagesDict, setChatIdDict, lastMessageIdDictRef ]
);

const updateMessage = useCallback(
Expand Down Expand Up @@ -119,7 +99,7 @@ export const useAssistant = ( instanceId: string ) => {
return newDict;
} );
},
[ instanceId ]
[ instanceId, setMessagesDict ]
);

const markMessageAsFailed = useCallback(
Expand All @@ -135,7 +115,7 @@ export const useAssistant = ( instanceId: string ) => {
return newDict;
} );
},
[ instanceId ]
[ instanceId, setMessagesDict ]
);

const sendFeedback = useSendFeedback();
Expand Down Expand Up @@ -183,11 +163,11 @@ export const useAssistant = ( instanceId: string ) => {

setChatIdDict( ( prevDict ) => {
const { [ instanceId ]: _, ...rest } = prevDict;
localStorage.setItem( chatIdStoreKey, JSON.stringify( rest ) );
localStorage.setItem( CHAT_ID_STORE_KEY, JSON.stringify( rest ) );
return rest;
} );
nextMessageIdRef.current[ instanceId ] = 0;
}, [ instanceId ] );
lastMessageIdDictRef.current[ instanceId ] = -1;
}, [ instanceId, setMessagesDict, setChatIdDict ] );

return {
messages: messagesDict[ instanceId ] || EMPTY_MESSAGES,
Expand Down
Loading
Loading