Skip to content

Commit

Permalink
[Feat] File integrity check for downloaded models (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
a-ghorbani authored Jan 5, 2025
1 parent 030eba7 commit 097b08d
Show file tree
Hide file tree
Showing 10 changed files with 237 additions and 25 deletions.
2 changes: 1 addition & 1 deletion __mocks__/stores/hfStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export const mockHFStore = {
// Methods
setSearchQuery: jest.fn(),
fetchAndSetGGUFSpecs: jest.fn().mockResolvedValue(undefined),
fetchModelFileSizes: jest.fn().mockResolvedValue(undefined),
fetchModelFileDetails: jest.fn().mockResolvedValue(undefined),
getModelById: jest.fn(id =>
mockHFStore.models.find(model => model.id === id),
),
Expand Down
58 changes: 45 additions & 13 deletions src/screens/ModelsScreen/ModelCard/ModelCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,16 @@ import {ModelSettings} from '../ModelSettings';
import {uiStore, modelStore} from '../../../store';

import {chatTemplates} from '../../../utils/chat';
import {getModelDescription, L10nContext} from '../../../utils';
import {Model, ModelOrigin, RootDrawerParamList} from '../../../utils/types';
import {
getModelDescription,
L10nContext,
checkModelFileIntegrity,
} from '../../../utils';
import {
COMPLETION_PARAMS_METADATA,
validateCompletionSettings,
} from '../../../utils/modelSettings';
import {Model, ModelOrigin, RootDrawerParamList} from '../../../utils/types';

type ChatScreenNavigationProp = DrawerNavigationProp<RootDrawerParamList>;

Expand All @@ -52,6 +56,7 @@ export const ModelCard: React.FC<ModelCardProps> = observer(

const [snackbarVisible, setSnackbarVisible] = useState(false); // Snackbar visibility
const [settingsModalVisible, setSettingsModalVisible] = useState(false);
const [integrityError, setIntegrityError] = useState<string | null>(null);

const {memoryWarning, shortMemoryWarning} = useMemoryCheck(model);
const {isOk: storageOk, message: storageNOkMessage} =
Expand All @@ -76,6 +81,17 @@ export const ModelCard: React.FC<ModelCardProps> = observer(
setTempCompletionSettings(model.completionSettings);
}, [model]);

// Check integrity when model is downloaded
useEffect(() => {
if (isDownloaded) {
checkModelFileIntegrity(model, modelStore).then(({errorMessage}) => {
setIntegrityError(errorMessage);
});
} else {
setIntegrityError(null);
}
}, [isDownloaded, model]);

const handleSettingsUpdate = useCallback((name: string, value: any) => {
setTempChatTemplate(prev => {
const newTemplate =
Expand Down Expand Up @@ -286,20 +302,17 @@ export const ModelCard: React.FC<ModelCardProps> = observer(
);
}

const handlePress = () => {
const handlePress = async () => {
if (isActiveModel) {
modelStore.manualReleaseContext();
} else {
modelStore
.initContext(model)
.then(() => {
console.log('initialized');
})
.catch(e => {
console.log(`Error: ${e}`);
});
if (uiStore.autoNavigatetoChat) {
navigation.navigate('Chat');
try {
await modelStore.initContext(model);
if (uiStore.autoNavigatetoChat) {
navigation.navigate('Chat');
}
} catch (e) {
console.log(`Error: ${e}`);
}
}
};
Expand All @@ -310,6 +323,7 @@ export const ModelCard: React.FC<ModelCardProps> = observer(
icon={isActiveModel ? 'eject' : 'play-circle-outline'}
mode="text"
onPress={handlePress}
// disabled={!!integrityError} // for now integrity check is experimental. So won't disable the button
style={styles.actionButton}>
{isActiveModel ? l10n.offload : l10n.load}
</Button>
Expand Down Expand Up @@ -390,6 +404,24 @@ export const ModelCard: React.FC<ModelCardProps> = observer(
</TouchableRipple>
)}

{/* Display integrity warning if check fails */}
{integrityError && (
<TouchableRipple
testID="integrity-warning-button"
//onPress={handleWarningPress}
style={styles.warningContainer}>
<View style={styles.warningContent}>
<IconButton
icon="alert-circle-outline"
iconColor={theme.colors.error}
size={20}
style={styles.warningIcon}
/>
<Text style={styles.warningText}>{integrityError}</Text>
</View>
</TouchableRipple>
)}

{isDownloading && (
<>
<ProgressBar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ describe('ModelCard', () => {
act(() => {
fireEvent.press(getByTestId('load-button'));
});
expect(mockNavigate).toHaveBeenCalledWith('Chat');
await waitFor(() => {
expect(mockNavigate).toHaveBeenCalledWith('Chat');
});
});

it('handles model offload', async () => {
Expand Down
4 changes: 4 additions & 0 deletions src/screens/ModelsScreen/ModelCard/styles.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,10 @@ export const createStyles = (theme: Theme) =>
flexDirection: 'row',
alignItems: 'center',
margin: 0,
marginTop: 8,
},
warningContent: {
flex: 1,
flexDirection: 'row',
alignItems: 'center',
},
Expand All @@ -127,6 +129,8 @@ export const createStyles = (theme: Theme) =>
warningText: {
color: theme.colors.error,
fontSize: 12,
flex: 1,
flexWrap: 'wrap',
},
overlayButtons: {
flex: 1,
Expand Down
9 changes: 5 additions & 4 deletions src/store/HFStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class HFStore {
...file,
size: details.size,
oid: details.oid,
lfs: details.lfs,
};

return {
Expand All @@ -72,10 +73,10 @@ class HFStore {
);
}

// Fetch the sizes of the model files
async fetchModelFileSizes(modelId: string) {
// Fetch the details (sizes, oid, lfs, ...) of the model files
async fetchModelFileDetails(modelId: string) {
try {
console.log('Fetching model file sizes for', modelId);
console.log('Fetching model file details for', modelId);
const fileDetails = await fetchModelFilesDetails(modelId);
const model = this.models.find(m => m.id === modelId);

Expand Down Expand Up @@ -103,7 +104,7 @@ class HFStore {
async fetchModelData(modelId: string) {
try {
await this.fetchAndSetGGUFSpecs(modelId);
await this.fetchModelFileSizes(modelId);
await this.fetchModelFileDetails(modelId);
} catch (error) {
console.error('Error fetching model data:', error);
}
Expand Down
55 changes: 53 additions & 2 deletions src/store/ModelStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,18 @@ import AsyncStorage from '@react-native-async-storage/async-storage';
import {computed, makeAutoObservable, ObservableMap, runInAction} from 'mobx';
import {CompletionParams, LlamaContext, initLlama} from '@pocketpalai/llama.rn';

import {fetchModelFilesDetails} from '../api/hf';

import {uiStore} from './UIStore';
import {chatSessionStore} from './ChatSessionStore';
import {defaultModels, MODEL_LIST_VERSION} from './defaultModels';
import {deepMerge, formatBytes, hasEnoughSpace, hfAsModel} from '../utils';
import {
deepMerge,
formatBytes,
getSHA256Hash,
hasEnoughSpace,
hfAsModel,
} from '../utils';

import {
getHFDefaultSettings,
Expand Down Expand Up @@ -410,7 +418,17 @@ class ModelStore {
};

async checkFileExists(model: Model) {
const exists = await RNFS.exists(await this.getModelFullPath(model));
const filePath = await this.getModelFullPath(model);
const exists = await RNFS.exists(filePath);
if (exists) {
// Only calculate hash if it's not already stored
if (!model.hash) {
const hash = await getSHA256Hash(filePath);
runInAction(() => {
model.hash = hash;
});
}
}
runInAction(() => {
model.isDownloaded = exists;
});
Expand Down Expand Up @@ -536,8 +554,12 @@ class ModelStore {

const result = await ret.promise;
if (result.statusCode === 200) {
// Calculate hash after successful download
const hash = await getSHA256Hash(downloadDest);

runInAction(() => {
model.progress = 100; // Ensure progress is set to 100 upon completion
model.hash = hash;
this.refreshDownloadStatuses();
});

Expand Down Expand Up @@ -1014,6 +1036,35 @@ class ModelStore {
setIsStreaming(value: boolean) {
this.isStreaming = value;
}

/**
* Fetches and updates model file details from HuggingFace.
* This is used when we need to get the lfs.oid for integrity checks.
* @param model - The model to update
* @returns Promise<void>
*/
async fetchAndUpdateModelFileDetails(model: Model): Promise<void> {
if (!model.hfModel?.id) {
return;
}

try {
const fileDetails = await fetchModelFilesDetails(model.hfModel.id);
const matchingFile = fileDetails.find(
file => file.path === model.hfModelFile?.rfilename,
);

if (matchingFile && matchingFile.lfs) {
runInAction(() => {
if (model.hfModelFile) {
model.hfModelFile.lfs = matchingFile.lfs;
}
});
}
} catch (error) {
console.error('Failed to fetch model file details:', error);
}
}
}

export const modelStore = new ModelStore();
6 changes: 3 additions & 3 deletions src/store/__tests__/HFStore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ describe('HFStore', () => {
});
});

describe('fetchModelFileSizes', () => {
describe('fetchModelFileDetails', () => {
it('should update model siblings with file sizes', async () => {
hfStore.models = [mockHFModel1];
const fileDetails = [
Expand All @@ -161,7 +161,7 @@ describe('HFStore', () => {

(fetchModelFilesDetails as jest.Mock).mockResolvedValueOnce(fileDetails);

await hfStore.fetchModelFileSizes(mockHFModel1.id);
await hfStore.fetchModelFileDetails(mockHFModel1.id);

expect(hfStore.models[0].siblings[0].size).toBe(1111);
expect(hfStore.models[0].siblings[0].oid).toBe('abc123');
Expand All @@ -171,7 +171,7 @@ describe('HFStore', () => {
hfStore.models = [];
(fetchModelFilesDetails as jest.Mock).mockResolvedValueOnce([]);

await hfStore.fetchModelFileSizes('non-existent-id');
await hfStore.fetchModelFileDetails('non-existent-id');

expect(fetchModelFilesDetails).toHaveBeenCalled();
// Should not throw error
Expand Down
Loading

0 comments on commit 097b08d

Please sign in to comment.