191 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			191 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | '''
 | ||
|  | Extend cpickle storage to include modules, and builtin functions/methods | ||
|  | 
 | ||
|  | To use, just import this module. | ||
|  | '''
 | ||
|  | import copy_reg | ||
|  | 
 | ||
|  | ### OBJECTS WHICH ARE RESTORED THROUGH IMPORTS | ||
|  | # MODULES | ||
|  | def pickle_module(module): | ||
|  | 	'''
 | ||
|  | 	Store a module to a pickling stream, must be available for | ||
|  | 	reimport during unpickling | ||
|  | 	'''
 | ||
|  | 	return unpickle_imported_code, ('import %s'%module.__name__, module.__name__) | ||
|  | 
 | ||
|  | # FUNCTIONS, METHODS (BUILTIN) | ||
|  | def pickle_imported_code(funcmeth): | ||
|  | 	'''
 | ||
|  | 	Store a reference to an imported element (such as a function/builtin function, | ||
|  | 	Must be available for reimport during unpickling. | ||
|  | 	'''
 | ||
|  | 	module = _whichmodule(funcmeth) | ||
|  | 	return unpickle_imported_code, ('from %s import %s'%(module.__name__,funcmeth.__name__),funcmeth.__name__) | ||
|  | 
 | ||
|  | import types, regex | ||
|  | import_filter = regex.compile('''\(from [A-Za-z0-9_\.]+ \)?import [A-Za-z0-9_\.]+''') # note the limitations on whitespace | ||
|  | getattr_filter = regex.compile('''[A-Za-z0-9_\.]+''') # note we allow you to use x.y.z here | ||
|  | 
 | ||
|  | # MODULES, AND FUNCTIONS | ||
|  | def unpickle_imported_code(impstr,impname): | ||
|  | 	'''
 | ||
|  | 	Attempt to load a reference to a module or other imported code (such as functions/builtin functions) | ||
|  | 	'''
 | ||
|  | 	if import_filter.match(impstr) != len(impstr) or getattr_filter.match(impname)!= len(impname): | ||
|  | 		import sys | ||
|  | 		sys.stderr.write('''Possible attempt to smuggle arbitrary code into pickle file (see module cpickle_extend).\nPassed code was %s\n%s\n'''%(impstr,impname)) | ||
|  | 		del(sys) | ||
|  | 	else: | ||
|  | 		ns = {} | ||
|  | 		try: | ||
|  | 			exec (impstr) in ns # could raise all sorts of errors, of course, and is still dangerous when you have no control over the modules on your system!  Do not allow for untrusted code!!! | ||
|  | 			return eval(impname, ns) | ||
|  | 		except: | ||
|  | 			import sys | ||
|  | 			sys.stderr.write('''Error unpickling module %s\n None returned, will likely raise errors.'''%impstr) | ||
|  | 			return None | ||
|  | 
 | ||
|  | # Modules | ||
|  | copy_reg.pickle(type(regex),pickle_module,unpickle_imported_code) | ||
|  | # builtin functions/methods | ||
|  | copy_reg.pickle(type(regex.compile),pickle_imported_code, unpickle_imported_code) | ||
|  | 
 | ||
|  | del(regex) # to keep the namespace neat as possible | ||
|  | 
 | ||
|  | ### INSTANCE METHODS | ||
|  | '''
 | ||
|  | The problem with instance methods is that they are almost always | ||
|  | stored inside a class somewhere.  We really need a new type: reference | ||
|  | that lets us just say "y.this" | ||
|  | 
 | ||
|  | We also need something that can reliably find burried functions :( not | ||
|  | likely to be easy or clean... | ||
|  | 
 | ||
|  | then filter for x is part of the set | ||
|  | '''
 | ||
|  | import new | ||
|  | 
 | ||
|  | def pickle_instance_method(imeth): | ||
|  | 	'''
 | ||
|  | 	Use the (rather surprisingly clean) internals of | ||
|  | 	the method to store a reference to a method. Might | ||
|  | 	be better to use a more general "get the attribute | ||
|  | 	'x' of this object" system, but I haven't written that yet :) | ||
|  | 	'''
 | ||
|  | 	klass = imeth.im_class | ||
|  | 	funcimp = _imp_meth(imeth) | ||
|  | 	self = imeth.im_self # will be None for UnboundMethodType | ||
|  | 	return unpickle_instance_method, (funcimp,self,klass) | ||
|  | def unpickle_instance_method(funcimp,self,klass): | ||
|  | 	'''
 | ||
|  | 	Attempt to restore a reference to an instance method, | ||
|  | 	the instance has already been recreated by the system | ||
|  | 	as self, so we just call new.instancemethod | ||
|  | 	'''
 | ||
|  | 	funcimp = apply(unpickle_imported_code, funcimp) | ||
|  | 	return new.instancemethod(func,self,klass) | ||
|  | 
 | ||
|  | copy_reg.pickle(types.MethodType, pickle_instance_method, unpickle_instance_method) | ||
|  | copy_reg.pickle(types.UnboundMethodType, pickle_instance_method, unpickle_instance_method) | ||
|  | 
 | ||
|  | ### Arrays | ||
|  | try: | ||
|  | 	import array | ||
|  | 	LittleEndian = array.array('i',[1]).tostring()[0] == '\001' | ||
|  | 	def pickle_array(somearray): | ||
|  | 		'''
 | ||
|  | 		Store a standard array object, inefficient because of copying to string | ||
|  | 		'''
 | ||
