diff --git a/imodelsx/qaemb/qaemb.py b/imodelsx/qaemb/qaemb.py index 3c0d720..1b96d96 100644 --- a/imodelsx/qaemb/qaemb.py +++ b/imodelsx/qaemb/qaemb.py @@ -125,7 +125,12 @@ def __call__(self, examples: List[str], verbose=True, debug_answering_correctly= for i in range(min(30, len(programs))): print(programs[i], '->', answers[i], end='\n\n\n') - answers = list(map(lambda x: 'yes' in x.lower(), answers)) + def _check_for_yes(s): + if isinstance(s, str): + return 'yes' in s.lower() + else: + return False + answers = list(map(_check_for_yes, answers)) answers = np.array(answers).reshape(len(examples), len(self.questions)) embeddings = np.array(answers, dtype=float)