diff --git a/src/hooks/index.ts b/src/hooks/index.ts index 4fa3c1a..1c0d3b7 100644 --- a/src/hooks/index.ts +++ b/src/hooks/index.ts @@ -2,4 +2,5 @@ export * from './usePrevious'; export * from './useTheme'; export * from './useChatSession'; export * from './useMemoryCheck'; +export * from './useMoveScroll'; export * from './useStorageCheck'; diff --git a/src/hooks/useMoveScroll.ts b/src/hooks/useMoveScroll.ts new file mode 100644 index 0000000..056e75a --- /dev/null +++ b/src/hooks/useMoveScroll.ts @@ -0,0 +1,38 @@ +// This is a workaround for multiline input text not avoiding the keyboard.. +// https://github.com/facebook/react-native/issues/16826#issuecomment-2254322144 + +import {useRef, useEffect} from 'react'; +import {FlatList, Keyboard, KeyboardEvent, Platform} from 'react-native'; + +export const useMoveScroll = () => { + const scrollRef = useRef(null); + const keyboardHeight = useRef(Platform.OS === 'ios' ? 320 : 280); + const visibleAreaOffset = 300; // very arbitrary number. TODO: fix this + + useEffect(() => { + const keyboardDidShowListener = Keyboard.addListener( + 'keyboardDidShow', + (event: KeyboardEvent) => { + keyboardHeight.current = event.endCoordinates.height; + }, + ); + + return () => { + keyboardDidShowListener.remove(); + }; + }, []); + + const moveScrollToDown = (inputY?: number) => { + if (scrollRef.current) { + setTimeout(() => { + const offset = inputY ?? keyboardHeight.current + visibleAreaOffset; + scrollRef.current?.scrollToOffset({ + offset: Math.max(0, offset), + animated: true, + }); + }, 600); + } + }; + + return {scrollRef, moveScrollToDown}; +}; diff --git a/src/screens/ModelsScreen/CompletionSettings/CompletionSettings.tsx b/src/screens/ModelsScreen/CompletionSettings/CompletionSettings.tsx index 6f3b95d..3c12c21 100644 --- a/src/screens/ModelsScreen/CompletionSettings/CompletionSettings.tsx +++ b/src/screens/ModelsScreen/CompletionSettings/CompletionSettings.tsx @@ -1,5 +1,5 @@ import React, {useState} from 'react'; -import {ScrollView, View} from 'react-native'; +import {View} from 'react-native'; import {CompletionParams} from '@pocketpalai/llama.rn'; import Slider from '@react-native-community/slider'; @@ -16,6 +16,7 @@ interface Props { export const CompletionSettings: React.FC = ({settings, onChange}) => { const [localSliderValues, setLocalSliderValues] = useState({}); + const [newStopWord, setNewStopWord] = useState(''); const {colors} = useTheme(); const handleOnChange = (name, value) => { @@ -88,64 +89,83 @@ export const CompletionSettings: React.FC = ({settings, onChange}) => { ); + const renderStopWords = () => ( + + + stop + + + {/* Display existing stop words as chips */} + + {(settings.stop ?? []).map((word, index) => ( + { + const newStops = (settings.stop ?? []).filter( + (_, i) => i !== index, + ); + onChange('stop', newStops); + }} + style={styles.stopChip}> + {word} + + ))} + + + {/* Input for new stop words */} + { + if (newStopWord.trim()) { + onChange('stop', [...(settings.stop ?? []), newStopWord.trim()]); + setNewStopWord(''); + } + }} + style={styles.textInput} + testID="stop-input" + /> + + ); + return ( - - - - {renderIntegerInput('n_predict', 0, 2048)} - {renderSlider('temperature', 0, 1)} - {renderSlider('top_k', 1, 128, 1)} - {renderSlider('top_p', 0, 1)} - {renderSlider('min_p', 0, 1)} - {renderSlider('xtc_threshold', 0, 1)} - {renderSlider('xtc_probability', 0, 1)} - {renderSlider('typical_p', 0, 2)} - {renderSlider('penalty_last_n', 0, 256, 1)} - {renderSlider('penalty_repeat', 0, 2)} - {renderSlider('penalty_freq', 0, 2)} - {renderSlider('penalty_present', 0, 2)} - - - mirostat - - {[0, 1, 2].map(value => ( - onChange('mirostat', value)} - style={styles.chip}> - {value.toString()} - - ))} - - - {renderSlider('mirostat_tau', 0, 10, 1)} - {renderSlider('mirostat_eta', 0, 1)} - {renderSwitch('penalize_nl')} - {renderIntegerInput('seed', 0, Number.MAX_SAFE_INTEGER)} - {renderIntegerInput('n_probs', 0, 100)} - - - stop - (comma separated) - - - onChange( - 'stop', - value - .split(',') - .map(s => s.trim()) - .filter(s => s.length > 0), - ) - } - style={styles.textInput} - testID="stop-input" - /> + + + {renderIntegerInput('n_predict', 0, 2048)} + {renderSlider('temperature', 0, 1)} + {renderSlider('top_k', 1, 128, 1)} + {renderSlider('top_p', 0, 1)} + {renderSlider('min_p', 0, 1)} + {renderSlider('xtc_threshold', 0, 1)} + {renderSlider('xtc_probability', 0, 1)} + {renderSlider('typical_p', 0, 2)} + {renderSlider('penalty_last_n', 0, 256, 1)} + {renderSlider('penalty_repeat', 0, 2)} + {renderSlider('penalty_freq', 0, 2)} + {renderSlider('penalty_present', 0, 2)} + + + mirostat + + {[0, 1, 2].map(value => ( + onChange('mirostat', value)} + style={styles.chip}> + {value.toString()} + + ))} - - - + + {renderSlider('mirostat_tau', 0, 10, 1)} + {renderSlider('mirostat_eta', 0, 1)} + {renderSwitch('penalize_nl')} + {renderIntegerInput('seed', 0, Number.MAX_SAFE_INTEGER)} + {renderIntegerInput('n_probs', 0, 100)} + {renderStopWords()} + + ); }; diff --git a/src/screens/ModelsScreen/CompletionSettings/__tests__/CompletionSettings.test.tsx b/src/screens/ModelsScreen/CompletionSettings/__tests__/CompletionSettings.test.tsx index f8574b0..89fe6f8 100644 --- a/src/screens/ModelsScreen/CompletionSettings/__tests__/CompletionSettings.test.tsx +++ b/src/screens/ModelsScreen/CompletionSettings/__tests__/CompletionSettings.test.tsx @@ -7,7 +7,7 @@ jest.useFakeTimers(); describe('CompletionSettings', () => { it('renders all settings correctly', async () => { - const {getByDisplayValue, getByTestId} = render( + const {getByDisplayValue, getByTestId, getByText} = render( { expect(nProbsInput.props.value).toBe('0'); expect(getByTestId('stop-input')).toBeTruthy(); - const stopInput = getByTestId('stop-input'); - expect(stopInput.props.value).toBe(', '); + expect(getByText('')).toBeTruthy(); + expect(getByText('')).toBeTruthy(); }); it('handles slider changes', () => { @@ -144,4 +144,33 @@ describe('CompletionSettings', () => { fireEvent.press(mirostatChip); expect(mockOnChange).toHaveBeenCalledWith('mirostat', 2); }); + + it('handles stop words additions and removals', () => { + const mockOnChange = jest.fn(); + const {getByTestId, getAllByRole} = render( + , + ); + + // Test adding new stop word + const stopInput = getByTestId('stop-input'); + fireEvent.changeText(stopInput, 'newstop'); + fireEvent(stopInput, 'submitEditing'); + + expect(mockOnChange).toHaveBeenCalledWith('stop', [ + ...(mockCompletionParams.stop ?? []), + 'newstop', + ]); + + // Test removing stop word + const closeButtons = getAllByRole('button', {name: /close/i}); + fireEvent.press(closeButtons[0]); + + expect(mockOnChange).toHaveBeenCalledWith( + 'stop', + (mockCompletionParams.stop ?? []).filter(word => word !== ''), + ); + }); }); diff --git a/src/screens/ModelsScreen/CompletionSettings/styles.ts b/src/screens/ModelsScreen/CompletionSettings/styles.ts index 1285820..3eed4e0 100644 --- a/src/screens/ModelsScreen/CompletionSettings/styles.ts +++ b/src/screens/ModelsScreen/CompletionSettings/styles.ts @@ -1,10 +1,6 @@ import {StyleSheet} from 'react-native'; export const styles = StyleSheet.create({ - container: { - flex: 1, - }, - card: {}, row: { flexDirection: 'row', alignItems: 'center', @@ -52,4 +48,14 @@ export const styles = StyleSheet.create({ fontSize: 16, marginRight: 8, }, + stopWordsContainer: { + flexDirection: 'row', + flexWrap: 'wrap', + gap: 8, + marginBottom: 8, + }, + stopChip: { + marginRight: 4, + marginBottom: 4, + }, }); diff --git a/src/screens/ModelsScreen/ModelCard/ModelCard.tsx b/src/screens/ModelsScreen/ModelCard/ModelCard.tsx index b7e393c..52ac51d 100644 --- a/src/screens/ModelsScreen/ModelCard/ModelCard.tsx +++ b/src/screens/ModelsScreen/ModelCard/ModelCard.tsx @@ -33,10 +33,11 @@ type ChatScreenNavigationProp = DrawerNavigationProp; interface ModelCardProps { model: Model; activeModelId?: string; + onFocus?: () => void; } export const ModelCard: React.FC = observer( - ({model, activeModelId}) => { + ({model, activeModelId, onFocus}) => { const l10n = React.useContext(L10nContext); const {colors} = useTheme(); const navigation = useNavigation(); @@ -315,6 +316,9 @@ export const ModelCard: React.FC = observer( isActive={isActiveModel} onChange={handleSettingsUpdate} onCompletionSettingsChange={handleCompletionSettingsUpdate} + onFocus={() => { + onFocus && onFocus(); + }} /> )} diff --git a/src/screens/ModelsScreen/ModelSettings/ModelSettings.tsx b/src/screens/ModelsScreen/ModelSettings/ModelSettings.tsx index 8db46e7..33c9bdd 100644 --- a/src/screens/ModelsScreen/ModelSettings/ModelSettings.tsx +++ b/src/screens/ModelsScreen/ModelSettings/ModelSettings.tsx @@ -34,6 +34,7 @@ interface ModelSettingsProps { isActive: boolean; onChange: (name: string, value: any) => void; onCompletionSettingsChange: (name: string, value: any) => void; + onFocus?: () => void; } export const ModelSettings: React.FC = ({ @@ -42,13 +43,14 @@ export const ModelSettings: React.FC = ({ isActive, onChange, onCompletionSettingsChange, + onFocus, }) => { const [isDialogVisible, setDialogVisible] = useState(false); const [localChatTemplate, setLocalChatTemplate] = useState( chatTemplate.chatTemplate, ); const [localSystemPrompt, setLocalSystemPrompt] = useState( - chatTemplate.systemPrompt, + chatTemplate.systemPrompt ?? '', ); const [selectedTemplateName, setSelectedTemplateName] = useState( chatTemplate.name, @@ -73,7 +75,7 @@ export const ModelSettings: React.FC = ({ }, [localChatTemplate]); useEffect(() => { - setLocalSystemPrompt(chatTemplate.systemPrompt); + setLocalSystemPrompt(chatTemplate.systemPrompt ?? ''); }, [chatTemplate.systemPrompt]); useEffect(() => { @@ -176,20 +178,20 @@ export const ModelSettings: React.FC = ({ - {chatTemplate.systemPrompt !== undefined && - chatTemplate.systemPrompt !== null && ( - setLocalSystemPrompt(text)} - onBlur={() => handleSaveSystemPrompt()} - multiline - numberOfLines={3} - style={styles.textArea} - label={'System prompt'} - /> - )} + setLocalSystemPrompt(text)} + onBlur={() => handleSaveSystemPrompt()} + multiline + numberOfLines={3} + style={styles.textArea} + label={'System prompt'} + onFocus={() => { + onFocus && onFocus(); + }} + /> {/** Completion Settings */} { uiStore.setValue('modelsScreen', 'expandedGroups', updatedExpandedGroups); }; + const {scrollRef, moveScrollToDown} = useMoveScroll(); + const renderGroupHeader = ({item: group}) => { const isExpanded = expandedGroups[group.type]; return ( @@ -175,7 +183,16 @@ export const ModelsScreen: React.FC = observer(() => { data={group.items} keyExtractor={subItem => subItem.id} renderItem={({item: subItem}) => ( - + { + if (Platform.OS === 'ios') { + // Workaround for multiline input text not avoiding the keyboard. + moveScrollToDown(); + } + }} + /> )} /> @@ -183,7 +200,16 @@ export const ModelsScreen: React.FC = observer(() => { }; const renderItem = ({item}) => ( - + { + if (Platform.OS === 'ios') { + // Workaround for multiline input text not avoiding the keyboard. + moveScrollToDown(); + } + }} + /> ); const flatListModels = Object.keys(groupedModels) @@ -194,10 +220,16 @@ export const ModelsScreen: React.FC = observer(() => { .filter(group => group.items.length > 0); return ( - + { onAddHFModel={() => setHFSearchVisible(true)} onAddLocalModel={handleAddLocalModel} /> - + ); });