|  | 		return unpickle_array, (somearray.typecode, somearray.tostring(), LittleEndian) | ||
|  | 	def unpickle_array(typecode, stringrep, origendian): | ||
|  | 		'''
 | ||
|  | 		Restore a standard array object | ||
|  | 		'''
 | ||
|  | 		newarray = array.array(typecode) | ||
|  | 		newarray.fromstring(stringrep) | ||
|  | 		# floats are always big-endian, single byte elements don't need swapping | ||
|  | 		if origendian != LittleEndian and typecode in ('I','i','h','H'): | ||
|  | 			newarray.byteswap() | ||
|  | 		return newarray | ||
|  | 	copy_reg.pickle(array.ArrayType, pickle_array, unpickle_array) | ||
|  | except ImportError: # no arrays | ||
|  | 	pass | ||
|  | 
 | ||
|  | ### NUMPY Arrays | ||
|  | try: | ||
|  | 	import Numeric | ||
|  | 	LittleEndian = Numeric.array([1],'i').tostring()[0] == '\001' | ||
|  | 	def pickle_numpyarray(somearray): | ||
|  | 		'''
 | ||
|  | 		Store a numpy array, inefficent, but should work with cPickle | ||
|  | 		'''
 | ||
|  | 		return unpickle_numpyarray, (somearray.typecode(), somearray.shape, somearray.tostring(), LittleEndian) | ||
|  | 	def unpickle_numpyarray(typecode, shape, stringval, origendian): | ||
|  | 		'''
 | ||
|  | 		Restore a numpy array | ||
|  | 		'''
 | ||
|  | 		newarray = Numeric.fromstring(stringval, typecode) | ||
|  | 		Numeric.reshape(newarray, shape) | ||
|  | 		if origendian != LittleEndian and typecode in ('I','i','h','H'): | ||
|  | 			# this doesn't seem to work correctly, what's byteswapped doing??? | ||
|  | 			return newarray.byteswapped() | ||
|  | 		else: | ||
|  | 			return newarray | ||
|  | 	copy_reg.pickle(Numeric.ArrayType, pickle_numpyarray, unpickle_numpyarray) | ||
|  | except ImportError: | ||
|  | 	pass | ||
|  | 
 | ||
|  | ### UTILITY FUNCTIONS | ||
|  | classmap = {} | ||
|  | def _whichmodule(cls): | ||
|  | 	"""Figure out the module in which an imported_code object occurs.
 | ||
|  | 	Search sys.modules for the module. | ||
|  | 	Cache in classmap. | ||
|  | 	Return a module name. | ||
|  | 	If the class cannot be found, return __main__. | ||
|  | 	Copied here from the standard pickle distribution | ||
|  | 	to prevent another import | ||
|  | 	"""
 | ||
|  | 	if classmap.has_key(cls): | ||
|  | 		return classmap[cls] | ||
|  | 	clsname = cls.__name__ | ||
|  | 	for name, module in sys.modules.items(): | ||
|  | 		if name != '__main__' and \ | ||
|  | 		   hasattr(module, clsname) and \ | ||
|  | 		   getattr(module, clsname) is cls: | ||
|  | 			break | ||
|  | 	else: | ||
|  | 		name = '__main__' | ||
|  | 	classmap[cls] = name | ||
|  | 	return name | ||
|  | 
 | ||
|  | import os, string, sys | ||
|  | 
 | ||
|  | def _imp_meth(im): | ||
|  | 	'''
 | ||
|  | 	One-level deep recursion on finding methods, i.e. we can  | ||
|  | 	find them only if the class is at the top level. | ||
|  | 	'''
 | ||
|  | 	fname = im.im_func.func_code.co_filename | ||
|  | 	tail = os.path.splitext(os.path.split(fname)[1])[0] | ||
|  | 	ourkeys = sys.modules.keys() | ||
|  | 	possibles = filter(lambda x,tail=tail: x[-1] == tail, map(string.split, ourkeys, ['.']*len(ourkeys))) | ||
|  | 	 | ||
|  | 	# now, iterate through possibles to find the correct class/function | ||
|  | 	possibles = map(string.join, possibles, ['.']*len(possibles)) | ||
|  | 	imp_string = _search_modules(possibles, im.im_func) | ||
|  | 	return imp_string | ||
|  | 	 | ||
|  | def _search_modules(possibles, im_func): | ||
|  | 	for our_mod_name in possibles: | ||
|  | 		our_mod = sys.modules[our_mod_name] | ||
|  | 		if hasattr(our_mod, im_func.__name__) and getattr(our_mod, im_func.__name__).im_func is im_func: | ||
|  | 			return 'from %s import %s'%(our_mod.__name__, im_func.__name__), im_func.__name__ | ||
|  | 		for key,val in our_mod.__dict__.items(): | ||
|  | 			if hasattr(val, im_func.__name__) and getattr(val, im_func.__name__).im_func is im_func: | ||
|  | 				return 'from %s import %s'%(our_mod.__name__,key), '%s.%s'%(key,im_func.__name__) | ||
|  | 	raise '''No import string calculable for %s'''%im_func | ||
|  | 
 | ||
|  | 
 |