-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/logreg retrieval #81
base: dev
Are you sure you want to change the base?
Conversation
self.k = model_data["k"] | ||
self.classifier.coef_ = [model_data["coef"]] | ||
self.classifier.intercept_ = model_data["intercept"] | ||
self.label_encoder = LabelEncoder() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Может лучше сделать через get params/set params?
Module for managing classification operations using logistic regression. | ||
|
||
LogRegEmbedding provides methods for indexing, training, and predicting based on embeddings | ||
for classification tasks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
модуль для прокси подбора эмбедера
для предикта не используется
def db_dir(self) -> str: | ||
""" | ||
Get the directory for storing data. | ||
|
||
:return: Path to the database directory. | ||
""" | ||
if self._db_dir is None: | ||
self._db_dir = str(get_db_dir()) | ||
return self._db_dir | ||
|
||
def fit(self, utterances: list[str], labels: list[LabelType]) -> None: | ||
""" | ||
Train the logistic regression model using the provided utterances and labels. | ||
|
||
:param utterances: List of text data to index. | ||
:param labels: List of corresponding labels for the utterances. | ||
""" | ||
vector_index_client = VectorIndexClient( | ||
self.embedder_device, | ||
self.db_dir, | ||
embedder_batch_size=self.batch_size, | ||
embedder_max_length=self.max_length, | ||
embedder_use_cache=self.embedder_use_cache, | ||
) | ||
self.vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
сейчас у нас по другому работает получение эмбедингов, достаточно просто инициализировать эмбедер и если включено кеширование и ранее эмбединги считались, то они подтянутся
для уточнения посмотри код для LinearScorer и/или спроси у Егора
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
а VectorIndexClient в итоге тоже нужно инициализировать? просто без него KNN Scorer не хочет работать
self.label_encoder.fit(labels) | ||
encoded_labels = self.label_encoder.transform(labels) | ||
self.classifier.fit(embeddings, encoded_labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
кажется не обрабатывается случай мультилейбл, посмотри LinearScorer как там работает
def score( | ||
self, | ||
context: Context, | ||
split: Literal["validation", "test"], | ||
metric_fn: RetrievalMetricFn, | ||
) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
нужна не RetrievalMetricFn, а ScoringMetricFn
predicted_encoded = self.classifier.predict(embeddings) | ||
predicted_labels = self.label_encoder.inverse_transform(predicted_encoded) | ||
|
||
return metric_fn(labels, [predicted_labels]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ну и тут скобки не нужны будут если скоринговые метрики использовать
self.metadata = VectorDBMetadata( | ||
batch_size=self.batch_size, | ||
max_length=self.max_length, | ||
db_dir=str(self.db_dir), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
стоит использовать свой один TypedDict для хранения метадаты, не надо разбивать на два, тем более брать один из них из соседнего модуля
self.classifier_metadata = ClassifierMetadata( | ||
coef_=self.classifier.coef_.tolist(), | ||
intercept_=self.classifier.intercept_.tolist(), | ||
classes=self.label_encoder.classes_.tolist(), | ||
params=self.classifier.get_params(), | ||
) | ||
with (dump_dir / "classifier.json").open("w") as file: | ||
json.dump(self.classifier_metadata, file, indent=4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
а классификатор sklearn лучше сохранять с помощью библиотеки joblib
def predict(self, utterances: list[str]) -> list[int | list[int]]: | ||
""" | ||
Predict labels for a list of utterances. | ||
|
||
:param utterances: List of utterances for classification. | ||
:return: A tuple containing: | ||
- labels: List of predicted labels for each utterance. | ||
- scores: List of dummy confidence scores. | ||
- texts: List of the input utterances. | ||
""" | ||
embeddings = self.vector_index.embedder.embed(utterances) | ||
predicted_encoded = self.classifier.predict(embeddings) | ||
predicted_labels = self.label_encoder.inverse_transform(predicted_encoded).tolist() | ||
predicted_probabilities = self.classifier.predict_proba(embeddings).tolist() | ||
|
||
labels = self.vector_index.get_all_labels() | ||
|
||
texts = [self.vector_index.texts[labels == label][: self.k] for label in predicted_labels] | ||
|
||
return predicted_labels, predicted_probabilities, texts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
обсуждали что предикт не нужен
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ABC модуль требует функцию predict, мне логику менять или как лучше сделать?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Наверное ошибку кидай или просто pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
да сделала pass, получается в RetrieverEmbedder аналогично сделать?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Да, можно так
closes #44 |
Closes надо в описании pr писать, чтобы автоматически все закрывалось. |
No description provided.