Skip to content

Commit

Permalink
Merge pull request #72 from AnFreTh/main
Browse files Browse the repository at this point in the history
beta v0.1.3
  • Loading branch information
AnFreTh authored Aug 7, 2024
2 parents 5c1ffe2 + c9053a3 commit 3e94e5d
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 31 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
recursive-exclude notebooks *
recursive-include stream/preprocessed_datasets/*
recursive-include stream/pre_embedded_datasets/*
include stream/preprocessor/config/default_preprocessing_steps.json
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,22 @@
install_requires=install_reqs,
# extras_require=extras_reqs,
license="MIT", # adapt based on your needs
packages=find_packages(
exclude=["examples", "examples.*", "tests", "tests.*"]),
packages=find_packages(exclude=["examples", "examples.*", "tests", "tests.*"]),
include_package_data=True,
# package_dir={"stream": "stream"},
package_data={
# Use '**' to include all files within subdirectories recursively
"stream_topic": [
"preprocessed_datasets/**/*",
"preprocessor/config/default_preprocessing_steps.json"
"pre_embedded_datasets/**/*",
"preprocessor/config/default_preprocessing_steps.json",
],
},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
project_urls={'Documentation': DOCS},
url=HOMEPAGE
project_urls={"Documentation": DOCS},
url=HOMEPAGE,
)
2 changes: 1 addition & 1 deletion stream_topic/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Version information."""

# The following line *must* be the last in the module, exactly as formatted:
__version__ = "0.1.2"
__version__ = "0.1.3"
2 changes: 1 addition & 1 deletion stream_topic/models/CEDC.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _clustering(self):

def fit(
self,
dataset: TMDataset = None,
dataset: TMDataset,
n_topics: int = 20,
only_nouns: bool = False,
clean: bool = False,
Expand Down
21 changes: 7 additions & 14 deletions stream_topic/models/DCTE.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from datasets import Dataset
from loguru import logger
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel,TrainingArguments
from setfit import SetFitModel, TrainingArguments
from setfit import Trainer as SetfitTrainer
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

Expand Down Expand Up @@ -124,9 +123,7 @@ def _get_topic_representation(self, predict_df: pd.DataFrame, top_words: int):
)

one_hot_encoder = OneHotEncoder(sparse=False)
predictions_one_hot = one_hot_encoder.fit_transform(
predict_df[["predictions"]]
)
predictions_one_hot = one_hot_encoder.fit_transform(predict_df[["predictions"]])

beta = tfidf
theta = predictions_one_hot
Expand Down Expand Up @@ -215,9 +212,8 @@ def fit(

logger.info("--- Training completed successfully. ---")
self._status = TrainingStatus.SUCCEEDED

return self


def predict(self, dataset):
"""
Expand All @@ -242,9 +238,9 @@ def predict(self, dataset):

labels = self.model(predict_df["text"])
predict_df["predictions"] = labels

return labels

def get_topics(self, dataset, n_words=10):
"""
Retrieve the top words for each topic.
Expand All @@ -269,11 +265,8 @@ def get_topics(self, dataset, n_words=10):

labels = self.model(predict_df["text"])
predict_df["predictions"] = labels

topic_dict, beta, theta = self._get_topic_representation(predict_df, n_words)
if self._status != TrainingStatus.SUCCEEDED:
raise RuntimeError("Model has not been trained yet or failed.")
return [
[word for word, _ in topic_dict[key][:n_words]]
for key in topic_dict
]
return [[word for word, _ in topic_dict[key][:n_words]] for key in topic_dict]
35 changes: 25 additions & 10 deletions stream_topic/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import json
import os
import pickle
import re

import importlib.util
import gensim.downloader as api
import numpy as np
import pandas as pd
from loguru import logger
from sentence_transformers import SentenceTransformer

from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from torch.utils.data import DataLoader, Dataset, random_split
from torch.utils.data import Dataset, random_split

from ..commons.load_steps import load_model_preprocessing_steps
from ..preprocessor import TextPreprocessor
Expand Down Expand Up @@ -201,9 +200,17 @@ def get_package_dataset_path(self, name):
str
Path to the dataset.
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
my_package_dir = os.path.dirname(script_dir)
dataset_path = os.path.join(my_package_dir, "preprocessed_datasets", name)
# Get the location of the installed package
package_name = "stream_topic"
spec = importlib.util.find_spec(package_name)
if spec is None:
raise ImportError(f"Cannot find the package '{package_name}'")

package_root_dir = os.path.dirname(spec.origin)

# Construct the full path to the dataset
dataset_path = os.path.join(package_root_dir, "preprocessed_datasets", name)

return dataset_path

def has_embeddings(self, embedding_model_name, path=None, file_name=None):
Expand Down Expand Up @@ -336,10 +343,18 @@ def get_package_embeddings_path(self, name):
str
Path to the embeddings.
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
my_package_dir = os.path.dirname(script_dir)
dataset_path = os.path.join(my_package_dir, "pre_embedded_datasets", name)
return dataset_path
# Get the location of the installed package
package_name = "stream_topic"
spec = importlib.util.find_spec(package_name)
if spec is None:
raise ImportError(f"Cannot find the package '{package_name}'")

package_root_dir = os.path.dirname(spec.origin)

# Construct the full path to the dataset
embedding_path = os.path.join(package_root_dir, "pre_embedded_datasets", name)

return embedding_path

def create_load_save_dataset(
self,
Expand Down

0 comments on commit 3e94e5d

Please sign in to comment.