Skip to content

Commit

Permalink
documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
jwijffels committed Nov 5, 2021
1 parent 33de256 commit 6df7216
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 67 deletions.
67 changes: 40 additions & 27 deletions R/ETM.R
Original file line number Diff line number Diff line change
Expand Up @@ -482,22 +482,30 @@ split_train_test <- function(x, train_pct = 0.7){



#' @title Predict to which ETM topic a text belongs
#' @description Predict functionality for an \code{ETM} object
#' @title Predict functionality for an ETM object.
#' @description Predict to which ETM topic a text belongs or extract which words are emitted for each topic.
#' @param object an object of class \code{ETM}
#' @param type either 'topics' or 'terms'
#' @param newdata bag of words document term matrix in \code{dgCMatrix} format
#' @param batch_size integer with the size of the batch
#' @param normalize logical indicating to normalize the bag of words data
#' @param top_n integer with number of most relevant words for each topic to extract
#' @param type a character string with either 'topics' or 'terms' indicating to either predict to which
#' topic a document encoded as a set of bag of words belongs to or to extract the most emitted terms for each topic
#' @param newdata bag of words document term matrix in \code{dgCMatrix} format. Only used in case type = 'topics'.
#' @param batch_size integer with the size of the batch in order to do chunkwise predictions in chunks of \code{batch_size} rows. Defaults to the whole dataset provided in \code{newdata}.
#' Only used in case type = 'topics'.
#' @param normalize logical indicating to normalize the bag of words data. Defaults to \code{TRUE} similar as the default when building the \code{ETM} model.
#' Only used in case type = 'topics'.
#' @param top_n integer with the number of most relevant words for each topic to extract. Only used in case type = 'terms'.
#' @param ... not used
#' @seealso \code{\link{ETM}}
#' @return Returns for
#' \itemize{
#' \item{type 'topics': a matrix with topic probabilities of dimension nrow(newdata) x the number of topics}
#' \item{type 'terms': a list of data.frame's where each data.frame has columns term, beta and rank indicating the
#' top_n most emitted terms for that topic. List element 1 corresponds to the top terms emitted by topic 1, element 2 to topic 2 ...}
#' }
#' @export
#' @examples
#' \dontshow{if(require(torch) && torch::torch_is_installed())
#' \{
#' }
#'
#' library(torch)
#' library(topicmodels.etm)
#' path <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt")
Expand All @@ -513,7 +521,6 @@ split_train_test <- function(x, train_pct = 0.7){
#' dtm <- head(dtm, n = 5)
#' scores <- predict(model, newdata = dtm, type = "topics")
#' scores
#'
#' \dontshow{
#' \}
#' # End of main if statement running only if the torch is properly installed
Expand Down Expand Up @@ -564,18 +571,19 @@ predict.ETM <- function(object, newdata, type = c("topics", "terms"), batch_size


#' @title Get matrices out of an ETM object
#' @description Convenience functions to extract
#' @description Convenience function to extract
#' \itemize{
#' \item{embeddings of the cluster centers}
#' \item{embeddings of the topic centers}
#' \item{embeddings of the words used in the model}
#' \item{words emmitted by each topic (beta), which is the softmax-transformed inner product of word embedding and topic embeddings}
#' }
#' @param x an object of class \code{ETM}
#' @param type character string with the type of information to extract: either 'beta', 'embedding'. Defaults to 'embedding'.
#' @param which a character string with either 'words' or 'topics' to get the specific embedding. Defaults to 'topics'. Only used if type = 'embedding'.
#' @param type character string with the type of information to extract: either 'beta' (words emttied by each topic) or 'embedding' (embeddings of words or topic centers). Defaults to 'embedding'.
#' @param which a character string with either 'words' or 'topics' to get either the embeddings of the words used in the model or the embedding of the topic centers. Defaults to 'topics'. Only used if type = 'embedding'.
#' @param ... not used
#' @seealso \code{\link{ETM}}
#' @return a numeric matrix
#' @return a numeric matrix containing, depending on the value supplied in \code{type}
#' either the embeddings of the topic centers, the embeddings of the words or the words emitted by each topic
#' @export
#' @examples
#' \dontshow{if(require(torch) && torch::torch_is_installed())
Expand Down Expand Up @@ -623,20 +631,24 @@ as.matrix.ETM <- function(x, type = c("embedding", "beta"), which = c("topics",
#' @title Plot functionality for an ETM object
#' @description Convenience function allowing to plot
#' \itemize{
#' \item{the evolution of the loss on the training / test set}
#' \item{a model in 2D dimensional space using a umap projection.
#' The topic plot uses function \code{\link[textplot]{textplot_embedding_2d}} from the textplot R package.}
#' \item{the evolution of the loss on the training / test set in order to inspect training convergence}
#' \item{the \code{ETM} model in 2D dimensional space using a umap projection.
#' This plot uses function \code{\link[textplot]{textplot_embedding_2d}} from the textplot R package and
#' plots the top_n most emitted words of each topic and the topic centers in 2 dimensions}
#' }
#' @param x an object of class \code{ETM}
#' @param type character string with the type of plot, either 'loss' or 'topics'
#' @param which an integer vector of clusters to plot, used in case type = 'topics'. Defaults to all clusters.
#' @param type character string with the type of plot to generate: either 'loss' or 'topics'
#' @param which an integer vector of topics to plot, used in case type = 'topics'. Defaults to all topics. See the example below.
#' @param top_n passed on to \code{summary.ETM} in order to visualise the top_n most relevant words for each topic. Defaults to 4.
#' @param title passed on to textplot_embedding_2d, used in case type = 'topics'
#' @param subtitle passed on to textplot_embedding_2d, used in case type = 'topics'
#' @param encircle passed on to textplot_embedding_2d, used in case type = 'topics'
#' @param points passed on to textplot_embedding_2d, used in case type = 'topics'
#' @param ... arguments passed on to \code{\link{summary.ETM}}
#' @seealso \code{\link{ETM}}, \code{\link{summary.ETM}}, \code{\link[textplot]{textplot_embedding_2d}}
#' @return In case \code{type} is set to 'topics', maps the topic centers and most emitted words for each topic
#' to 2D using \code{\link{summary.ETM}} and returns a ggplot object by calling \code{\link[textplot]{textplot_embedding_2d}}. \cr
#' For type 'loss', makes a base graphics plot and returns invisibly nothing.
#' @export
#' @examples
#' \dontshow{if(require(torch) && torch::torch_is_installed())
Expand Down Expand Up @@ -664,17 +676,17 @@ as.matrix.ETM <- function(x, type = c("embedding", "beta"), which = c("topics",
#' library(ggalt)
#' path <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt")
#' model <- torch_load(path)
#' plot(model, type = "topics", top_n = 7, which = c(1, 2, 14, 16, 18, 19),
#' metric = "cosine", n_neighbors = 15, fast_sgd = FALSE, n_threads = 2, verbose = TRUE,
#' title = "ETM Topics example")
#'
#'
#' plt <- plot(model, type = "topics", top_n = 7, which = c(1, 2, 14, 16, 18, 19),
#' metric = "cosine", n_neighbors = 15,
#' fast_sgd = FALSE, n_threads = 2, verbose = TRUE,
#' title = "ETM Topics example")
#' plt
#' \dontshow{
#' \}
#' # End of main if statement running only if the torch is properly installed
#' }
plot.ETM <- function(x, type = c("loss", "topics"), which, top_n = 4,
title = "ETM clusters", subtitle = "",
title = "ETM topics", subtitle = "",
encircle = FALSE, points = FALSE, ...){
type <- match.arg(type)
if(type == "loss"){
Expand All @@ -692,6 +704,7 @@ plot.ETM <- function(x, type = c("loss", "topics"), which, top_n = 4,
par(mfrow = c(1, 2))
plot(combined$epoch, combined$loss, xlab = "Epoch", ylab = "loss", main = "Avg batch loss evolution\non 70% training set", col = "steelblue", type = "b", pch = 20, lty = 2)
plot(combined$epoch, combined$loss_test, xlab = "Epoch", ylab = "exp(loss)", main = "Avg batch loss evolution\non 30% test set", col = "purple", type = "b", pch = 20, lty = 2)
invisible()
}else{
requireNamespace("textplot")
manifolded <- summary(x, top_n = top_n, ...)
Expand All @@ -717,8 +730,8 @@ plot.ETM <- function(x, type = c("loss", "topics"), which, top_n = 4,
#' \item{center: a matrix with the embeddings of the topic centers}
#' \item{words: a matrix with the embeddings of the words}
#' \item{embed_2d: a data.frame which contains a lower dimensional presentation in 2D of the topics and the top_n words associated with
#' the topic, containing columns type, term, cluster, rank, beta, x, y, weight; where type is either words or centers, x/y contain the lower dimensional
#' positions in 2D of the word and weight is the emitted beta scaled to the highest beta within a cluster and the cluster center always gets weight 0.8}
#' the topic, containing columns type, term, cluster (the topic number), rank, beta, x, y, weight; where type is either 'words' or 'centers', x/y contain the lower dimensional
#' positions in 2D of the word and weight is the emitted beta scaled to the highest beta within a topic where the topic center always gets weight 0.8}
#' }
#' @export
#' @examples
Expand Down
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ model <- torch_load("example_etm.ckpt")
Example plot shown above was created using the following code

- This uses R package [textplot](https://github.com/bnosac/textplot) >= 0.2.0 which was updated on CRAN on 2021-08-18
- The summary function maps the learned embeddings of the words and cluster centers in 2D using [UMAP](https://github.com/jlmelville/uwot) and textplot_embedding_2d plots the selected clusters of interest in 2D
- The summary function maps the learned embeddings of the words and topic centers in 2D using [UMAP](https://github.com/jlmelville/uwot) and textplot_embedding_2d plots the selected topics of interest in 2D

```
library(textplot)
Expand All @@ -314,15 +314,15 @@ manifolded <- summary(model, type = "umap", n_components = 2, metric = "cosine",
space <- subset(manifolded$embed_2d, type %in% "centers")
textplot_embedding_2d(space)
space <- subset(manifolded$embed_2d, cluster %in% c(12, 14, 9, 7) & rank <= 7)
textplot_embedding_2d(space, title = "ETM clusters", subtitle = "embedded in 2D using UMAP",
textplot_embedding_2d(space, title = "ETM topics", subtitle = "embedded in 2D using UMAP",
encircle = FALSE, points = TRUE)
```

![](tools/example-visualisation-basic.png)

#### z. Or you can brew up your own code to plot things

- Put embeddings of words and cluster centers in 2D using UMAP
- Put embeddings of words and topic centers in 2D using UMAP

```
library(uwot)
Expand All @@ -335,7 +335,7 @@ centers <- umap_transform(X = centers, model = manifold)
words <- manifold$embedding
```

- Plot words in 2D, color by cluster and add cluster centers in 2D
- Plot words in 2D, color by topic and add topic centers in 2D
- This uses R package textplot >= 0.2.0 (https://github.com/bnosac/textplot) which was put on CRAN on 2021-08-18

```
Expand All @@ -346,7 +346,7 @@ df <- list(words = merge(x = terminology,
y = data.frame(x = words[, 1], y = words[, 2], term = rownames(embeddings)),
by = "term"),
centers = data.frame(x = centers[, 1], y = centers[, 2],
term = paste("Cluster-", seq_len(nrow(centers)), sep = ""),
term = paste("Topic-", seq_len(nrow(centers)), sep = ""),
cluster = seq_len(nrow(centers))))
df <- rbindlist(df, use.names = TRUE, fill = TRUE, idcol = "type")
df <- df[, weight := ifelse(is.na(beta), 0.8, beta / max(beta, na.rm = TRUE)), by = list(cluster)]
Expand All @@ -355,26 +355,26 @@ library(textplot)
library(ggrepel)
library(ggalt)
x <- subset(df, type %in% c("words", "centers") & cluster %in% c(1, 3, 4, 8))
textplot_embedding_2d(x, title = "ETM clusters", subtitle = "embedded in 2D using UMAP", encircle = FALSE, points = FALSE)
textplot_embedding_2d(x, title = "ETM clusters", subtitle = "embedded in 2D using UMAP", encircle = TRUE, points = TRUE)
textplot_embedding_2d(x, title = "ETM topics", subtitle = "embedded in 2D using UMAP", encircle = FALSE, points = FALSE)
textplot_embedding_2d(x, title = "ETM topics", subtitle = "embedded in 2D using UMAP", encircle = TRUE, points = TRUE)
```

- Or if you like writing down the full ggplot2 code

```
library(ggplot2)
library(ggrepel)
x$cluster <- factor(x$cluster)
plt <- ggplot(x,
aes(x = x, y = y, label = term, color = cluster, cex = weight, pch = factor(type, levels = c("centers", "words")))) +
x$topic <- factor(x$cluster)
plt <- ggplot(x,
aes(x = x, y = y, label = term, color = topic, cex = weight, pch = factor(type, levels = c("centers", "words")))) +
geom_text_repel(show.legend = FALSE) +
theme_void() +
labs(title = "ETM clusters", subtitle = "embedded in 2D using UMAP")
labs(title = "ETM topics", subtitle = "embedded in 2D using UMAP")
plt + geom_point(show.legend = FALSE)
## encircle if clusters are non-overlapping can provide nice visualisations
## encircle if topics are non-overlapping can provide nice visualisations
library(ggalt)
plt + geom_encircle(aes(group = cluster, fill = cluster), alpha = 0.4, show.legend = FALSE) + geom_point(show.legend = FALSE)
plt + geom_encircle(aes(group = topic, fill = topic), alpha = 0.4, show.legend = FALSE) + geom_point(show.legend = FALSE)
```

> More examples are provided in the help of the ETM function see `?ETM`
Expand Down
11 changes: 6 additions & 5 deletions man/as.matrix.ETM.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 17 additions & 11 deletions man/plot.ETM.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 18 additions & 9 deletions man/predict.ETM.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 6df7216

Please sign in to comment.