Python - sklearn Hidden Dirichlet distribution conversion vs Fittransform

I use the Sklearn NMF and LDA submodules to analyze unlabeled text. I read the documentation, but I'm not sure that the conversion functions in these modules (NMF and LDA) are the same as the back function in R models (see Predicting LDA topics for new data ). Basically, I’m looking for a function that will allow me to predict topics in a test suite using a model prepared from training sets. I predicted topics across the entire dataset. Then I divided the data into training and test sets, trained the train set model, and converted the test set using this model. although it was expected that I would not get the same results, comparing the two topics with those, not assuring me that the conversion function performs the same function as package R. I would appreciate your answer.

Thank you

+4
source share
1 answer

A call transformto the model LatentDirichletAllocationcauses an abnormal distribution of document topics. To get the right probabilities, you can simply normalize the result. Here is an example:

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from sklearn.datasets import fetch_20newsgroups
import numpy as np

# grab a sample data set
dataset = fetch_20newsgroups(shuffle=True, remove=('headers', 'footers', 'quotes'))
train,test = dataset.data[:100], dataset.data[100:200]

# vectorizer the features
tf_vectorizer = TfidfVectorizer(max_features=25)
X_train = tf_vectorizer.fit_transform(train)

# train the model
lda = LatentDirichletAllocation(n_topics=5)
lda.fit(X_train)

# predict topics for test data
# unnormalized doc-topic distribution
X_test = tf_vectorizer.transform(test)
doc_topic_dist_unnormalized = np.matrix(lda.transform(X_test))

# normalize the distribution (only needed if you want to work with the probabilities)
doc_topic_dist = doc_topic_dist_unnormalized/doc_topic_dist_unnormalized.sum(axis=1)

To find a top ranking topic, you can do something like:

doc_topic_dist.argmax(axis=1)
+6
source

Source: https://habr.com/ru/post/1660719/


All Articles