extended Plane class.

pull/1/head
Yusuke Shinyama 2011-05-14 14:16:40 +09:00
parent fcf0d74ecc
commit b8d516fc52
1 changed files with 25 additions and 12 deletions

View File

@ -196,7 +196,7 @@ class ObjIdRange(object):
## Plane ## Plane
## ##
## A data structure for objects placed on a plane. ## A set-like data structure for objects placed on a plane.
## Can efficiently find objects in a certain rectangular area. ## Can efficiently find objects in a certain rectangular area.
## It maintains two parallel lists of objects, each of ## It maintains two parallel lists of objects, each of
## which is sorted by its x or y coordinate. ## which is sorted by its x or y coordinate.
@ -204,7 +204,8 @@ class ObjIdRange(object):
class Plane(object): class Plane(object):
def __init__(self, objs=None, gridsize=50): def __init__(self, objs=None, gridsize=50):
self._objs = {} self._objs = set()
self._grid = {}
self.gridsize = gridsize self.gridsize = gridsize
if objs is not None: if objs is not None:
for obj in objs: for obj in objs:
@ -214,6 +215,15 @@ class Plane(object):
def __repr__(self): def __repr__(self):
return ('<Plane objs=%r>' % list(self)) return ('<Plane objs=%r>' % list(self))
def __iter__(self):
return iter(self._objs)
def __len__(self):
return len(self._objs)
def __contains__(self, obj):
return obj in self._objs
def _getrange(self, (x0,y0,x1,y1)): def _getrange(self, (x0,y0,x1,y1)):
for y in drange(y0, y1, self.gridsize): for y in drange(y0, y1, self.gridsize):
for x in drange(x0, x1, self.gridsize): for x in drange(x0, x1, self.gridsize):
@ -223,35 +233,38 @@ class Plane(object):
# add(obj): place an object. # add(obj): place an object.
def add(self, obj): def add(self, obj):
for k in self._getrange((obj.x0, obj.y0, obj.x1, obj.y1)): for k in self._getrange((obj.x0, obj.y0, obj.x1, obj.y1)):
if k not in self._objs: if k not in self._grid:
r = [] r = []
self._objs[k] = r self._grid[k] = r
else: else:
r = self._objs[k] r = self._grid[k]
r.append(obj) r.append(obj)
self._objs.add(obj)
return return
# remove(obj): displace an object. # remove(obj): displace an object.
def remove(self, obj): def remove(self, obj):
for k in self._getrange((obj.x0, obj.y0, obj.x1, obj.y1)): for k in self._getrange((obj.x0, obj.y0, obj.x1, obj.y1)):
try: try:
self._objs[k].remove(obj) self._grid[k].remove(obj)
except (KeyError, ValueError): except (KeyError, ValueError):
pass pass
self._objs.remove(obj)
return return
# find(): finds objects that are in a certain area. # find(): finds objects that are in a certain area.
def find(self, (x0,y0,x1,y1)): def find(self, (x0,y0,x1,y1)):
r = set() r = set()
done = set()
for k in self._getrange((x0,y0,x1,y1)): for k in self._getrange((x0,y0,x1,y1)):
if k not in self._objs: continue if k not in self._grid: continue
for obj in self._objs[k]: for obj in self._grid[k]:
if obj in r: continue if obj in done: continue
r.add(obj) done.add(obj)
if (obj.x1 <= x0 or x1 <= obj.x0 or if (obj.x1 <= x0 or x1 <= obj.x0 or
obj.y1 <= y0 or y1 <= obj.y0): continue obj.y1 <= y0 or y1 <= obj.y0): continue
yield obj r.add(obj)
return return r
# create_bmp # create_bmp