update to address some API deprecations
[gae-samples.git] / search / product_search_python / docs.py
blobd231320c546cb8a59866d2d912b188d7b2469235
1 #!/usr/bin/env python
3 # Copyright 2012 Google Inc.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 """ Contains 'helper' classes for managing search.Documents.
18 BaseDocumentManager provides some common utilities, and the Product subclass
19 adds some Product-document-specific helper methods.
20 """
22 import collections
23 import copy
24 import datetime
25 import logging
26 import re
27 import string
28 import urllib
30 import categories
31 import config
32 import errors
33 import models
35 from google.appengine.api import search
36 from google.appengine.ext import ndb
39 class BaseDocumentManager(object):
40 """Abstract class. Provides helper methods to manage search.Documents."""
42 _INDEX_NAME = None
43 _VISIBLE_PRINTABLE_ASCII = frozenset(
44 set(string.printable) - set(string.whitespace))
46 def __init__(self, doc):
47 """Builds a dict of the fields mapped against the field names, for
48 efficient access.
49 """
50 self.doc = doc
51 fields = doc.fields
52 self.doc_map = {}
53 for field in fields:
54 fieldlist = self.doc_map.get(field.name,[])
55 fieldlist.append(field)
56 self.doc_map[field.name] = fieldlist
58 def getFirstFieldVal(self, fname):
59 """Get the value of the (first) document field with the given name."""
60 try:
61 return self.doc_map.get(fname)[0].value
62 except IndexError:
63 return None
64 except TypeError:
65 return None
67 def setFirstField(self, new_field):
68 """Set the value of the (first) document field with the given name."""
69 for i, field in enumerate(self.doc.fields):
70 if field.name == new_field.name:
71 self.doc.fields[i] = new_field
72 return True
73 return False
75 @classmethod
76 def isValidDocId(cls, doc_id):
77 """Checks if the given id is a visible printable ASCII string not starting
78 with '!'. Whitespace characters are excluded.
79 """
80 for char in doc_id:
81 if char not in cls._VISIBLE_PRINTABLE_ASCII:
82 return False
83 return not doc_id.startswith('!')
85 @classmethod
86 def getIndex(cls):
87 return search.Index(name=cls._INDEX_NAME)
89 @classmethod
90 def deleteAllInIndex(cls):
91 """Delete all the docs in the given index."""
92 docindex = cls.getIndex()
94 try:
95 while True:
96 # until no more documents, get a list of documents,
97 # constraining the returned objects to contain only the doc ids,
98 # extract the doc ids, and delete the docs.
99 document_ids = [document.doc_id
100 for document in docindex.get_range(ids_only=True)]
101 if not document_ids:
102 break
103 docindex.delete(document_ids)
104 except search.Error:
105 logging.exception("Error removing documents:")
107 @classmethod
108 def getDoc(cls, doc_id):
109 """Return the document with the given doc id. One way to do this is via
110 the get_range method, as shown here. If the doc id is not in the
111 index, the first doc in the index will be returned instead, so we need
112 to check for that case."""
113 if not doc_id:
114 return None
115 try:
116 index = cls.getIndex()
117 response = index.get_range(
118 start_id=doc_id, limit=1, include_start_object=True)
119 if response.results and response.results[0].doc_id == doc_id:
120 return response.results[0]
121 return None
122 except search.InvalidRequest: # catches ill-formed doc ids
123 return None
125 @classmethod
126 def removeDocById(cls, doc_id):
127 """Remove the doc with the given doc id."""
128 try:
129 cls.getIndex().delete(doc_id)
130 except search.Error:
131 logging.exception("Error removing doc id %s.", doc_id)
133 @classmethod
134 def add(cls, documents):
135 """wrapper for search index add method; specifies the index name."""
136 try:
137 return cls.getIndex().put(documents)
138 except search.Error:
139 logging.exception("Error adding documents.")
142 class Store(BaseDocumentManager):
144 _INDEX_NAME = config.STORE_INDEX_NAME
145 STORE_NAME = 'store_name'
146 STORE_LOCATION = 'store_location'
147 STORE_ADDRESS = 'store_address'
150 class Product(BaseDocumentManager):
151 """Provides helper methods to manage Product documents. All Product documents
152 built using these methods will include a core set of fields (see the
153 _buildCoreProductFields method). We use the given product id (the Product
154 entity key) as the doc_id. This is not required for the entity/document
155 design-- each explicitly point to each other, allowing their ids to be
156 decoupled-- but using the product id as the doc id allows a document to be
157 reindexed given its product info, without having to fetch the
158 existing document."""
160 _INDEX_NAME = config.PRODUCT_INDEX_NAME
162 # 'core' product document field names
163 PID = 'pid'
164 DESCRIPTION = 'description'
165 CATEGORY = 'category'
166 PRODUCT_NAME = 'name'
167 PRICE = 'price'
168 AVG_RATING = 'ar' #average rating
169 UPDATED = 'modified'
171 _SORT_OPTIONS = [
172 [AVG_RATING, 'average rating', search.SortExpression(
173 expression=AVG_RATING,
174 direction=search.SortExpression.DESCENDING, default_value=0)],
175 [PRICE, 'price', search.SortExpression(
176 # other examples:
177 # expression='max(price, 14.99)'
178 # If you access _score in your sort expressions,
179 # your SortOptions should include a scorer.
180 # e.g. search.SortOptions(match_scorer=search.MatchScorer(),...)
181 # Then, you can access the score to build expressions like:
182 # expression='price * _score'
183 expression=PRICE,
184 direction=search.SortExpression.ASCENDING, default_value=9999)],
185 [UPDATED, 'modified', search.SortExpression(
186 expression=UPDATED,
187 direction=search.SortExpression.DESCENDING, default_value=1)],
188 [CATEGORY, 'category', search.SortExpression(
189 expression=CATEGORY,
190 direction=search.SortExpression.ASCENDING, default_value='')],
191 [PRODUCT_NAME, 'product name', search.SortExpression(
192 expression=PRODUCT_NAME,
193 direction=search.SortExpression.ASCENDING, default_value='zzz')]
196 _SORT_MENU = None
197 _SORT_DICT = None
200 @classmethod
201 def deleteAllInProductIndex(cls):
202 cls.deleteAllInIndex()
204 @classmethod
205 def getSortMenu(cls):
206 if not cls._SORT_MENU:
207 cls._buildSortMenu()
208 return cls._SORT_MENU
210 @classmethod
211 def getSortDict(cls):
212 if not cls._SORT_DICT:
213 cls._buildSortDict()
214 return cls._SORT_DICT
216 @classmethod
217 def _buildSortMenu(cls):
218 """Build the default set of sort options used for Product search.
219 Of these options, all but 'relevance' reference core fields that
220 all Products will have."""
221 res = [(elt[0], elt[1]) for elt in cls._SORT_OPTIONS]
222 cls._SORT_MENU = [('relevance', 'relevance')] + res
224 @classmethod
225 def _buildSortDict(cls):
226 """Build a dict that maps sort option keywords to their corresponding
227 SortExpressions."""
228 cls._SORT_DICT = {}
229 for elt in cls._SORT_OPTIONS:
230 cls._SORT_DICT[elt[0]] = elt[2]
232 @classmethod
233 def getDocFromPid(cls, pid):
234 """Given a pid, get its doc. We're using the pid as the doc id, so we can
235 do this via a direct fetch."""
236 return cls.getDoc(pid)
238 @classmethod
239 def removeProductDocByPid(cls, pid):
240 """Given a doc's pid, remove the doc matching it from the product
241 index."""
242 cls.removeDocById(pid)
244 @classmethod
245 def updateRatingInDoc(cls, doc_id, avg_rating):
246 # get the associated doc from the doc id in the product entity
247 doc = cls.getDoc(doc_id)
248 if doc:
249 pdoc = cls(doc)
250 pdoc.setAvgRating(avg_rating)
251 # The use of the same id will cause the existing doc to be reindexed.
252 return doc
253 else:
254 raise errors.OperationFailedError(
255 'Could not retrieve doc associated with id %s' % (doc_id,))
257 @classmethod
258 def updateRatingsInfo(cls, doc_id, avg_rating):
259 """Given a models.Product entity, update and reindex the associated
260 document with the product entity's current average rating. """
262 ndoc = cls.updateRatingInDoc(doc_id, avg_rating)
263 # reindex the returned updated doc
264 return cls.add(ndoc)
266 # 'accessor' convenience methods
268 def getPID(self):
269 """Get the value of the 'pid' field of a Product doc."""
270 return self.getFirstFieldVal(self.PID)
272 def getName(self):
273 """Get the value of the 'name' field of a Product doc."""
274 return self.getFirstFieldVal(self.PRODUCT_NAME)
276 def getDescription(self):
277 """Get the value of the 'description' field of a Product doc."""
278 return self.getFirstFieldVal(self.DESCRIPTION)
280 def getCategory(self):
281 """Get the value of the 'cat' field of a Product doc."""
282 return self.getFirstFieldVal(self.CATEGORY)
284 def setCategory(self, cat):
285 """Set the value of the 'cat' (category) field of a Product doc."""
286 return self.setFirstField(search.NumberField(name=self.CATEGORY, value=cat))
288 def getAvgRating(self):
289 """Get the value of the 'ar' (average rating) field of a Product doc."""
290 return self.getFirstFieldVal(self.AVG_RATING)
292 def setAvgRating(self, ar):
293 """Set the value of the 'ar' field of a Product doc."""
294 return self.setFirstField(search.NumberField(name=self.AVG_RATING, value=ar))
296 def getPrice(self):
297 """Get the value of the 'price' field of a Product doc."""
298 return self.getFirstFieldVal(self.PRICE)
300 @classmethod
301 def generateRatingsBuckets(cls, query_string):
302 """Builds a dict of ratings 'buckets' and their counts, based on the
303 value of the 'avg_rating" field for the documents retrieved by the given
304 query. See the 'generateRatingsLinks' method. This information will
305 be used to generate sidebar links that allow the user to drill down in query
306 results based on rating.
308 For demonstration purposes only; this will be expensive for large data
309 sets.
312 # do the query on the *full* search results
313 # to generate the facet information, imitating what may in future be
314 # provided by the FTS API.
315 try:
316 sq = search.Query(
317 query_string=query_string.strip())
318 search_results = cls.getIndex().search(sq)
319 except search.Error:
320 logging.exception('An error occurred on search.')
321 return None
323 ratings_buckets = collections.defaultdict(int)
324 # populate the buckets
325 for res in search_results:
326 ratings_buckets[int((cls(res)).getAvgRating() or 0)] += 1
327 return ratings_buckets
329 @classmethod
330 def generateRatingsLinks(cls, query, phash):
331 """Given a dict of ratings 'buckets' and their counts,
332 builds a list of html snippets, to be displayed in the sidebar when
333 showing results of a query. Each is a link that runs the query, additionally
334 filtered by the indicated ratings interval."""
336 ratings_buckets = cls.generateRatingsBuckets(query)
337 if not ratings_buckets:
338 return None
339 rlist = []
340 for k in range(config.RATING_MIN, config.RATING_MAX+1):
341 try:
342 v = ratings_buckets[k]
343 except KeyError:
344 return
345 # build html
346 if k < 5:
347 htext = '%s-%s (%s)' % (k, k+1, v)
348 else:
349 htext = '%s (%s)' % (k, v)
350 phash['rating'] = k
351 hlink = '/psearch?' + urllib.urlencode(phash)
352 rlist.append((hlink, htext))
353 return rlist
355 @classmethod
356 def _buildCoreProductFields(
357 cls, pid, name, description, category, category_name, price):
358 """Construct a 'core' document field list for the fields common to all
359 Products. The various categories (as defined in the file 'categories.py'),
360 may add additional specialized fields; these will be appended to this
361 core list. (see _buildProductFields)."""
362 fields = [search.TextField(name=cls.PID, value=pid),
363 # The 'updated' field is always set to the current date.
364 search.DateField(name=cls.UPDATED,
365 value=datetime.datetime.now().date()),
366 search.TextField(name=cls.PRODUCT_NAME, value=name),
367 # strip the markup from the description value, which can
368 # potentially come from user input. We do this so that
369 # we don't need to sanitize the description in the
370 # templates, showing off the Search API's ability to mark up query
371 # terms in generated snippets. This is done only for
372 # demonstration purposes; in an actual app,
373 # it would be preferrable to use a library like Beautiful Soup
374 # instead.
375 # We'll let the templating library escape all other rendered
376 # values for us, so this is the only field we do this for.
377 search.TextField(
378 name=cls.DESCRIPTION,
379 value=re.sub(r'<[^>]*?>', '', description)),
380 search.AtomField(name=cls.CATEGORY, value=category),
381 search.NumberField(name=cls.AVG_RATING, value=0.0),
382 search.NumberField(name=cls.PRICE, value=price)
384 return fields
386 @classmethod
387 def _buildProductFields(cls, pid=None, category=None, name=None,
388 description=None, category_name=None, price=None, **params):
389 """Build all the additional non-core fields for a document of the given
390 product type (category), using the given params dict, and the
391 already-constructed list of 'core' fields. All such additional
392 category-specific fields are treated as required.
395 fields = cls._buildCoreProductFields(
396 pid, name, description, category, category_name, price)
397 # get the specification of additional (non-'core') fields for this category
398 pdict = categories.product_dict.get(category_name)
399 if pdict:
400 # for all fields
401 for k, field_type in pdict.iteritems():
402 # see if there is a value in the given params for that field.
403 # if there is, get the field type, create the field, and append to the
404 # document field list.
405 if k in params:
406 v = params[k]
407 if field_type == search.NumberField:
408 try:
409 val = float(v)
410 fields.append(search.NumberField(name=k, value=val))
411 except ValueError:
412 error_message = ('bad value %s for field %s of type %s' %
413 (k, v, field_type))
414 logging.error(error_message)
415 raise errors.OperationFailedError(error_message)
416 elif field_type == search.TextField:
417 fields.append(search.TextField(name=k, value=str(v)))
418 else:
419 # you may want to add handling of other field types for generality.
420 # Not needed for our current sample data.
421 logging.warn('not processed: %s, %s, of type %s', k, v, field_type)
422 else:
423 error_message = ('value not given for field "%s" of field type "%s"'
424 % (k, field_type))
425 logging.warn(error_message)
426 raise errors.OperationFailedError(error_message)
427 else:
428 # else, did not have an entry in the params dict for the given field.
429 logging.warn(
430 'product field information not found for category name %s',
431 params['category_name'])
432 return fields
434 @classmethod
435 def _createDocument(
436 cls, pid=None, category=None, name=None, description=None,
437 category_name=None, price=None, **params):
438 """Create a Document object from given params."""
439 # check for the fields that are always required.
440 if pid and category and name:
441 # First, check that the given pid has only visible ascii characters,
442 # and does not contain whitespace. The pid will be used as the doc_id,
443 # which has these requirements.
444 if not cls.isValidDocId(pid):
445 raise errors.OperationFailedError("Illegal pid %s" % pid)
446 # construct the document fields from the params
447 resfields = cls._buildProductFields(
448 pid=pid, category=category, name=name,
449 description=description,
450 category_name=category_name, price=price, **params)
451 # build and index the document. Use the pid (product id) as the doc id.
452 # (If we did not do this, and left the doc_id unspecified, an id would be
453 # auto-generated.)
454 d = search.Document(doc_id=pid, fields=resfields)
455 return d
456 else:
457 raise errors.OperationFailedError('Missing parameter.')
459 @classmethod
460 def _normalizeParams(cls, params):
461 """Normalize the submitted params for building a product."""
463 params = copy.deepcopy(params)
464 try:
465 params['pid'] = params['pid'].strip()
466 params['name'] = params['name'].strip()
467 params['category_name'] = params['category']
468 params['category'] = params['category']
469 try:
470 params['price'] = float(params['price'])
471 except ValueError:
472 error_message = 'bad price value: %s' % params['price']
473 logging.error(error_message)
474 raise errors.OperationFailedError(error_message)
475 return params
476 except KeyError as e1:
477 logging.exception("key error")
478 raise errors.OperationFailedError(e1)
479 except errors.Error as e2:
480 logging.debug(
481 'Problem with params: %s: %s' % (params, e2.error_message))
482 raise errors.OperationFailedError(e2.error_message)
484 @classmethod
485 def buildProductBatch(cls, rows):
486 """Build product documents and their related datastore entities, in batch,
487 given a list of params dicts. Should be used for new products, as does not
488 handle updates of existing product entities. This method does not require
489 that the doc ids be tied to the product ids, and obtains the doc ids from
490 the results of the document add."""
492 docs = []
493 dbps = []
494 for row in rows:
495 try:
496 params = cls._normalizeParams(row)
497 doc = cls._createDocument(**params)
498 docs.append(doc)
499 # create product entity, sans doc_id
500 dbp = models.Product(
501 id=params['pid'], price=params['price'],
502 category=params['category'])
503 dbps.append(dbp)
504 except errors.OperationFailedError:
505 logging.error('error creating document from data: %s', row)
506 try:
507 add_results = cls.add(docs)
508 except search.Error:
509 logging.exception('Add failed')
510 return
511 if len(add_results) != len(dbps):
512 # this case should not be reached; if there was an issue,
513 # search.Error should have been thrown, above.
514 raise errors.OperationFailedError(
515 'Error: wrong number of results returned from indexing operation')
516 # now set the entities with the doc ids, the list of which are returned in
517 # the same order as the list of docs given to the indexers
518 for i, dbp in enumerate(dbps):
519 dbp.doc_id = add_results[i].id
520 # persist the entities
521 ndb.put_multi(dbps)
523 @classmethod
524 def buildProduct(cls, params):
525 """Create/update a product document and its related datastore entity. The
526 product id and the field values are taken from the params dict.
528 params = cls._normalizeParams(params)
529 # check to see if doc already exists. We do this because we need to retain
530 # some information from the existing doc. We could skip the fetch if this
531 # were not the case.
532 curr_doc = cls.getDocFromPid(params['pid'])
533 d = cls._createDocument(**params)
534 if curr_doc: # retain ratings info from existing doc
535 avg_rating = cls(curr_doc).getAvgRating()
536 cls(d).setAvgRating(avg_rating)
538 # This will reindex if a doc with that doc id already exists
539 doc_ids = cls.add(d)
540 try:
541 doc_id = doc_ids[0].id
542 except IndexError:
543 doc_id = None
544 raise errors.OperationFailedError('could not index document')
545 logging.debug('got new doc id %s for product: %s', doc_id, params['pid'])
547 # now update the entity
548 def _tx():
549 # Check whether the product entity exists. If so, we want to update
550 # from the params, but preserve its ratings-related info.
551 prod = models.Product.get_by_id(params['pid'])
552 if prod: #update
553 prod.update_core(params, doc_id)
554 else: # create new entity
555 prod = models.Product.create(params, doc_id)
556 prod.put()
557 return prod
558 prod = ndb.transaction(_tx)
559 logging.debug('prod: %s', prod)
560 return prod