57
57
struct BMI25TransformerResult
58
58
vocab:: Vector{String}
59
59
idf_vector:: Vector{Float64}
60
+ mean_words_in_docs:: Float64
60
61
end
61
62
62
- get_result (:: BM25Transformer , idf:: Vector{Float64} , vocab:: Vector{String} ) = BMI25TransformerResult (vocab, idf)
63
+ function get_result (:: BM25Transformer , idf:: Vector{F} , vocab:: Vector{String} , doc_term_mat:: SparseMatrixCSC ) where {F <: AbstractFloat }
64
+ words_in_documents = F .(sum (doc_term_mat; dims= 1 ))
65
+ mean_words_in_docs = mean (words_in_documents)
66
+ BMI25TransformerResult (vocab, idf, mean_words_in_docs)
67
+ end
63
68
64
69
# BM25: Okapi Best Match 25
65
70
# Details at: https://en.wikipedia.org/wiki/Okapi_BM25
66
71
# derived from https://github.com/zgornel/StringAnalysis.jl/blob/master/src/stats.jl
67
72
function build_bm25! (doc_term_mat:: SparseMatrixCSC{T} ,
68
73
bm25:: SparseMatrixCSC{F} ,
69
- idf_vector:: Vector{F} ;
74
+ idf_vector:: Vector{F} ,
75
+ mean_words_in_docs:: Float64 ;
70
76
κ:: Int = 2 ,
71
77
β:: Float64 = 0.75 ) where {T <: Real , F <: AbstractFloat }
72
78
@assert size (doc_term_mat) == size (bm25)
@@ -82,7 +88,7 @@ function build_bm25!(doc_term_mat::SparseMatrixCSC{T},
82
88
83
89
# TF tells us what proportion of a document is defined by a term
84
90
words_in_documents = F .(sum (doc_term_mat; dims= 1 ))
85
- ln = words_in_documents ./ mean (words_in_documents)
91
+ ln = words_in_documents ./ mean_words_in_docs
86
92
oneval = one (F)
87
93
88
94
for i = 1 : n
100
106
function _transform (transformer:: BM25Transformer ,
101
107
result:: BMI25TransformerResult ,
102
108
v:: Corpus )
103
- dtm_matrix = build_dtm (v, result. vocab)
104
- bm25 = similar (dtm_matrix . dtm, eltype (result. idf_vector))
105
- build_bm25! (dtm_matrix . dtm, bm25, result. idf_vector; κ= transformer. κ, β= transformer. β)
109
+ doc_terms = build_dtm (v, result. vocab)
110
+ bm25 = similar (doc_terms . dtm, eltype (result. idf_vector))
111
+ build_bm25! (doc_terms . dtm, bm25, result. idf_vector, result . mean_words_in_docs ; κ= transformer. κ, β= transformer. β)
106
112
107
113
# here we return the `adjoint` of our sparse matrix to conform to
108
114
# the `n x p` dimensions throughout MLJ
113
119
function MMI. fitted_params (:: BM25Transformer , fitresult)
114
120
vocab = fitresult. vocab
115
121
idf_vector = fitresult. idf_vector
116
- return (vocab = vocab, idf_vector = idf_vector)
122
+ mean_words_in_docs = fitresult. mean_words_in_docs
123
+ return (vocab = vocab, idf_vector = idf_vector, mean_words_in_docs = mean_words_in_docs)
117
124
end
118
125
119
126
0 commit comments