Registration is open - Live, Instructor-led Online Classes - Elasticsearch in March - Solr in April - OpenSearch in May. See all classes


Entity Extraction with Scikit-learn Classifiers

What is entity extraction?

Entity extraction is the process of figuring out which fields a query should target, as opposed to always hitting all fields. For example: how to tell, when the user typed in Apple iPhone, that the intent was to run company:Apple AND product:iPhone?

Is entity extraction a classification problem?

Typically, when you think about entity extraction, you think about context: in Nokia 3310 is an old phone words like is or an are strong indicators that before them, we have a subject.

E-commerce queries are a special case: we often have little context. In our “Entity Extraction for Product Searches” presentation at Activate, we argued that if all you have is Nokia 3310, figuring out that Nokia is a manufacturer and 3310 is a model is a classification problem.

In this post, we’ll explore one of the approaches to solve this classification problem: training and using Scikit-learn classification models.

What’s Scikit-learn and how can I get it?

Scikit-learn is a popular machine learning library. It’s written in Python, so to get it, you can just:

pip install sklearn
pip install numpy

We’ll install NumPy as well, because we need to provide the training set as a NumPy array.

Feature selection

Before implementing anything, we need to figure out which features are relevant for classification. Feature selection is a continuous process, but we need something to begin with.

In the Activate example, we used three features: term frequency, number of digits and number of spaces. We assume that, typically, manufacturer names will occur more often in our index compared to model numbers, which are pretty unique. We expect more digits in model numbers and more spaces in manufacturer names.

The fundamental question is, what would help one distinguish an entity from another. In this case, the manufacturer from the model number. You can get creative with features: does the entity match a dictionary of manufacturers or models? How long is the query and in which position(s) is our entity located? Because there are common constructs in E-commerce, such as manufacturer+model (Nokia 3310) or model+generation (iPhone 3GS, if we stick to old school).

Training and test sets

Data cleanup

When it comes to training and testing a model, the old “garbage in, garbage out” saying applies here as well. You’ll want to curate your data as you see fit: lowercasing and stemming would be useful in many entity extraction setups. Just as they are for regular search 🙂

When testing or applying the model, you’ll notice that some “entities” span across multiple words. You could take word n-grams to fix this problem. For example, in Apple Mac Book, you’d take apple, mac, book, apple mac and mac book, and expect to get apple as manufacturer and mac and mac book as models. From which you can take the larger gram (mac book) or both (mac + mac book, but rank “mac book” higher), depending on how you’d like to balance precision and recall.

Parsing entities into feature arrays

When training a model, you don’t feed Scikit-learn the actual words, but the features of those words. You’ll need code that, given the queries (or entities), can generate feature arrays. In our example, for Nokia, you’ll have 0 numbers, 0 spaces and its frequency in your index.

In our sample code, we read data from a file. We assume each line contains an entity and we also use the file to judge frequencies: if we encounter an entity N times, we’ll get a frequency of N. In the end, we return a dictionary, where the entity is the key, and the value is the feature array for that entity.

def read_into_feature_dict(file):
    with open(file) as le_file:
        le_dict = {}
        for line in le_file:
            line = line.strip("\n")
            if line not in le_dict:
                # other features besides frequency
                digits = sum(c.isdigit() for c in line)
                spaces = sum(c.isspace() for c in line)
                # initialize an array of [frequency, digits, spaces]. Frequency is initially 1
                le_dict[line] = [1,digits, spaces]

            else:
                # increment frequency if we met this before
                le_dict[line][0] = le_dict[line][0] + 1
        return le_dict

Training a model

To train the model, we’ll need only the list of feature arrays, without the keys. This list of feature arrays is our training set (X), but we’ll also need labels for each entity (y). In our case, labels are manufacturers or models:

# we have a file with manufacturers and one with models. Read them into dictionaries
mfr_feature_dict = read_into_feature_dict("mfrs")
model_feature_dict = read_into_feature_dict("models")

# from the dictionaries, we get only the feature arrays and add them to one list
training = []
for i in mfr_feature_dict:
    training.append(mfr_feature_dict[i])
for i in model_feature_dict:
    training.append(model_feature_dict[i])

# make the list a NumPy array. That’s what Scikit-learn requires
X = np.array(training)

# add training labels. We know that we first added manufacturers, then models
y = []
for i in range(len(mfr_feature_dict)):
    y.append("mfr")
for i in range(len(model_feature_dict)):
    y.append("model")

At this point, we can select a model and train it. Scikit-learn comes with a variety of classifiers out of the box. From simple linear Support Vector Machines like we’re using in this example, to decision trees and perceptrons (the same sort of algorithms you saw in our OpenNLP tutorial). You’d use them in similar way, though parameters are different, of course.

With our training X and y, and the algorithm selected, we can try it. For linear SVC, the code can be:

# select the algorithm. Here, linear SVC
clf = svm.SVC(kernel='linear', C = 1.0)
# train it
clf.fit(X, y)

Here, C is the penalty parameter for the error term. The intuition is that, with higher C, your model will fit your training set better, but it may also lead to overfitting. There are other SVC parameters as well, such as the number of iterations.

Using the model to predict entities

At this point, we can use our model for entity extraction. Or at least we can test it. To do that, we can build a test X from some test samples and use the predict() function of our classifier to get the suggested entities:

def test_from_file(test_file):
  test_X = []

  # same function that we used for the training set: read manufacturers/codes from a file
  # then turn them into a dictionary of entities to feature arrays
  test_dict = read_into_feature_dict(test_file)

  # concatenate feature arrays into our X
  for feature_set in test_dict:
    test_X.append(test_dict[feature_set])
  print(test_dict.keys())
  print(test_X)

  # use our model to predict entities for each entity
  print(clf.predict(test_X))

Conclusions and next steps

With well-selected features, classification is a good solution to extract entities from E-commerce queries. We showed an example here with Scikit-learn, but of course, there are other good options. SpaCy is one of them, and we’ll publish another how-to here soon!

If you find this stuff exciting, please join us: we’re hiring worldwide. If you need entity extraction, relevancy tuning, or any other help with your search infrastructure, please reach out, because we provide:

If you want to boost your productivity with Solr or Elasticsearch, check out two useful Cheat Sheets to help you boost your productivity and save time when you’re working with any of these two open-source search engines:

Start Free Trial