Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/logreg retrieval #81

Open
wants to merge 19 commits into
base: dev
Choose a base branch
from
Open

Feat/logreg retrieval #81

wants to merge 19 commits into from

Conversation

Darinochka
Copy link
Collaborator

No description provided.

self.k = model_data["k"]
self.classifier.coef_ = [model_data["coef"]]
self.classifier.intercept_ = model_data["intercept"]
self.label_encoder = LabelEncoder()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Может лучше сделать через get params/set params?

Comment on lines 37 to 40
Module for managing classification operations using logistic regression.

LogRegEmbedding provides methods for indexing, training, and predicting based on embeddings
for classification tasks.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

модуль для прокси подбора эмбедера

для предикта не используется

Comment on lines 139 to 163
def db_dir(self) -> str:
"""
Get the directory for storing data.

:return: Path to the database directory.
"""
if self._db_dir is None:
self._db_dir = str(get_db_dir())
return self._db_dir

def fit(self, utterances: list[str], labels: list[LabelType]) -> None:
"""
Train the logistic regression model using the provided utterances and labels.

:param utterances: List of text data to index.
:param labels: List of corresponding labels for the utterances.
"""
vector_index_client = VectorIndexClient(
self.embedder_device,
self.db_dir,
embedder_batch_size=self.batch_size,
embedder_max_length=self.max_length,
embedder_use_cache=self.embedder_use_cache,
)
self.vector_index = vector_index_client.create_index(self.embedder_name, utterances, labels)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

сейчас у нас по другому работает получение эмбедингов, достаточно просто инициализировать эмбедер и если включено кеширование и ранее эмбединги считались, то они подтянутся

для уточнения посмотри код для LinearScorer и/или спроси у Егора

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

а VectorIndexClient в итоге тоже нужно инициализировать? просто без него KNN Scorer не хочет работать

Comment on lines +166 to +168
self.label_encoder.fit(labels)
encoded_labels = self.label_encoder.transform(labels)
self.classifier.fit(embeddings, encoded_labels)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

кажется не обрабатывается случай мультилейбл, посмотри LinearScorer как там работает

Comment on lines 170 to 175
def score(
self,
context: Context,
split: Literal["validation", "test"],
metric_fn: RetrievalMetricFn,
) -> float:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

нужна не RetrievalMetricFn, а ScoringMetricFn

predicted_encoded = self.classifier.predict(embeddings)
predicted_labels = self.label_encoder.inverse_transform(predicted_encoded)

return metric_fn(labels, [predicted_labels])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ну и тут скобки не нужны будут если скоринговые метрики использовать

Comment on lines 218 to 222
self.metadata = VectorDBMetadata(
batch_size=self.batch_size,
max_length=self.max_length,
db_dir=str(self.db_dir),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

стоит использовать свой один TypedDict для хранения метадаты, не надо разбивать на два, тем более брать один из них из соседнего модуля

Comment on lines 229 to 236
self.classifier_metadata = ClassifierMetadata(
coef_=self.classifier.coef_.tolist(),
intercept_=self.classifier.intercept_.tolist(),
classes=self.label_encoder.classes_.tolist(),
params=self.classifier.get_params(),
)
with (dump_dir / "classifier.json").open("w") as file:
json.dump(self.classifier_metadata, file, indent=4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

а классификатор sklearn лучше сохранять с помощью библиотеки joblib

Comment on lines 268 to 287
def predict(self, utterances: list[str]) -> list[int | list[int]]:
"""
Predict labels for a list of utterances.

:param utterances: List of utterances for classification.
:return: A tuple containing:
- labels: List of predicted labels for each utterance.
- scores: List of dummy confidence scores.
- texts: List of the input utterances.
"""
embeddings = self.vector_index.embedder.embed(utterances)
predicted_encoded = self.classifier.predict(embeddings)
predicted_labels = self.label_encoder.inverse_transform(predicted_encoded).tolist()
predicted_probabilities = self.classifier.predict_proba(embeddings).tolist()

labels = self.vector_index.get_all_labels()

texts = [self.vector_index.texts[labels == label][: self.k] for label in predicted_labels]

return predicted_labels, predicted_probabilities, texts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

обсуждали что предикт не нужен

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ABC модуль требует функцию predict, мне логику менять или как лучше сделать?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Наверное ошибку кидай или просто pass

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

да сделала pass, получается в RetrieverEmbedder аналогично сделать?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Да, можно так

@voorhs
Copy link
Collaborator

voorhs commented Dec 22, 2024

closes #44

@Samoed
Copy link
Collaborator

Samoed commented Dec 22, 2024

Closes надо в описании pr писать, чтобы автоматически все закрывалось.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants