Skip to content

Commit

Permalink
Fix/avoiding keyboard settings (#115)
Browse files Browse the repository at this point in the history
* fix: text inputs in model screen avoid keyboard

* fix: add workaround for multiline input text not avoiding the keyboard

* fix: refactor stop words input -> chips
  • Loading branch information
a-ghorbani authored Nov 30, 2024
1 parent 600918b commit b5afeca
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 89 deletions.
1 change: 1 addition & 0 deletions src/hooks/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ export * from './usePrevious';
export * from './useTheme';
export * from './useChatSession';
export * from './useMemoryCheck';
export * from './useMoveScroll';
export * from './useStorageCheck';
38 changes: 38 additions & 0 deletions src/hooks/useMoveScroll.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// This is a workaround for multiline input text not avoiding the keyboard..
// https://github.com/facebook/react-native/issues/16826#issuecomment-2254322144

import {useRef, useEffect} from 'react';
import {FlatList, Keyboard, KeyboardEvent, Platform} from 'react-native';

export const useMoveScroll = () => {
const scrollRef = useRef<FlatList>(null);
const keyboardHeight = useRef(Platform.OS === 'ios' ? 320 : 280);
const visibleAreaOffset = 300; // very arbitrary number. TODO: fix this

useEffect(() => {
const keyboardDidShowListener = Keyboard.addListener(
'keyboardDidShow',
(event: KeyboardEvent) => {
keyboardHeight.current = event.endCoordinates.height;
},
);

return () => {
keyboardDidShowListener.remove();
};
}, []);

const moveScrollToDown = (inputY?: number) => {
if (scrollRef.current) {
setTimeout(() => {
const offset = inputY ?? keyboardHeight.current + visibleAreaOffset;
scrollRef.current?.scrollToOffset({
offset: Math.max(0, offset),
animated: true,
});
}, 600);
}
};

return {scrollRef, moveScrollToDown};
};
136 changes: 78 additions & 58 deletions src/screens/ModelsScreen/CompletionSettings/CompletionSettings.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import React, {useState} from 'react';
import {ScrollView, View} from 'react-native';
import {View} from 'react-native';

import {CompletionParams} from '@pocketpalai/llama.rn';
import Slider from '@react-native-community/slider';
Expand All @@ -16,6 +16,7 @@ interface Props {

export const CompletionSettings: React.FC<Props> = ({settings, onChange}) => {
const [localSliderValues, setLocalSliderValues] = useState({});
const [newStopWord, setNewStopWord] = useState('');
const {colors} = useTheme();

const handleOnChange = (name, value) => {
Expand Down Expand Up @@ -88,64 +89,83 @@ export const CompletionSettings: React.FC<Props> = ({settings, onChange}) => {
</View>
);

const renderStopWords = () => (
<View style={styles.settingItem}>
<View style={styles.stopLabel}>
<Text style={styles.settingLabel}>stop</Text>
</View>

{/* Display existing stop words as chips */}
<View style={styles.stopWordsContainer}>
{(settings.stop ?? []).map((word, index) => (
<Chip
key={index}
onClose={() => {
const newStops = (settings.stop ?? []).filter(
(_, i) => i !== index,
);
onChange('stop', newStops);
}}
style={styles.stopChip}>
{word}
</Chip>
))}
</View>

{/* Input for new stop words */}
<TextInput
value={newStopWord}
placeholder="Add new stop word"
onChangeText={setNewStopWord}
onSubmitEditing={() => {
if (newStopWord.trim()) {
onChange('stop', [...(settings.stop ?? []), newStopWord.trim()]);
setNewStopWord('');
}
}}
style={styles.textInput}
testID="stop-input"
/>
</View>
);

return (
<ScrollView style={styles.container}>
<View style={styles.card}>
<Card.Content>
{renderIntegerInput('n_predict', 0, 2048)}
{renderSlider('temperature', 0, 1)}
{renderSlider('top_k', 1, 128, 1)}
{renderSlider('top_p', 0, 1)}
{renderSlider('min_p', 0, 1)}
{renderSlider('xtc_threshold', 0, 1)}
{renderSlider('xtc_probability', 0, 1)}
{renderSlider('typical_p', 0, 2)}
{renderSlider('penalty_last_n', 0, 256, 1)}
{renderSlider('penalty_repeat', 0, 2)}
{renderSlider('penalty_freq', 0, 2)}
{renderSlider('penalty_present', 0, 2)}
<Divider style={styles.divider} />
<View style={styles.settingItem}>
<Text style={styles.settingLabel}>mirostat</Text>
<View style={styles.chipContainer}>
{[0, 1, 2].map(value => (
<Chip
key={value}
selected={settings.mirostat === value}
onPress={() => onChange('mirostat', value)}
style={styles.chip}>
{value.toString()}
</Chip>
))}
</View>
</View>
{renderSlider('mirostat_tau', 0, 10, 1)}
{renderSlider('mirostat_eta', 0, 1)}
{renderSwitch('penalize_nl')}
{renderIntegerInput('seed', 0, Number.MAX_SAFE_INTEGER)}
{renderIntegerInput('n_probs', 0, 100)}
<View style={styles.settingItem}>
<View style={styles.stopLabel}>
<Text style={styles.settingLabel}>stop</Text>
<Text>(comma separated)</Text>
</View>
<TextInput
value={settings.stop?.join(', ')}
onChangeText={value =>
onChange(
'stop',
value
.split(',')
.map(s => s.trim())
.filter(s => s.length > 0),
)
}
style={styles.textInput}
testID="stop-input"
/>
<View>
<Card.Content>
{renderIntegerInput('n_predict', 0, 2048)}
{renderSlider('temperature', 0, 1)}
{renderSlider('top_k', 1, 128, 1)}
{renderSlider('top_p', 0, 1)}
{renderSlider('min_p', 0, 1)}
{renderSlider('xtc_threshold', 0, 1)}
{renderSlider('xtc_probability', 0, 1)}
{renderSlider('typical_p', 0, 2)}
{renderSlider('penalty_last_n', 0, 256, 1)}
{renderSlider('penalty_repeat', 0, 2)}
{renderSlider('penalty_freq', 0, 2)}
{renderSlider('penalty_present', 0, 2)}
<Divider style={styles.divider} />
<View style={styles.settingItem}>
<Text style={styles.settingLabel}>mirostat</Text>
<View style={styles.chipContainer}>
{[0, 1, 2].map(value => (
<Chip
key={value}
selected={settings.mirostat === value}
onPress={() => onChange('mirostat', value)}
style={styles.chip}>
{value.toString()}
</Chip>
))}
</View>
</Card.Content>
</View>
</ScrollView>
</View>
{renderSlider('mirostat_tau', 0, 10, 1)}
{renderSlider('mirostat_eta', 0, 1)}
{renderSwitch('penalize_nl')}
{renderIntegerInput('seed', 0, Number.MAX_SAFE_INTEGER)}
{renderIntegerInput('n_probs', 0, 100)}
{renderStopWords()}
</Card.Content>
</View>
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jest.useFakeTimers();

describe('CompletionSettings', () => {
it('renders all settings correctly', async () => {
const {getByDisplayValue, getByTestId} = render(
const {getByDisplayValue, getByTestId, getByText} = render(
<CompletionSettings
settings={mockCompletionParams}
onChange={jest.fn()}
Expand Down Expand Up @@ -82,8 +82,8 @@ describe('CompletionSettings', () => {
expect(nProbsInput.props.value).toBe('0');

expect(getByTestId('stop-input')).toBeTruthy();
const stopInput = getByTestId('stop-input');
expect(stopInput.props.value).toBe('<stop1>, <stop2>');
expect(getByText('<stop1>')).toBeTruthy();
expect(getByText('<stop2>')).toBeTruthy();
});

it('handles slider changes', () => {
Expand Down Expand Up @@ -144,4 +144,33 @@ describe('CompletionSettings', () => {
fireEvent.press(mirostatChip);
expect(mockOnChange).toHaveBeenCalledWith('mirostat', 2);
});

it('handles stop words additions and removals', () => {
const mockOnChange = jest.fn();
const {getByTestId, getAllByRole} = render(
<CompletionSettings
settings={mockCompletionParams}
onChange={mockOnChange}
/>,
);

// Test adding new stop word
const stopInput = getByTestId('stop-input');
fireEvent.changeText(stopInput, 'newstop');
fireEvent(stopInput, 'submitEditing');

expect(mockOnChange).toHaveBeenCalledWith('stop', [
...(mockCompletionParams.stop ?? []),
'newstop',
]);

// Test removing stop word
const closeButtons = getAllByRole('button', {name: /close/i});
fireEvent.press(closeButtons[0]);

expect(mockOnChange).toHaveBeenCalledWith(
'stop',
(mockCompletionParams.stop ?? []).filter(word => word !== '<stop1>'),
);
});
});
14 changes: 10 additions & 4 deletions src/screens/ModelsScreen/CompletionSettings/styles.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import {StyleSheet} from 'react-native';

export const styles = StyleSheet.create({
container: {
flex: 1,
},
card: {},
row: {
flexDirection: 'row',
alignItems: 'center',
Expand Down Expand Up @@ -52,4 +48,14 @@ export const styles = StyleSheet.create({
fontSize: 16,
marginRight: 8,
},
stopWordsContainer: {
flexDirection: 'row',
flexWrap: 'wrap',
gap: 8,
marginBottom: 8,
},
stopChip: {
marginRight: 4,
marginBottom: 4,
},
});
6 changes: 5 additions & 1 deletion src/screens/ModelsScreen/ModelCard/ModelCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ type ChatScreenNavigationProp = DrawerNavigationProp<RootDrawerParamList>;
interface ModelCardProps {
model: Model;
activeModelId?: string;
onFocus?: () => void;
}

export const ModelCard: React.FC<ModelCardProps> = observer(
({model, activeModelId}) => {
({model, activeModelId, onFocus}) => {
const l10n = React.useContext(L10nContext);
const {colors} = useTheme();
const navigation = useNavigation<ChatScreenNavigationProp>();
Expand Down Expand Up @@ -315,6 +316,9 @@ export const ModelCard: React.FC<ModelCardProps> = observer(
isActive={isActiveModel}
onChange={handleSettingsUpdate}
onCompletionSettingsChange={handleCompletionSettingsUpdate}
onFocus={() => {
onFocus && onFocus();
}}
/>
)}
</View>
Expand Down
34 changes: 18 additions & 16 deletions src/screens/ModelsScreen/ModelSettings/ModelSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ interface ModelSettingsProps {
isActive: boolean;
onChange: (name: string, value: any) => void;
onCompletionSettingsChange: (name: string, value: any) => void;
onFocus?: () => void;
}

export const ModelSettings: React.FC<ModelSettingsProps> = ({
Expand All @@ -42,13 +43,14 @@ export const ModelSettings: React.FC<ModelSettingsProps> = ({
isActive,
onChange,
onCompletionSettingsChange,
onFocus,
}) => {
const [isDialogVisible, setDialogVisible] = useState<boolean>(false);
const [localChatTemplate, setLocalChatTemplate] = useState(
chatTemplate.chatTemplate,
);
const [localSystemPrompt, setLocalSystemPrompt] = useState(
chatTemplate.systemPrompt,
chatTemplate.systemPrompt ?? '',
);
const [selectedTemplateName, setSelectedTemplateName] = useState(
chatTemplate.name,
Expand All @@ -73,7 +75,7 @@ export const ModelSettings: React.FC<ModelSettingsProps> = ({
}, [localChatTemplate]);

useEffect(() => {
setLocalSystemPrompt(chatTemplate.systemPrompt);
setLocalSystemPrompt(chatTemplate.systemPrompt ?? '');
}, [chatTemplate.systemPrompt]);

useEffect(() => {
Expand Down Expand Up @@ -176,20 +178,20 @@ export const ModelSettings: React.FC<ModelSettingsProps> = ({
</Button>
</View>
<View>
{chatTemplate.systemPrompt !== undefined &&
chatTemplate.systemPrompt !== null && (
<TextInput
testID="system-prompt-input"
ref={systemPromptTextInputRef}
defaultValue={localSystemPrompt}
onChangeText={text => setLocalSystemPrompt(text)}
onBlur={() => handleSaveSystemPrompt()}
multiline
numberOfLines={3}
style={styles.textArea}
label={'System prompt'}
/>
)}
<TextInput
testID="system-prompt-input"
ref={systemPromptTextInputRef}
defaultValue={localSystemPrompt}
onChangeText={text => setLocalSystemPrompt(text)}
onBlur={() => handleSaveSystemPrompt()}
multiline
numberOfLines={3}
style={styles.textArea}
label={'System prompt'}
onFocus={() => {
onFocus && onFocus();
}}
/>
</View>
{/** Completion Settings */}
<List.Accordion
Expand Down
Loading

0 comments on commit b5afeca

Please sign in to comment.