webpy/ 0002755 0001750 0001750 00000000000 11773142165 010351 5 ustar wmb wmb webpy/experimental/ 0002755 0001750 0001750 00000000000 11773133455 013050 5 ustar wmb wmb webpy/experimental/migration.py 0000644 0001750 0001750 00000005204 11773133455 015412 0 ustar wmb wmb """Migration script to run web.py 0.23 programs using 0.3.
Import this module at the beginning of your program.
"""
import web
import sys
def setup_database():
if web.config.get('db_parameters'):
db = web.database(**web.config.db_parameters)
web.insert = db.insert
web.select = db.select
web.update = db.update
web.delete = db.delete
web.query = db.query
def transact():
t = db.transaction()
web.ctx.setdefault('transaction_stack', []).append(t)
def rollback():
stack = web.ctx.get('transaction_stack')
t = stack and stack.pop()
t and t.rollback()
def commit():
stack = web.ctx.get('transaction_stack')
t = stack and stack.pop()
t and t.commit()
web.transact = transact
web.rollback = rollback
web.commit = commit
web.loadhooks = web.webapi.loadhooks = {}
web._loadhooks = web.webapi._loadhooks = {}
web.unloadhooks = web.webapi.unloadhooks = {}
def load():
setup_database()
web.load = load
def run(urls, fvars, *middleware):
setup_database()
def stdout_processor(handler):
handler()
return web.ctx.get('output', '')
def hook_processor(handler):
for h in web.loadhooks.values() + web._loadhooks.values(): h()
output = handler()
for h in web.unloadhooks.values(): h()
return output
app = web.application(urls, fvars)
app.add_processor(stdout_processor)
app.add_processor(hook_processor)
app.run(*middleware)
class _outputter:
"""Wraps `sys.stdout` so that print statements go into the response."""
def __init__(self, file): self.file = file
def write(self, string_):
if hasattr(web.ctx, 'output'):
return output(string_)
else:
self.file.write(string_)
def __getattr__(self, attr): return getattr(self.file, attr)
def __getitem__(self, item): return self.file[item]
def output(string_):
"""Appends `string_` to the response."""
string_ = web.safestr(string_)
if web.ctx.get('flush'):
web.ctx._write(string_)
else:
web.ctx.output += str(string_)
def _capturedstdout():
sysstd = sys.stdout
while hasattr(sysstd, 'file'):
if isinstance(sys.stdout, _outputter): return True
sysstd = sysstd.file
if isinstance(sys.stdout, _outputter): return True
return False
if not _capturedstdout():
sys.stdout = _outputter(sys.stdout)
web.run = run
class Stowage(web.storage):
def __str__(self):
return self._str
web.template.Stowage = web.template.stowage = Stowage
webpy/experimental/untwisted.py 0000644 0001750 0001750 00000007116 11773133455 015453 0 ustar wmb wmb import random
from twisted.internet import reactor, defer
from twisted.web import http
import simplejson
import web
class Request(http.Request):
def process(self):
self.content.seek(0, 0)
env = {
'REMOTE_ADDR': self.client.host,
'REQUEST_METHOD': self.method,
'PATH_INFO': self.path,
'CONTENT_LENGTH': web.intget(self.getHeader('content-length'), 0),
'wsgi.input': self.content
}
if '?' in self.uri:
env['QUERY_STRING'] = self.uri.split('?', 1)[1]
for k, v in self.received_headers.iteritems():
env['HTTP_' + k.upper()] = v
if self.path.startswith('/static/'):
f = web.lstrips(self.path, '/static/')
assert '/' not in f
#@@@ big security hole
self.write(file('static/' + f).read())
return self.finish()
web.webapi._load(env)
web.ctx.trequest = self
result = self.actualfunc()
self.setResponseCode(int(web.ctx.status.split()[0]))
for (h, v) in web.ctx.headers:
self.setHeader(h, v)
self.write(web.ctx.output)
if not web.ctx.get('persist'):
self.finish()
class Server(http.HTTPFactory):
def __init__(self, func):
self.func = func
def buildProtocol(self, addr):
"""Generate a channel attached to this site.
"""
channel = http.HTTPFactory.buildProtocol(self, addr)
class MyRequest(Request):
actualfunc = staticmethod(self.func)
channel.requestFactory = MyRequest
channel.site = self
return channel
def runtwisted(func):
reactor.listenTCP(8086, Server(func))
reactor.run()
def newrun(inp, fvars):
print "Running on http://0.0.0.0:8086/"
runtwisted(web.webpyfunc(inp, fvars, False))
def iframe(url):
return """
""" % url #("http://%s.ajaxpush.lh.theinfo.org:8086%s" % (random.random(), url))
class Feed:
def __init__(self):
self.sessions = []
def subscribe(self):
request = web.ctx.trequest
self.sessions.append(request)
request.connectionLost = lambda reason: self.sessions.remove(request)
web.ctx.persist = True
def publish(self, text):
for x in self.sessions:
x.write(text)
class JSFeed(Feed):
def __init__(self, callback="callback"):
Feed.__init__(self)
self.callback = callback
def publish(self, obj):
web.debug("publishing")
Feed.publish(self,
'' % (self.callback, simplejson.dumps(obj) +
" " * 2048))
if __name__ == "__main__":
mfeed = JSFeed()
urls = (
'/', 'view',
'/js', 'js',
'/send', 'send'
)
class view:
def GET(self):
print """
Today's News
Contribute
"""
class js:
def GET(self):
mfeed.subscribe()
class send:
def POST(self):
mfeed.publish('
%s
' % web.input().text + (" " * 2048))
web.seeother('/')
newrun(urls, globals()) webpy/experimental/background.py 0000644 0001750 0001750 00000002560 11773133455 015542 0 ustar wmb wmb """Helpers functions to run log-running tasks."""
from web import utils
from web import webapi as web
def background(func):
"""A function decorator to run a long-running function as a background thread."""
def internal(*a, **kw):
web.data() # cache it
tmpctx = web._context[threading.currentThread()]
web._context[threading.currentThread()] = utils.storage(web.ctx.copy())
def newfunc():
web._context[threading.currentThread()] = tmpctx
func(*a, **kw)
myctx = web._context[threading.currentThread()]
for k in myctx.keys():
if k not in ['status', 'headers', 'output']:
try: del myctx[k]
except KeyError: pass
t = threading.Thread(target=newfunc)
background.threaddb[id(t)] = t
t.start()
web.ctx.headers = []
return seeother(changequery(_t=id(t)))
return internal
background.threaddb = {}
def backgrounder(func):
def internal(*a, **kw):
i = web.input(_method='get')
if '_t' in i:
try:
t = background.threaddb[int(i._t)]
except KeyError:
return web.notfound()
web._context[threading.currentThread()] = web._context[t]
return
else:
return func(*a, **kw)
return internal
webpy/experimental/pwt.py 0000644 0001750 0001750 00000004426 11773133455 014240 0 ustar wmb wmb import web
import simplejson, sudo
urls = (
'/sudo', 'sudoku',
'/length', 'length',
)
class pwt(object):
_inFunc = False
updated = {}
page = """
"""
def GET(self):
web.header('Content-Type', 'text/html')
print self.page % self.form()
def POST(self):
i = web.input()
if '_' in i: del i['_']
#for k, v in i.iteritems(): setattr(self, k, v)
self._inFunc = True
self.work(**i)
self._inFunc = False
web.header('Content-Type', 'text/javascript')
print 'receive('+simplejson.dumps(self.updated)+');'
def __setattr__(self, k, v):
if self._inFunc and k != '_inFunc':
self.updated[k] = v
object.__setattr__(self, k, v)
class sudoku(pwt):
def form(self):
import sudo
out = ''
n = 0
for i in range(9):
for j in range(9):
out += '' % (sudo.squares[n])
n += 1
out += ' '
return out
def work(self, **kw):
values = dict((s, sudo.digits) for s in sudo.squares)
for k, v in kw.iteritems():
if v:
sudo.assign(values, k, v)
for k, v in values.iteritems():
if len(v) == 1:
setattr(self, k, v)
return values
class length(pwt):
def form(self):
return '
'
def work(self):
self.output = ('a' * web.intget(self.n, 0) or ' ')
if __name__ == "__main__":
web.run(urls, globals(), web.reloader) webpy/README.md 0000644 0001750 0001750 00000000305 11773133455 011626 0 ustar wmb wmb web.py is a web framework for Python that is as simple as it is powerful.
Visit http://webpy.org/ for more information.
 webpy/web/ 0002755 0001750 0001750 00000000000 11773133455 011130 5 ustar wmb wmb webpy/web/db.py 0000644 0001750 0001750 00000117474 11773133455 012103 0 ustar wmb wmb """
Database API
(part of web.py)
"""
__all__ = [
"UnknownParamstyle", "UnknownDB", "TransactionError",
"sqllist", "sqlors", "reparam", "sqlquote",
"SQLQuery", "SQLParam", "sqlparam",
"SQLLiteral", "sqlliteral",
"database", 'DB',
]
import time
try:
import datetime
except ImportError:
datetime = None
try: set
except NameError:
from sets import Set as set
from utils import threadeddict, storage, iters, iterbetter, safestr, safeunicode
try:
# db module can work independent of web.py
from webapi import debug, config
except:
import sys
debug = sys.stderr
config = storage()
class UnknownDB(Exception):
"""raised for unsupported dbms"""
pass
class _ItplError(ValueError):
def __init__(self, text, pos):
ValueError.__init__(self)
self.text = text
self.pos = pos
def __str__(self):
return "unfinished expression in %s at char %d" % (
repr(self.text), self.pos)
class TransactionError(Exception): pass
class UnknownParamstyle(Exception):
"""
raised for unsupported db paramstyles
(currently supported: qmark, numeric, format, pyformat)
"""
pass
class SQLParam(object):
"""
Parameter in SQLQuery.
>>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam("joe")])
>>> q
>>> q.query()
'SELECT * FROM test WHERE name=%s'
>>> q.values()
['joe']
"""
__slots__ = ["value"]
def __init__(self, value):
self.value = value
def get_marker(self, paramstyle='pyformat'):
if paramstyle == 'qmark':
return '?'
elif paramstyle == 'numeric':
return ':1'
elif paramstyle is None or paramstyle in ['format', 'pyformat']:
return '%s'
raise UnknownParamstyle, paramstyle
def sqlquery(self):
return SQLQuery([self])
def __add__(self, other):
return self.sqlquery() + other
def __radd__(self, other):
return other + self.sqlquery()
def __str__(self):
return str(self.value)
def __repr__(self):
return '' % repr(self.value)
sqlparam = SQLParam
class SQLQuery(object):
"""
You can pass this sort of thing as a clause in any db function.
Otherwise, you can pass a dictionary to the keyword argument `vars`
and the function will call reparam for you.
Internally, consists of `items`, which is a list of strings and
SQLParams, which get concatenated to produce the actual query.
"""
__slots__ = ["items"]
# tested in sqlquote's docstring
def __init__(self, items=None):
r"""Creates a new SQLQuery.
>>> SQLQuery("x")
>>> q = SQLQuery(['SELECT * FROM ', 'test', ' WHERE x=', SQLParam(1)])
>>> q
>>> q.query(), q.values()
('SELECT * FROM test WHERE x=%s', [1])
>>> SQLQuery(SQLParam(1))
"""
if items is None:
self.items = []
elif isinstance(items, list):
self.items = items
elif isinstance(items, SQLParam):
self.items = [items]
elif isinstance(items, SQLQuery):
self.items = list(items.items)
else:
self.items = [items]
# Take care of SQLLiterals
for i, item in enumerate(self.items):
if isinstance(item, SQLParam) and isinstance(item.value, SQLLiteral):
self.items[i] = item.value.v
def append(self, value):
self.items.append(value)
def __add__(self, other):
if isinstance(other, basestring):
items = [other]
elif isinstance(other, SQLQuery):
items = other.items
else:
return NotImplemented
return SQLQuery(self.items + items)
def __radd__(self, other):
if isinstance(other, basestring):
items = [other]
else:
return NotImplemented
return SQLQuery(items + self.items)
def __iadd__(self, other):
if isinstance(other, (basestring, SQLParam)):
self.items.append(other)
elif isinstance(other, SQLQuery):
self.items.extend(other.items)
else:
return NotImplemented
return self
def __len__(self):
return len(self.query())
def query(self, paramstyle=None):
"""
Returns the query part of the sql query.
>>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
>>> q.query()
'SELECT * FROM test WHERE name=%s'
>>> q.query(paramstyle='qmark')
'SELECT * FROM test WHERE name=?'
"""
s = []
for x in self.items:
if isinstance(x, SQLParam):
x = x.get_marker(paramstyle)
s.append(safestr(x))
else:
x = safestr(x)
# automatically escape % characters in the query
# For backward compatability, ignore escaping when the query looks already escaped
if paramstyle in ['format', 'pyformat']:
if '%' in x and '%%' not in x:
x = x.replace('%', '%%')
s.append(x)
return "".join(s)
def values(self):
"""
Returns the values of the parameters used in the sql query.
>>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')])
>>> q.values()
['joe']
"""
return [i.value for i in self.items if isinstance(i, SQLParam)]
def join(items, sep=' ', prefix=None, suffix=None, target=None):
"""
Joins multiple queries.
>>> SQLQuery.join(['a', 'b'], ', ')
Optinally, prefix and suffix arguments can be provided.
>>> SQLQuery.join(['a', 'b'], ', ', prefix='(', suffix=')')
If target argument is provided, the items are appended to target instead of creating a new SQLQuery.
"""
if target is None:
target = SQLQuery()
target_items = target.items
if prefix:
target_items.append(prefix)
for i, item in enumerate(items):
if i != 0:
target_items.append(sep)
if isinstance(item, SQLQuery):
target_items.extend(item.items)
else:
target_items.append(item)
if suffix:
target_items.append(suffix)
return target
join = staticmethod(join)
def _str(self):
try:
return self.query() % tuple([sqlify(x) for x in self.values()])
except (ValueError, TypeError):
return self.query()
def __str__(self):
return safestr(self._str())
def __unicode__(self):
return safeunicode(self._str())
def __repr__(self):
return '' % repr(str(self))
class SQLLiteral:
"""
Protects a string from `sqlquote`.
>>> sqlquote('NOW()')
>>> sqlquote(SQLLiteral('NOW()'))
"""
def __init__(self, v):
self.v = v
def __repr__(self):
return self.v
sqlliteral = SQLLiteral
def _sqllist(values):
"""
>>> _sqllist([1, 2, 3])
"""
items = []
items.append('(')
for i, v in enumerate(values):
if i != 0:
items.append(', ')
items.append(sqlparam(v))
items.append(')')
return SQLQuery(items)
def reparam(string_, dictionary):
"""
Takes a string and a dictionary and interpolates the string
using values from the dictionary. Returns an `SQLQuery` for the result.
>>> reparam("s = $s", dict(s=True))
>>> reparam("s IN $s", dict(s=[1, 2]))
"""
dictionary = dictionary.copy() # eval mucks with it
vals = []
result = []
for live, chunk in _interpolate(string_):
if live:
v = eval(chunk, dictionary)
result.append(sqlquote(v))
else:
result.append(chunk)
return SQLQuery.join(result, '')
def sqlify(obj):
"""
converts `obj` to its proper SQL version
>>> sqlify(None)
'NULL'
>>> sqlify(True)
"'t'"
>>> sqlify(3)
'3'
"""
# because `1 == True and hash(1) == hash(True)`
# we have to do this the hard way...
if obj is None:
return 'NULL'
elif obj is True:
return "'t'"
elif obj is False:
return "'f'"
elif datetime and isinstance(obj, datetime.datetime):
return repr(obj.isoformat())
else:
if isinstance(obj, unicode): obj = obj.encode('utf8')
return repr(obj)
def sqllist(lst):
"""
Converts the arguments for use in something like a WHERE clause.
>>> sqllist(['a', 'b'])
'a, b'
>>> sqllist('a')
'a'
>>> sqllist(u'abc')
u'abc'
"""
if isinstance(lst, basestring):
return lst
else:
return ', '.join(lst)
def sqlors(left, lst):
"""
`left is a SQL clause like `tablename.arg = `
and `lst` is a list of values. Returns a reparam-style
pair featuring the SQL that ORs together the clause
for each item in the lst.
>>> sqlors('foo = ', [])
>>> sqlors('foo = ', [1])
>>> sqlors('foo = ', 1)
>>> sqlors('foo = ', [1,2,3])
"""
if isinstance(lst, iters):
lst = list(lst)
ln = len(lst)
if ln == 0:
return SQLQuery("1=2")
if ln == 1:
lst = lst[0]
if isinstance(lst, iters):
return SQLQuery(['('] +
sum([[left, sqlparam(x), ' OR '] for x in lst], []) +
['1=2)']
)
else:
return left + sqlparam(lst)
def sqlwhere(dictionary, grouping=' AND '):
"""
Converts a `dictionary` to an SQL WHERE clause `SQLQuery`.
>>> sqlwhere({'cust_id': 2, 'order_id':3})
>>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ')
>>> sqlwhere({'a': 'a', 'b': 'b'}).query()
'a = %s AND b = %s'
"""
return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping)
def sqlquote(a):
"""
Ensures `a` is quoted properly for use in a SQL query.
>>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3)
>>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote([2, 3])
"""
if isinstance(a, list):
return _sqllist(a)
else:
return sqlparam(a).sqlquery()
class Transaction:
"""Database transaction."""
def __init__(self, ctx):
self.ctx = ctx
self.transaction_count = transaction_count = len(ctx.transactions)
class transaction_engine:
"""Transaction Engine used in top level transactions."""
def do_transact(self):
ctx.commit(unload=False)
def do_commit(self):
ctx.commit()
def do_rollback(self):
ctx.rollback()
class subtransaction_engine:
"""Transaction Engine used in sub transactions."""
def query(self, q):
db_cursor = ctx.db.cursor()
ctx.db_execute(db_cursor, SQLQuery(q % transaction_count))
def do_transact(self):
self.query('SAVEPOINT webpy_sp_%s')
def do_commit(self):
self.query('RELEASE SAVEPOINT webpy_sp_%s')
def do_rollback(self):
self.query('ROLLBACK TO SAVEPOINT webpy_sp_%s')
class dummy_engine:
"""Transaction Engine used instead of subtransaction_engine
when sub transactions are not supported."""
do_transact = do_commit = do_rollback = lambda self: None
if self.transaction_count:
# nested transactions are not supported in some databases
if self.ctx.get('ignore_nested_transactions'):
self.engine = dummy_engine()
else:
self.engine = subtransaction_engine()
else:
self.engine = transaction_engine()
self.engine.do_transact()
self.ctx.transactions.append(self)
def __enter__(self):
return self
def __exit__(self, exctype, excvalue, traceback):
if exctype is not None:
self.rollback()
else:
self.commit()
def commit(self):
if len(self.ctx.transactions) > self.transaction_count:
self.engine.do_commit()
self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
def rollback(self):
if len(self.ctx.transactions) > self.transaction_count:
self.engine.do_rollback()
self.ctx.transactions = self.ctx.transactions[:self.transaction_count]
class DB:
"""Database"""
def __init__(self, db_module, keywords):
"""Creates a database.
"""
# some DB implementaions take optional paramater `driver` to use a specific driver modue
# but it should not be passed to connect
keywords.pop('driver', None)
self.db_module = db_module
self.keywords = keywords
self._ctx = threadeddict()
# flag to enable/disable printing queries
self.printing = config.get('debug_sql', config.get('debug', False))
self.supports_multiple_insert = False
try:
import DBUtils
# enable pooling if DBUtils module is available.
self.has_pooling = True
except ImportError:
self.has_pooling = False
# Pooling can be disabled by passing pooling=False in the keywords.
self.has_pooling = self.keywords.pop('pooling', True) and self.has_pooling
def _getctx(self):
if not self._ctx.get('db'):
self._load_context(self._ctx)
return self._ctx
ctx = property(_getctx)
def _load_context(self, ctx):
ctx.dbq_count = 0
ctx.transactions = [] # stack of transactions
if self.has_pooling:
ctx.db = self._connect_with_pooling(self.keywords)
else:
ctx.db = self._connect(self.keywords)
ctx.db_execute = self._db_execute
if not hasattr(ctx.db, 'commit'):
ctx.db.commit = lambda: None
if not hasattr(ctx.db, 'rollback'):
ctx.db.rollback = lambda: None
def commit(unload=True):
# do db commit and release the connection if pooling is enabled.
ctx.db.commit()
if unload and self.has_pooling:
self._unload_context(self._ctx)
def rollback():
# do db rollback and release the connection if pooling is enabled.
ctx.db.rollback()
if self.has_pooling:
self._unload_context(self._ctx)
ctx.commit = commit
ctx.rollback = rollback
def _unload_context(self, ctx):
del ctx.db
def _connect(self, keywords):
return self.db_module.connect(**keywords)
def _connect_with_pooling(self, keywords):
def get_pooled_db():
from DBUtils import PooledDB
# In DBUtils 0.9.3, `dbapi` argument is renamed as `creator`
# see Bug#122112
if PooledDB.__version__.split('.') < '0.9.3'.split('.'):
return PooledDB.PooledDB(dbapi=self.db_module, **keywords)
else:
return PooledDB.PooledDB(creator=self.db_module, **keywords)
if getattr(self, '_pooleddb', None) is None:
self._pooleddb = get_pooled_db()
return self._pooleddb.connection()
def _db_cursor(self):
return self.ctx.db.cursor()
def _param_marker(self):
"""Returns parameter marker based on paramstyle attribute if this database."""
style = getattr(self, 'paramstyle', 'pyformat')
if style == 'qmark':
return '?'
elif style == 'numeric':
return ':1'
elif style in ['format', 'pyformat']:
return '%s'
raise UnknownParamstyle, style
def _db_execute(self, cur, sql_query):
"""executes an sql query"""
self.ctx.dbq_count += 1
try:
a = time.time()
query, params = self._process_query(sql_query)
out = cur.execute(query, params)
b = time.time()
except:
if self.printing:
print >> debug, 'ERR:', str(sql_query)
if self.ctx.transactions:
self.ctx.transactions[-1].rollback()
else:
self.ctx.rollback()
raise
if self.printing:
print >> debug, '%s (%s): %s' % (round(b-a, 2), self.ctx.dbq_count, str(sql_query))
return out
def _process_query(self, sql_query):
"""Takes the SQLQuery object and returns query string and parameters.
"""
paramstyle = getattr(self, 'paramstyle', 'pyformat')
query = sql_query.query(paramstyle)
params = sql_query.values()
return query, params
def _where(self, where, vars):
if isinstance(where, (int, long)):
where = "id = " + sqlparam(where)
#@@@ for backward-compatibility
elif isinstance(where, (list, tuple)) and len(where) == 2:
where = SQLQuery(where[0], where[1])
elif isinstance(where, SQLQuery):
pass
else:
where = reparam(where, vars)
return where
def query(self, sql_query, vars=None, processed=False, _test=False):
"""
Execute SQL query `sql_query` using dictionary `vars` to interpolate it.
If `processed=True`, `vars` is a `reparam`-style list to use
instead of interpolating.
>>> db = DB(None, {})
>>> db.query("SELECT * FROM foo", _test=True)
>>> db.query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True)
>>> db.query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True)
"""
if vars is None: vars = {}
if not processed and not isinstance(sql_query, SQLQuery):
sql_query = reparam(sql_query, vars)
if _test: return sql_query
db_cursor = self._db_cursor()
self._db_execute(db_cursor, sql_query)
if db_cursor.description:
names = [x[0] for x in db_cursor.description]
def iterwrapper():
row = db_cursor.fetchone()
while row:
yield storage(dict(zip(names, row)))
row = db_cursor.fetchone()
out = iterbetter(iterwrapper())
out.__len__ = lambda: int(db_cursor.rowcount)
out.list = lambda: [storage(dict(zip(names, x))) \
for x in db_cursor.fetchall()]
else:
out = db_cursor.rowcount
if not self.ctx.transactions:
self.ctx.commit()
return out
def select(self, tables, vars=None, what='*', where=None, order=None, group=None,
limit=None, offset=None, _test=False):
"""
Selects `what` from `tables` with clauses `where`, `order`,
`group`, `limit`, and `offset`. Uses vars to interpolate.
Otherwise, each clause can be a SQLQuery.
>>> db = DB(None, {})
>>> db.select('foo', _test=True)
>>> db.select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True)
"""
if vars is None: vars = {}
sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset)
clauses = [self.gen_clause(sql, val, vars) for sql, val in sql_clauses if val is not None]
qout = SQLQuery.join(clauses)
if _test: return qout
return self.query(qout, processed=True)
def where(self, table, what='*', order=None, group=None, limit=None,
offset=None, _test=False, **kwargs):
"""
Selects from `table` where keys are equal to values in `kwargs`.
>>> db = DB(None, {})
>>> db.where('foo', bar_id=3, _test=True)
>>> db.where('foo', source=2, crust='dewey', _test=True)
>>> db.where('foo', _test=True)
"""
where_clauses = []
for k, v in kwargs.iteritems():
where_clauses.append(k + ' = ' + sqlquote(v))
if where_clauses:
where = SQLQuery.join(where_clauses, " AND ")
else:
where = None
return self.select(table, what=what, order=order,
group=group, limit=limit, offset=offset, _test=_test,
where=where)
def sql_clauses(self, what, tables, where, group, order, limit, offset):
return (
('SELECT', what),
('FROM', sqllist(tables)),
('WHERE', where),
('GROUP BY', group),
('ORDER BY', order),
('LIMIT', limit),
('OFFSET', offset))
def gen_clause(self, sql, val, vars):
if isinstance(val, (int, long)):
if sql == 'WHERE':
nout = 'id = ' + sqlquote(val)
else:
nout = SQLQuery(val)
#@@@
elif isinstance(val, (list, tuple)) and len(val) == 2:
nout = SQLQuery(val[0], val[1]) # backwards-compatibility
elif isinstance(val, SQLQuery):
nout = val
else:
nout = reparam(val, vars)
def xjoin(a, b):
if a and b: return a + ' ' + b
else: return a or b
return xjoin(sql, nout)
def insert(self, tablename, seqname=None, _test=False, **values):
"""
Inserts `values` into `tablename`. Returns current sequence ID.
Set `seqname` to the ID if it's not the default, or to `False`
if there isn't one.
>>> db = DB(None, {})
>>> q = db.insert('foo', name='bob', age=2, created=SQLLiteral('NOW()'), _test=True)
>>> q
>>> q.query()
'INSERT INTO foo (age, name, created) VALUES (%s, %s, NOW())'
>>> q.values()
[2, 'bob']
"""
def q(x): return "(" + x + ")"
if values:
_keys = SQLQuery.join(values.keys(), ', ')
_values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ')
sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values)
else:
sql_query = SQLQuery(self._get_insert_default_values_query(tablename))
if _test: return sql_query
db_cursor = self._db_cursor()
if seqname is not False:
sql_query = self._process_insert_query(sql_query, tablename, seqname)
if isinstance(sql_query, tuple):
# for some databases, a separate query has to be made to find
# the id of the inserted row.
q1, q2 = sql_query
self._db_execute(db_cursor, q1)
self._db_execute(db_cursor, q2)
else:
self._db_execute(db_cursor, sql_query)
try:
out = db_cursor.fetchone()[0]
except Exception:
out = None
if not self.ctx.transactions:
self.ctx.commit()
return out
def _get_insert_default_values_query(self, table):
return "INSERT INTO %s DEFAULT VALUES" % table
def multiple_insert(self, tablename, values, seqname=None, _test=False):
"""
Inserts multiple rows into `tablename`. The `values` must be a list of dictioanries,
one for each row to be inserted, each with the same set of keys.
Returns the list of ids of the inserted rows.
Set `seqname` to the ID if it's not the default, or to `False`
if there isn't one.
>>> db = DB(None, {})
>>> db.supports_multiple_insert = True
>>> values = [{"name": "foo", "email": "foo@example.com"}, {"name": "bar", "email": "bar@example.com"}]
>>> db.multiple_insert('person', values=values, _test=True)
"""
if not values:
return []
if not self.supports_multiple_insert:
out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values]
if seqname is False:
return None
else:
return out
keys = values[0].keys()
#@@ make sure all keys are valid
# make sure all rows have same keys.
for v in values:
if v.keys() != keys:
raise ValueError, 'Bad data'
sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys)))
for i, row in enumerate(values):
if i != 0:
sql_query.append(", ")
SQLQuery.join([SQLParam(row[k]) for k in keys], sep=", ", target=sql_query, prefix="(", suffix=")")
if _test: return sql_query
db_cursor = self._db_cursor()
if seqname is not False:
sql_query = self._process_insert_query(sql_query, tablename, seqname)
if isinstance(sql_query, tuple):
# for some databases, a separate query has to be made to find
# the id of the inserted row.
q1, q2 = sql_query
self._db_execute(db_cursor, q1)
self._db_execute(db_cursor, q2)
else:
self._db_execute(db_cursor, sql_query)
try:
out = db_cursor.fetchone()[0]
out = range(out-len(values)+1, out+1)
except Exception:
out = None
if not self.ctx.transactions:
self.ctx.commit()
return out
def update(self, tables, where, vars=None, _test=False, **values):
"""
Update `tables` with clause `where` (interpolated using `vars`)
and setting `values`.
>>> db = DB(None, {})
>>> name = 'Joseph'
>>> q = db.update('foo', where='name = $name', name='bob', age=2,
... created=SQLLiteral('NOW()'), vars=locals(), _test=True)
>>> q
>>> q.query()
'UPDATE foo SET age = %s, name = %s, created = NOW() WHERE name = %s'
>>> q.values()
[2, 'bob', 'Joseph']
"""
if vars is None: vars = {}
where = self._where(where, vars)
query = (
"UPDATE " + sqllist(tables) +
" SET " + sqlwhere(values, ', ') +
" WHERE " + where)
if _test: return query
db_cursor = self._db_cursor()
self._db_execute(db_cursor, query)
if not self.ctx.transactions:
self.ctx.commit()
return db_cursor.rowcount
def delete(self, table, where, using=None, vars=None, _test=False):
"""
Deletes from `table` with clauses `where` and `using`.
>>> db = DB(None, {})
>>> name = 'Joe'
>>> db.delete('foo', where='name = $name', vars=locals(), _test=True)
"""
if vars is None: vars = {}
where = self._where(where, vars)
q = 'DELETE FROM ' + table
if using: q += ' USING ' + sqllist(using)
if where: q += ' WHERE ' + where
if _test: return q
db_cursor = self._db_cursor()
self._db_execute(db_cursor, q)
if not self.ctx.transactions:
self.ctx.commit()
return db_cursor.rowcount
def _process_insert_query(self, query, tablename, seqname):
return query
def transaction(self):
"""Start a transaction."""
return Transaction(self.ctx)
class PostgresDB(DB):
"""Postgres driver."""
def __init__(self, **keywords):
if 'pw' in keywords:
keywords['password'] = keywords.pop('pw')
db_module = import_driver(["psycopg2", "psycopg", "pgdb"], preferred=keywords.pop('driver', None))
if db_module.__name__ == "psycopg2":
import psycopg2.extensions
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
# if db is not provided postgres driver will take it from PGDATABASE environment variable
if 'db' in keywords:
keywords['database'] = keywords.pop('db')
self.dbname = "postgres"
self.paramstyle = db_module.paramstyle
DB.__init__(self, db_module, keywords)
self.supports_multiple_insert = True
self._sequences = None
def _process_insert_query(self, query, tablename, seqname):
if seqname is None:
# when seqname is not provided guess the seqname and make sure it exists
seqname = tablename + "_id_seq"
if seqname not in self._get_all_sequences():
seqname = None
if seqname:
query += "; SELECT currval('%s')" % seqname
return query
def _get_all_sequences(self):
"""Query postgres to find names of all sequences used in this database."""
if self._sequences is None:
q = "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'"
self._sequences = set([c.relname for c in self.query(q)])
return self._sequences
def _connect(self, keywords):
conn = DB._connect(self, keywords)
try:
conn.set_client_encoding('UTF8')
except AttributeError:
# fallback for pgdb driver
conn.cursor().execute("set client_encoding to 'UTF-8'")
return conn
def _connect_with_pooling(self, keywords):
conn = DB._connect_with_pooling(self, keywords)
conn._con._con.set_client_encoding('UTF8')
return conn
class MySQLDB(DB):
def __init__(self, **keywords):
import MySQLdb as db
if 'pw' in keywords:
keywords['passwd'] = keywords['pw']
del keywords['pw']
if 'charset' not in keywords:
keywords['charset'] = 'utf8'
elif keywords['charset'] is None:
del keywords['charset']
self.paramstyle = db.paramstyle = 'pyformat' # it's both, like psycopg
self.dbname = "mysql"
DB.__init__(self, db, keywords)
self.supports_multiple_insert = True
def _process_insert_query(self, query, tablename, seqname):
return query, SQLQuery('SELECT last_insert_id();')
def _get_insert_default_values_query(self, table):
return "INSERT INTO %s () VALUES()" % table
def import_driver(drivers, preferred=None):
"""Import the first available driver or preferred driver.
"""
if preferred:
drivers = [preferred]
for d in drivers:
try:
return __import__(d, None, None, ['x'])
except ImportError:
pass
raise ImportError("Unable to import " + " or ".join(drivers))
class SqliteDB(DB):
def __init__(self, **keywords):
db = import_driver(["sqlite3", "pysqlite2.dbapi2", "sqlite"], preferred=keywords.pop('driver', None))
if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]:
db.paramstyle = 'qmark'
# sqlite driver doesn't create datatime objects for timestamp columns unless `detect_types` option is passed.
# It seems to be supported in sqlite3 and pysqlite2 drivers, not surte about sqlite.
keywords.setdefault('detect_types', db.PARSE_DECLTYPES)
self.paramstyle = db.paramstyle
keywords['database'] = keywords.pop('db')
keywords['pooling'] = False # sqlite don't allows connections to be shared by threads
self.dbname = "sqlite"
DB.__init__(self, db, keywords)
def _process_insert_query(self, query, tablename, seqname):
return query, SQLQuery('SELECT last_insert_rowid();')
def query(self, *a, **kw):
out = DB.query(self, *a, **kw)
if isinstance(out, iterbetter):
del out.__len__
return out
class FirebirdDB(DB):
"""Firebird Database.
"""
def __init__(self, **keywords):
try:
import kinterbasdb as db
except Exception:
db = None
pass
if 'pw' in keywords:
keywords['passwd'] = keywords['pw']
del keywords['pw']
keywords['database'] = keywords['db']
del keywords['db']
DB.__init__(self, db, keywords)
def delete(self, table, where=None, using=None, vars=None, _test=False):
# firebird doesn't support using clause
using=None
return DB.delete(self, table, where, using, vars, _test)
def sql_clauses(self, what, tables, where, group, order, limit, offset):
return (
('SELECT', ''),
('FIRST', limit),
('SKIP', offset),
('', what),
('FROM', sqllist(tables)),
('WHERE', where),
('GROUP BY', group),
('ORDER BY', order)
)
class MSSQLDB(DB):
def __init__(self, **keywords):
import pymssql as db
if 'pw' in keywords:
keywords['password'] = keywords.pop('pw')
keywords['database'] = keywords.pop('db')
self.dbname = "mssql"
DB.__init__(self, db, keywords)
def _process_query(self, sql_query):
"""Takes the SQLQuery object and returns query string and parameters.
"""
# MSSQLDB expects params to be a tuple.
# Overwriting the default implementation to convert params to tuple.
paramstyle = getattr(self, 'paramstyle', 'pyformat')
query = sql_query.query(paramstyle)
params = sql_query.values()
return query, tuple(params)
def sql_clauses(self, what, tables, where, group, order, limit, offset):
return (
('SELECT', what),
('TOP', limit),
('FROM', sqllist(tables)),
('WHERE', where),
('GROUP BY', group),
('ORDER BY', order),
('OFFSET', offset))
def _test(self):
"""Test LIMIT.
Fake presence of pymssql module for running tests.
>>> import sys
>>> sys.modules['pymssql'] = sys.modules['sys']
MSSQL has TOP clause instead of LIMIT clause.
>>> db = MSSQLDB(db='test', user='joe', pw='secret')
>>> db.select('foo', limit=4, _test=True)
"""
pass
class OracleDB(DB):
def __init__(self, **keywords):
import cx_Oracle as db
if 'pw' in keywords:
keywords['password'] = keywords.pop('pw')
#@@ TODO: use db.makedsn if host, port is specified
keywords['dsn'] = keywords.pop('db')
self.dbname = 'oracle'
db.paramstyle = 'numeric'
self.paramstyle = db.paramstyle
# oracle doesn't support pooling
keywords.pop('pooling', None)
DB.__init__(self, db, keywords)
def _process_insert_query(self, query, tablename, seqname):
if seqname is None:
# It is not possible to get seq name from table name in Oracle
return query
else:
return query + "; SELECT %s.currval FROM dual" % seqname
_databases = {}
def database(dburl=None, **params):
"""Creates appropriate database using params.
Pooling will be enabled if DBUtils module is available.
Pooling can be disabled by passing pooling=False in params.
"""
dbn = params.pop('dbn')
if dbn in _databases:
return _databases[dbn](**params)
else:
raise UnknownDB, dbn
def register_database(name, clazz):
"""
Register a database.
>>> class LegacyDB(DB):
... def __init__(self, **params):
... pass
...
>>> register_database('legacy', LegacyDB)
>>> db = database(dbn='legacy', db='test', user='joe', passwd='secret')
"""
_databases[name] = clazz
register_database('mysql', MySQLDB)
register_database('postgres', PostgresDB)
register_database('sqlite', SqliteDB)
register_database('firebird', FirebirdDB)
register_database('mssql', MSSQLDB)
register_database('oracle', OracleDB)
def _interpolate(format):
"""
Takes a format string and returns a list of 2-tuples of the form
(boolean, string) where boolean says whether string should be evaled
or not.
from (public domain, Ka-Ping Yee)
"""
from tokenize import tokenprog
def matchorfail(text, pos):
match = tokenprog.match(text, pos)
if match is None:
raise _ItplError(text, pos)
return match, match.end()
namechars = "abcdefghijklmnopqrstuvwxyz" \
"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
chunks = []
pos = 0
while 1:
dollar = format.find("$", pos)
if dollar < 0:
break
nextchar = format[dollar + 1]
if nextchar == "{":
chunks.append((0, format[pos:dollar]))
pos, level = dollar + 2, 1
while level:
match, pos = matchorfail(format, pos)
tstart, tend = match.regs[3]
token = format[tstart:tend]
if token == "{":
level = level + 1
elif token == "}":
level = level - 1
chunks.append((1, format[dollar + 2:pos - 1]))
elif nextchar in namechars:
chunks.append((0, format[pos:dollar]))
match, pos = matchorfail(format, dollar + 1)
while pos < len(format):
if format[pos] == "." and \
pos + 1 < len(format) and format[pos + 1] in namechars:
match, pos = matchorfail(format, pos + 1)
elif format[pos] in "([":
pos, level = pos + 1, 1
while level:
match, pos = matchorfail(format, pos)
tstart, tend = match.regs[3]
token = format[tstart:tend]
if token[0] in "([":
level = level + 1
elif token[0] in ")]":
level = level - 1
else:
break
chunks.append((1, format[dollar + 1:pos]))
else:
chunks.append((0, format[pos:dollar + 1]))
pos = dollar + 1 + (nextchar == "$")
if pos < len(format):
chunks.append((0, format[pos:]))
return chunks
if __name__ == "__main__":
import doctest
doctest.testmod()
webpy/web/python23.py 0000644 0001750 0001750 00000002362 11773133455 013171 0 ustar wmb wmb """Python 2.3 compatabilty"""
import threading
class threadlocal(object):
"""Implementation of threading.local for python2.3.
"""
def __getattribute__(self, name):
if name == "__dict__":
return threadlocal._getd(self)
else:
try:
return object.__getattribute__(self, name)
except AttributeError:
try:
return self.__dict__[name]
except KeyError:
raise AttributeError, name
def __setattr__(self, name, value):
self.__dict__[name] = value
def __delattr__(self, name):
try:
del self.__dict__[name]
except KeyError:
raise AttributeError, name
def _getd(self):
t = threading.currentThread()
if not hasattr(t, '_d'):
# using __dict__ of thread as thread local storage
t._d = {}
_id = id(self)
# there could be multiple instances of threadlocal.
# use id(self) as key
if _id not in t._d:
t._d[_id] = {}
return t._d[_id]
if __name__ == '__main__':
d = threadlocal()
d.x = 1
print d.__dict__
print d.x
webpy/web/session.py 0000644 0001750 0001750 00000025017 11773133455 013170 0 ustar wmb wmb """
Session Management
(from web.py)
"""
import os, time, datetime, random, base64
import os.path
from copy import deepcopy
try:
import cPickle as pickle
except ImportError:
import pickle
try:
import hashlib
sha1 = hashlib.sha1
except ImportError:
import sha
sha1 = sha.new
import utils
import webapi as web
__all__ = [
'Session', 'SessionExpired',
'Store', 'DiskStore', 'DBStore',
]
web.config.session_parameters = utils.storage({
'cookie_name': 'webpy_session_id',
'cookie_domain': None,
'cookie_path' : None,
'timeout': 86400, #24 * 60 * 60, # 24 hours in seconds
'ignore_expiry': True,
'ignore_change_ip': True,
'secret_key': 'fLjUfxqXtfNoIldA0A0J',
'expired_message': 'Session expired',
'httponly': True,
'secure': False
})
class SessionExpired(web.HTTPError):
def __init__(self, message):
web.HTTPError.__init__(self, '200 OK', {}, data=message)
class Session(object):
"""Session management for web.py
"""
__slots__ = [
"store", "_initializer", "_last_cleanup_time", "_config", "_data",
"__getitem__", "__setitem__", "__delitem__"
]
def __init__(self, app, store, initializer=None):
self.store = store
self._initializer = initializer
self._last_cleanup_time = 0
self._config = utils.storage(web.config.session_parameters)
self._data = utils.threadeddict()
self.__getitem__ = self._data.__getitem__
self.__setitem__ = self._data.__setitem__
self.__delitem__ = self._data.__delitem__
if app:
app.add_processor(self._processor)
def __contains__(self, name):
return name in self._data
def __getattr__(self, name):
return getattr(self._data, name)
def __setattr__(self, name, value):
if name in self.__slots__:
object.__setattr__(self, name, value)
else:
setattr(self._data, name, value)
def __delattr__(self, name):
delattr(self._data, name)
def _processor(self, handler):
"""Application processor to setup session for every request"""
self._cleanup()
self._load()
try:
return handler()
finally:
self._save()
def _load(self):
"""Load the session from the store, by the id from cookie"""
cookie_name = self._config.cookie_name
cookie_domain = self._config.cookie_domain
cookie_path = self._config.cookie_path
httponly = self._config.httponly
self.session_id = web.cookies().get(cookie_name)
# protection against session_id tampering
if self.session_id and not self._valid_session_id(self.session_id):
self.session_id = None
self._check_expiry()
if self.session_id:
d = self.store[self.session_id]
self.update(d)
self._validate_ip()
if not self.session_id:
self.session_id = self._generate_session_id()
if self._initializer:
if isinstance(self._initializer, dict):
self.update(deepcopy(self._initializer))
elif hasattr(self._initializer, '__call__'):
self._initializer()
self.ip = web.ctx.ip
def _check_expiry(self):
# check for expiry
if self.session_id and self.session_id not in self.store:
if self._config.ignore_expiry:
self.session_id = None
else:
return self.expired()
def _validate_ip(self):
# check for change of IP
if self.session_id and self.get('ip', None) != web.ctx.ip:
if not self._config.ignore_change_ip:
return self.expired()
def _save(self):
if not self.get('_killed'):
self._setcookie(self.session_id)
self.store[self.session_id] = dict(self._data)
else:
self._setcookie(self.session_id, expires=-1)
def _setcookie(self, session_id, expires='', **kw):
cookie_name = self._config.cookie_name
cookie_domain = self._config.cookie_domain
cookie_path = self._config.cookie_path
httponly = self._config.httponly
secure = self._config.secure
web.setcookie(cookie_name, session_id, expires=expires, domain=cookie_domain, httponly=httponly, secure=secure, path=cookie_path)
def _generate_session_id(self):
"""Generate a random id for session"""
while True:
rand = os.urandom(16)
now = time.time()
secret_key = self._config.secret_key
session_id = sha1("%s%s%s%s" %(rand, now, utils.safestr(web.ctx.ip), secret_key))
session_id = session_id.hexdigest()
if session_id not in self.store:
break
return session_id
def _valid_session_id(self, session_id):
rx = utils.re_compile('^[0-9a-fA-F]+$')
return rx.match(session_id)
def _cleanup(self):
"""Cleanup the stored sessions"""
current_time = time.time()
timeout = self._config.timeout
if current_time - self._last_cleanup_time > timeout:
self.store.cleanup(timeout)
self._last_cleanup_time = current_time
def expired(self):
"""Called when an expired session is atime"""
self._killed = True
self._save()
raise SessionExpired(self._config.expired_message)
def kill(self):
"""Kill the session, make it no longer available"""
del self.store[self.session_id]
self._killed = True
class Store:
"""Base class for session stores"""
def __contains__(self, key):
raise NotImplementedError
def __getitem__(self, key):
raise NotImplementedError
def __setitem__(self, key, value):
raise NotImplementedError
def cleanup(self, timeout):
"""removes all the expired sessions"""
raise NotImplementedError
def encode(self, session_dict):
"""encodes session dict as a string"""
pickled = pickle.dumps(session_dict)
return base64.encodestring(pickled)
def decode(self, session_data):
"""decodes the data to get back the session dict """
pickled = base64.decodestring(session_data)
return pickle.loads(pickled)
class DiskStore(Store):
"""
Store for saving a session on disk.
>>> import tempfile
>>> root = tempfile.mkdtemp()
>>> s = DiskStore(root)
>>> s['a'] = 'foo'
>>> s['a']
'foo'
>>> time.sleep(0.01)
>>> s.cleanup(0.01)
>>> s['a']
Traceback (most recent call last):
...
KeyError: 'a'
"""
def __init__(self, root):
# if the storage root doesn't exists, create it.
if not os.path.exists(root):
os.makedirs(
os.path.abspath(root)
)
self.root = root
def _get_path(self, key):
if os.path.sep in key:
raise ValueError, "Bad key: %s" % repr(key)
return os.path.join(self.root, key)
def __contains__(self, key):
path = self._get_path(key)
return os.path.exists(path)
def __getitem__(self, key):
path = self._get_path(key)
if os.path.exists(path):
pickled = open(path).read()
return self.decode(pickled)
else:
raise KeyError, key
def __setitem__(self, key, value):
path = self._get_path(key)
pickled = self.encode(value)
try:
f = open(path, 'w')
try:
f.write(pickled)
finally:
f.close()
except IOError:
pass
def __delitem__(self, key):
path = self._get_path(key)
if os.path.exists(path):
os.remove(path)
def cleanup(self, timeout):
now = time.time()
for f in os.listdir(self.root):
path = self._get_path(f)
atime = os.stat(path).st_atime
if now - atime > timeout :
os.remove(path)
class DBStore(Store):
"""Store for saving a session in database
Needs a table with the following columns:
session_id CHAR(128) UNIQUE NOT NULL,
atime DATETIME NOT NULL default current_timestamp,
data TEXT
"""
def __init__(self, db, table_name):
self.db = db
self.table = table_name
def __contains__(self, key):
data = self.db.select(self.table, where="session_id=$key", vars=locals())
return bool(list(data))
def __getitem__(self, key):
now = datetime.datetime.now()
try:
s = self.db.select(self.table, where="session_id=$key", vars=locals())[0]
self.db.update(self.table, where="session_id=$key", atime=now, vars=locals())
except IndexError:
raise KeyError
else:
return self.decode(s.data)
def __setitem__(self, key, value):
pickled = self.encode(value)
now = datetime.datetime.now()
if key in self:
self.db.update(self.table, where="session_id=$key", data=pickled, vars=locals())
else:
self.db.insert(self.table, False, session_id=key, data=pickled )
def __delitem__(self, key):
self.db.delete(self.table, where="session_id=$key", vars=locals())
def cleanup(self, timeout):
timeout = datetime.timedelta(timeout/(24.0*60*60)) #timedelta takes numdays as arg
last_allowed_time = datetime.datetime.now() - timeout
self.db.delete(self.table, where="$last_allowed_time > atime", vars=locals())
class ShelfStore:
"""Store for saving session using `shelve` module.
import shelve
store = ShelfStore(shelve.open('session.shelf'))
XXX: is shelve thread-safe?
"""
def __init__(self, shelf):
self.shelf = shelf
def __contains__(self, key):
return key in self.shelf
def __getitem__(self, key):
atime, v = self.shelf[key]
self[key] = v # update atime
return v
def __setitem__(self, key, value):
self.shelf[key] = time.time(), value
def __delitem__(self, key):
try:
del self.shelf[key]
except KeyError:
pass
def cleanup(self, timeout):
now = time.time()
for k in self.shelf.keys():
atime, v = self.shelf[k]
if now - atime > timeout :
del self[k]
if __name__ == '__main__' :
import doctest
doctest.testmod()
webpy/web/__init__.py 0000644 0001750 0001750 00000001314 11773133455 013236 0 ustar wmb wmb #!/usr/bin/env python
"""web.py: makes web apps (http://webpy.org)"""
from __future__ import generators
__version__ = "0.37"
__author__ = [
"Aaron Swartz ",
"Anand Chitipothu "
]
__license__ = "public domain"
__contributors__ = "see http://webpy.org/changes"
import utils, db, net, wsgi, http, webapi, httpserver, debugerror
import template, form
import session
from utils import *
from db import *
from net import *
from wsgi import *
from http import *
from webapi import *
from httpserver import *
from debugerror import *
from application import *
from browser import *
try:
import webopenid as openid
except ImportError:
pass # requires openid module
webpy/web/net.py 0000644 0001750 0001750 00000011477 11773133455 012300 0 ustar wmb wmb """
Network Utilities
(from web.py)
"""
__all__ = [
"validipaddr", "validipport", "validip", "validaddr",
"urlquote",
"httpdate", "parsehttpdate",
"htmlquote", "htmlunquote", "websafe",
]
import urllib, time
try: import datetime
except ImportError: pass
def validipaddr(address):
"""
Returns True if `address` is a valid IPv4 address.
>>> validipaddr('192.168.1.1')
True
>>> validipaddr('192.168.1.800')
False
>>> validipaddr('192.168.1')
False
"""
try:
octets = address.split('.')
if len(octets) != 4:
return False
for x in octets:
if not (0 <= int(x) <= 255):
return False
except ValueError:
return False
return True
def validipport(port):
"""
Returns True if `port` is a valid IPv4 port.
>>> validipport('9000')
True
>>> validipport('foo')
False
>>> validipport('1000000')
False
"""
try:
if not (0 <= int(port) <= 65535):
return False
except ValueError:
return False
return True
def validip(ip, defaultaddr="0.0.0.0", defaultport=8080):
"""Returns `(ip_address, port)` from string `ip_addr_port`"""
addr = defaultaddr
port = defaultport
ip = ip.split(":", 1)
if len(ip) == 1:
if not ip[0]:
pass
elif validipaddr(ip[0]):
addr = ip[0]
elif validipport(ip[0]):
port = int(ip[0])
else:
raise ValueError, ':'.join(ip) + ' is not a valid IP address/port'
elif len(ip) == 2:
addr, port = ip
if not validipaddr(addr) and validipport(port):
raise ValueError, ':'.join(ip) + ' is not a valid IP address/port'
port = int(port)
else:
raise ValueError, ':'.join(ip) + ' is not a valid IP address/port'
return (addr, port)
def validaddr(string_):
"""
Returns either (ip_address, port) or "/path/to/socket" from string_
>>> validaddr('/path/to/socket')
'/path/to/socket'
>>> validaddr('8000')
('0.0.0.0', 8000)
>>> validaddr('127.0.0.1')
('127.0.0.1', 8080)
>>> validaddr('127.0.0.1:8000')
('127.0.0.1', 8000)
>>> validaddr('fff')
Traceback (most recent call last):
...
ValueError: fff is not a valid IP address/port
"""
if '/' in string_:
return string_
else:
return validip(string_)
def urlquote(val):
"""
Quotes a string for use in a URL.
>>> urlquote('://?f=1&j=1')
'%3A//%3Ff%3D1%26j%3D1'
>>> urlquote(None)
''
>>> urlquote(u'\u203d')
'%E2%80%BD'
"""
if val is None: return ''
if not isinstance(val, unicode): val = str(val)
else: val = val.encode('utf-8')
return urllib.quote(val)
def httpdate(date_obj):
"""
Formats a datetime object for use in HTTP headers.
>>> import datetime
>>> httpdate(datetime.datetime(1970, 1, 1, 1, 1, 1))
'Thu, 01 Jan 1970 01:01:01 GMT'
"""
return date_obj.strftime("%a, %d %b %Y %H:%M:%S GMT")
def parsehttpdate(string_):
"""
Parses an HTTP date into a datetime object.
>>> parsehttpdate('Thu, 01 Jan 1970 01:01:01 GMT')
datetime.datetime(1970, 1, 1, 1, 1, 1)
"""
try:
t = time.strptime(string_, "%a, %d %b %Y %H:%M:%S %Z")
except ValueError:
return None
return datetime.datetime(*t[:6])
def htmlquote(text):
r"""
Encodes `text` for raw use in HTML.
>>> htmlquote(u"<'&\">")
u'<'&">'
"""
text = text.replace(u"&", u"&") # Must be done first!
text = text.replace(u"<", u"<")
text = text.replace(u">", u">")
text = text.replace(u"'", u"'")
text = text.replace(u'"', u""")
return text
def htmlunquote(text):
r"""
Decodes `text` that's HTML quoted.
>>> htmlunquote(u'<'&">')
u'<\'&">'
"""
text = text.replace(u""", u'"')
text = text.replace(u"'", u"'")
text = text.replace(u">", u">")
text = text.replace(u"<", u"<")
text = text.replace(u"&", u"&") # Must be done last!
return text
def websafe(val):
r"""Converts `val` so that it is safe for use in Unicode HTML.
>>> websafe("<'&\">")
u'<'&">'
>>> websafe(None)
u''
>>> websafe(u'\u203d')
u'\u203d'
>>> websafe('\xe2\x80\xbd')
u'\u203d'
"""
if val is None:
return u''
elif isinstance(val, str):
val = val.decode('utf-8')
elif not isinstance(val, unicode):
val = unicode(val)
return htmlquote(val)
if __name__ == "__main__":
import doctest
doctest.testmod()
webpy/web/application.py 0000644 0001750 0001750 00000055404 11773133455 014013 0 ustar wmb wmb """
Web application
(from web.py)
"""
import webapi as web
import webapi, wsgi, utils
import debugerror
import httpserver
from utils import lstrips, safeunicode
import sys
import urllib
import traceback
import itertools
import os
import types
from exceptions import SystemExit
try:
import wsgiref.handlers
except ImportError:
pass # don't break people with old Pythons
__all__ = [
"application", "auto_application",
"subdir_application", "subdomain_application",
"loadhook", "unloadhook",
"autodelegate"
]
class application:
"""
Application to delegate requests based on path.
>>> urls = ("/hello", "hello")
>>> app = application(urls, globals())
>>> class hello:
... def GET(self): return "hello"
>>>
>>> app.request("/hello").data
'hello'
"""
def __init__(self, mapping=(), fvars={}, autoreload=None):
if autoreload is None:
autoreload = web.config.get('debug', False)
self.init_mapping(mapping)
self.fvars = fvars
self.processors = []
self.add_processor(loadhook(self._load))
self.add_processor(unloadhook(self._unload))
if autoreload:
def main_module_name():
mod = sys.modules['__main__']
file = getattr(mod, '__file__', None) # make sure this works even from python interpreter
return file and os.path.splitext(os.path.basename(file))[0]
def modname(fvars):
"""find name of the module name from fvars."""
file, name = fvars.get('__file__'), fvars.get('__name__')
if file is None or name is None:
return None
if name == '__main__':
# Since the __main__ module can't be reloaded, the module has
# to be imported using its file name.
name = main_module_name()
return name
mapping_name = utils.dictfind(fvars, mapping)
module_name = modname(fvars)
def reload_mapping():
"""loadhook to reload mapping and fvars."""
mod = __import__(module_name, None, None, [''])
mapping = getattr(mod, mapping_name, None)
if mapping:
self.fvars = mod.__dict__
self.init_mapping(mapping)
self.add_processor(loadhook(Reloader()))
if mapping_name and module_name:
self.add_processor(loadhook(reload_mapping))
# load __main__ module usings its filename, so that it can be reloaded.
if main_module_name() and '__main__' in sys.argv:
try:
__import__(main_module_name())
except ImportError:
pass
def _load(self):
web.ctx.app_stack.append(self)
def _unload(self):
web.ctx.app_stack = web.ctx.app_stack[:-1]
if web.ctx.app_stack:
# this is a sub-application, revert ctx to earlier state.
oldctx = web.ctx.get('_oldctx')
if oldctx:
web.ctx.home = oldctx.home
web.ctx.homepath = oldctx.homepath
web.ctx.path = oldctx.path
web.ctx.fullpath = oldctx.fullpath
def _cleanup(self):
# Threads can be recycled by WSGI servers.
# Clearing up all thread-local state to avoid interefereing with subsequent requests.
utils.ThreadedDict.clear_all()
def init_mapping(self, mapping):
self.mapping = list(utils.group(mapping, 2))
def add_mapping(self, pattern, classname):
self.mapping.append((pattern, classname))
def add_processor(self, processor):
"""
Adds a processor to the application.
>>> urls = ("/(.*)", "echo")
>>> app = application(urls, globals())
>>> class echo:
... def GET(self, name): return name
...
>>>
>>> def hello(handler): return "hello, " + handler()
...
>>> app.add_processor(hello)
>>> app.request("/web.py").data
'hello, web.py'
"""
self.processors.append(processor)
def request(self, localpart='/', method='GET', data=None,
host="0.0.0.0:8080", headers=None, https=False, **kw):
"""Makes request to this application for the specified path and method.
Response will be a storage object with data, status and headers.
>>> urls = ("/hello", "hello")
>>> app = application(urls, globals())
>>> class hello:
... def GET(self):
... web.header('Content-Type', 'text/plain')
... return "hello"
...
>>> response = app.request("/hello")
>>> response.data
'hello'
>>> response.status
'200 OK'
>>> response.headers['Content-Type']
'text/plain'
To use https, use https=True.
>>> urls = ("/redirect", "redirect")
>>> app = application(urls, globals())
>>> class redirect:
... def GET(self): raise web.seeother("/foo")
...
>>> response = app.request("/redirect")
>>> response.headers['Location']
'http://0.0.0.0:8080/foo'
>>> response = app.request("/redirect", https=True)
>>> response.headers['Location']
'https://0.0.0.0:8080/foo'
The headers argument specifies HTTP headers as a mapping object
such as a dict.
>>> urls = ('/ua', 'uaprinter')
>>> class uaprinter:
... def GET(self):
... return 'your user-agent is ' + web.ctx.env['HTTP_USER_AGENT']
...
>>> app = application(urls, globals())
>>> app.request('/ua', headers = {
... 'User-Agent': 'a small jumping bean/1.0 (compatible)'
... }).data
'your user-agent is a small jumping bean/1.0 (compatible)'
"""
path, maybe_query = urllib.splitquery(localpart)
query = maybe_query or ""
if 'env' in kw:
env = kw['env']
else:
env = {}
env = dict(env, HTTP_HOST=host, REQUEST_METHOD=method, PATH_INFO=path, QUERY_STRING=query, HTTPS=str(https))
headers = headers or {}
for k, v in headers.items():
env['HTTP_' + k.upper().replace('-', '_')] = v
if 'HTTP_CONTENT_LENGTH' in env:
env['CONTENT_LENGTH'] = env.pop('HTTP_CONTENT_LENGTH')
if 'HTTP_CONTENT_TYPE' in env:
env['CONTENT_TYPE'] = env.pop('HTTP_CONTENT_TYPE')
if method not in ["HEAD", "GET"]:
data = data or ''
import StringIO
if isinstance(data, dict):
q = urllib.urlencode(data)
else:
q = data
env['wsgi.input'] = StringIO.StringIO(q)
if not env.get('CONTENT_TYPE', '').lower().startswith('multipart/') and 'CONTENT_LENGTH' not in env:
env['CONTENT_LENGTH'] = len(q)
response = web.storage()
def start_response(status, headers):
response.status = status
response.headers = dict(headers)
response.header_items = headers
response.data = "".join(self.wsgifunc()(env, start_response))
return response
def browser(self):
import browser
return browser.AppBrowser(self)
def handle(self):
fn, args = self._match(self.mapping, web.ctx.path)
return self._delegate(fn, self.fvars, args)
def handle_with_processors(self):
def process(processors):
try:
if processors:
p, processors = processors[0], processors[1:]
return p(lambda: process(processors))
else:
return self.handle()
except web.HTTPError:
raise
except (KeyboardInterrupt, SystemExit):
raise
except:
print >> web.debug, traceback.format_exc()
raise self.internalerror()
# processors must be applied in the resvere order. (??)
return process(self.processors)
def wsgifunc(self, *middleware):
"""Returns a WSGI-compatible function for this application."""
def peep(iterator):
"""Peeps into an iterator by doing an iteration
and returns an equivalent iterator.
"""
# wsgi requires the headers first
# so we need to do an iteration
# and save the result for later
try:
firstchunk = iterator.next()
except StopIteration:
firstchunk = ''
return itertools.chain([firstchunk], iterator)
def is_generator(x): return x and hasattr(x, 'next')
def wsgi(env, start_resp):
# clear threadlocal to avoid inteference of previous requests
self._cleanup()
self.load(env)
try:
# allow uppercase methods only
if web.ctx.method.upper() != web.ctx.method:
raise web.nomethod()
result = self.handle_with_processors()
if is_generator(result):
result = peep(result)
else:
result = [result]
except web.HTTPError, e:
result = [e.data]
result = web.safestr(iter(result))
status, headers = web.ctx.status, web.ctx.headers
start_resp(status, headers)
def cleanup():
self._cleanup()
yield '' # force this function to be a generator
return itertools.chain(result, cleanup())
for m in middleware:
wsgi = m(wsgi)
return wsgi
def run(self, *middleware):
"""
Starts handling requests. If called in a CGI or FastCGI context, it will follow
that protocol. If called from the command line, it will start an HTTP
server on the port named in the first command line argument, or, if there
is no argument, on port 8080.
`middleware` is a list of WSGI middleware which is applied to the resulting WSGI
function.
"""
return wsgi.runwsgi(self.wsgifunc(*middleware))
def stop(self):
"""Stops the http server started by run.
"""
if httpserver.server:
httpserver.server.stop()
httpserver.server = None
def cgirun(self, *middleware):
"""
Return a CGI handler. This is mostly useful with Google App Engine.
There you can just do:
main = app.cgirun()
"""
wsgiapp = self.wsgifunc(*middleware)
try:
from google.appengine.ext.webapp.util import run_wsgi_app
return run_wsgi_app(wsgiapp)
except ImportError:
# we're not running from within Google App Engine
return wsgiref.handlers.CGIHandler().run(wsgiapp)
def load(self, env):
"""Initializes ctx using env."""
ctx = web.ctx
ctx.clear()
ctx.status = '200 OK'
ctx.headers = []
ctx.output = ''
ctx.environ = ctx.env = env
ctx.host = env.get('HTTP_HOST')
if env.get('wsgi.url_scheme') in ['http', 'https']:
ctx.protocol = env['wsgi.url_scheme']
elif env.get('HTTPS', '').lower() in ['on', 'true', '1']:
ctx.protocol = 'https'
else:
ctx.protocol = 'http'
ctx.homedomain = ctx.protocol + '://' + env.get('HTTP_HOST', '[unknown]')
ctx.homepath = os.environ.get('REAL_SCRIPT_NAME', env.get('SCRIPT_NAME', ''))
ctx.home = ctx.homedomain + ctx.homepath
#@@ home is changed when the request is handled to a sub-application.
#@@ but the real home is required for doing absolute redirects.
ctx.realhome = ctx.home
ctx.ip = env.get('REMOTE_ADDR')
ctx.method = env.get('REQUEST_METHOD')
ctx.path = env.get('PATH_INFO')
# http://trac.lighttpd.net/trac/ticket/406 requires:
if env.get('SERVER_SOFTWARE', '').startswith('lighttpd/'):
ctx.path = lstrips(env.get('REQUEST_URI').split('?')[0], ctx.homepath)
# Apache and CherryPy webservers unquote the url but lighttpd doesn't.
# unquote explicitly for lighttpd to make ctx.path uniform across all servers.
ctx.path = urllib.unquote(ctx.path)
if env.get('QUERY_STRING'):
ctx.query = '?' + env.get('QUERY_STRING', '')
else:
ctx.query = ''
ctx.fullpath = ctx.path + ctx.query
for k, v in ctx.iteritems():
# convert all string values to unicode values and replace
# malformed data with a suitable replacement marker.
if isinstance(v, str):
ctx[k] = v.decode('utf-8', 'replace')
# status must always be str
ctx.status = '200 OK'
ctx.app_stack = []
def _delegate(self, f, fvars, args=[]):
def handle_class(cls):
meth = web.ctx.method
if meth == 'HEAD' and not hasattr(cls, meth):
meth = 'GET'
if not hasattr(cls, meth):
raise web.nomethod(cls)
tocall = getattr(cls(), meth)
return tocall(*args)
def is_class(o): return isinstance(o, (types.ClassType, type))
if f is None:
raise web.notfound()
elif isinstance(f, application):
return f.handle_with_processors()
elif is_class(f):
return handle_class(f)
elif isinstance(f, basestring):
if f.startswith('redirect '):
url = f.split(' ', 1)[1]
if web.ctx.method == "GET":
x = web.ctx.env.get('QUERY_STRING', '')
if x:
url += '?' + x
raise web.redirect(url)
elif '.' in f:
mod, cls = f.rsplit('.', 1)
mod = __import__(mod, None, None, [''])
cls = getattr(mod, cls)
else:
cls = fvars[f]
return handle_class(cls)
elif hasattr(f, '__call__'):
return f()
else:
return web.notfound()
def _match(self, mapping, value):
for pat, what in mapping:
if isinstance(what, application):
if value.startswith(pat):
f = lambda: self._delegate_sub_application(pat, what)
return f, None
else:
continue
elif isinstance(what, basestring):
what, result = utils.re_subm('^' + pat + '$', what, value)
else:
result = utils.re_compile('^' + pat + '$').match(value)
if result: # it's a match
return what, [x for x in result.groups()]
return None, None
def _delegate_sub_application(self, dir, app):
"""Deletes request to sub application `app` rooted at the directory `dir`.
The home, homepath, path and fullpath values in web.ctx are updated to mimic request
to the subapp and are restored after it is handled.
@@Any issues with when used with yield?
"""
web.ctx._oldctx = web.storage(web.ctx)
web.ctx.home += dir
web.ctx.homepath += dir
web.ctx.path = web.ctx.path[len(dir):]
web.ctx.fullpath = web.ctx.fullpath[len(dir):]
return app.handle_with_processors()
def get_parent_app(self):
if self in web.ctx.app_stack:
index = web.ctx.app_stack.index(self)
if index > 0:
return web.ctx.app_stack[index-1]
def notfound(self):
"""Returns HTTPError with '404 not found' message"""
parent = self.get_parent_app()
if parent:
return parent.notfound()
else:
return web._NotFound()
def internalerror(self):
"""Returns HTTPError with '500 internal error' message"""
parent = self.get_parent_app()
if parent:
return parent.internalerror()
elif web.config.get('debug'):
import debugerror
return debugerror.debugerror()
else:
return web._InternalError()
class auto_application(application):
"""Application similar to `application` but urls are constructed
automatiacally using metaclass.
>>> app = auto_application()
>>> class hello(app.page):
... def GET(self): return "hello, world"
...
>>> class foo(app.page):
... path = '/foo/.*'
... def GET(self): return "foo"
>>> app.request("/hello").data
'hello, world'
>>> app.request('/foo/bar').data
'foo'
"""
def __init__(self):
application.__init__(self)
class metapage(type):
def __init__(klass, name, bases, attrs):
type.__init__(klass, name, bases, attrs)
path = attrs.get('path', '/' + name)
# path can be specified as None to ignore that class
# typically required to create a abstract base class.
if path is not None:
self.add_mapping(path, klass)
class page:
path = None
__metaclass__ = metapage
self.page = page
# The application class already has the required functionality of subdir_application
subdir_application = application
class subdomain_application(application):
"""
Application to delegate requests based on the host.
>>> urls = ("/hello", "hello")
>>> app = application(urls, globals())
>>> class hello:
... def GET(self): return "hello"
>>>
>>> mapping = (r"hello\.example\.com", app)
>>> app2 = subdomain_application(mapping)
>>> app2.request("/hello", host="hello.example.com").data
'hello'
>>> response = app2.request("/hello", host="something.example.com")
>>> response.status
'404 Not Found'
>>> response.data
'not found'
"""
def handle(self):
host = web.ctx.host.split(':')[0] #strip port
fn, args = self._match(self.mapping, host)
return self._delegate(fn, self.fvars, args)
def _match(self, mapping, value):
for pat, what in mapping:
if isinstance(what, basestring):
what, result = utils.re_subm('^' + pat + '$', what, value)
else:
result = utils.re_compile('^' + pat + '$').match(value)
if result: # it's a match
return what, [x for x in result.groups()]
return None, None
def loadhook(h):
"""
Converts a load hook into an application processor.
>>> app = auto_application()
>>> def f(): "something done before handling request"
...
>>> app.add_processor(loadhook(f))
"""
def processor(handler):
h()
return handler()
return processor
def unloadhook(h):
"""
Converts an unload hook into an application processor.
>>> app = auto_application()
>>> def f(): "something done after handling request"
...
>>> app.add_processor(unloadhook(f))
"""
def processor(handler):
try:
result = handler()
is_generator = result and hasattr(result, 'next')
except:
# run the hook even when handler raises some exception
h()
raise
if is_generator:
return wrap(result)
else:
h()
return result
def wrap(result):
def next():
try:
return result.next()
except:
# call the hook at the and of iterator
h()
raise
result = iter(result)
while True:
yield next()
return processor
def autodelegate(prefix=''):
"""
Returns a method that takes one argument and calls the method named prefix+arg,
calling `notfound()` if there isn't one. Example:
urls = ('/prefs/(.*)', 'prefs')
class prefs:
GET = autodelegate('GET_')
def GET_password(self): pass
def GET_privacy(self): pass
`GET_password` would get called for `/prefs/password` while `GET_privacy` for
`GET_privacy` gets called for `/prefs/privacy`.
If a user visits `/prefs/password/change` then `GET_password(self, '/change')`
is called.
"""
def internal(self, arg):
if '/' in arg:
first, rest = arg.split('/', 1)
func = prefix + first
args = ['/' + rest]
else:
func = prefix + arg
args = []
if hasattr(self, func):
try:
return getattr(self, func)(*args)
except TypeError:
raise web.notfound()
else:
raise web.notfound()
return internal
class Reloader:
"""Checks to see if any loaded modules have changed on disk and,
if so, reloads them.
"""
"""File suffix of compiled modules."""
if sys.platform.startswith('java'):
SUFFIX = '$py.class'
else:
SUFFIX = '.pyc'
def __init__(self):
self.mtimes = {}
def __call__(self):
for mod in sys.modules.values():
self.check(mod)
def check(self, mod):
# jython registers java packages as modules but they either
# don't have a __file__ attribute or its value is None
if not (mod and hasattr(mod, '__file__') and mod.__file__):
return
try:
mtime = os.stat(mod.__file__).st_mtime
except (OSError, IOError):
return
if mod.__file__.endswith(self.__class__.SUFFIX) and os.path.exists(mod.__file__[:-1]):
mtime = max(os.stat(mod.__file__[:-1]).st_mtime, mtime)
if mod not in self.mtimes:
self.mtimes[mod] = mtime
elif self.mtimes[mod] < mtime:
try:
reload(mod)
self.mtimes[mod] = mtime
except ImportError:
pass
if __name__ == "__main__":
import doctest
doctest.testmod()
webpy/web/webapi.py 0000644 0001750 0001750 00000037720 11773133455 012760 0 ustar wmb wmb """
Web API (wrapper around WSGI)
(from web.py)
"""
__all__ = [
"config",
"header", "debug",
"input", "data",
"setcookie", "cookies",
"ctx",
"HTTPError",
# 200, 201, 202
"OK", "Created", "Accepted",
"ok", "created", "accepted",
# 301, 302, 303, 304, 307
"Redirect", "Found", "SeeOther", "NotModified", "TempRedirect",
"redirect", "found", "seeother", "notmodified", "tempredirect",
# 400, 401, 403, 404, 405, 406, 409, 410, 412, 415
"BadRequest", "Unauthorized", "Forbidden", "NotFound", "NoMethod", "NotAcceptable", "Conflict", "Gone", "PreconditionFailed", "UnsupportedMediaType",
"badrequest", "unauthorized", "forbidden", "notfound", "nomethod", "notacceptable", "conflict", "gone", "preconditionfailed", "unsupportedmediatype",
# 500
"InternalError",
"internalerror",
]
import sys, cgi, Cookie, pprint, urlparse, urllib
from utils import storage, storify, threadeddict, dictadd, intget, safestr
config = storage()
config.__doc__ = """
A configuration object for various aspects of web.py.
`debug`
: when True, enables reloading, disabled template caching and sets internalerror to debugerror.
"""
class HTTPError(Exception):
def __init__(self, status, headers={}, data=""):
ctx.status = status
for k, v in headers.items():
header(k, v)
self.data = data
Exception.__init__(self, status)
def _status_code(status, data=None, classname=None, docstring=None):
if data is None:
data = status.split(" ", 1)[1]
classname = status.split(" ", 1)[1].replace(' ', '') # 304 Not Modified -> NotModified
docstring = docstring or '`%s` status' % status
def __init__(self, data=data, headers={}):
HTTPError.__init__(self, status, headers, data)
# trick to create class dynamically with dynamic docstring.
return type(classname, (HTTPError, object), {
'__doc__': docstring,
'__init__': __init__
})
ok = OK = _status_code("200 OK", data="")
created = Created = _status_code("201 Created")
accepted = Accepted = _status_code("202 Accepted")
class Redirect(HTTPError):
"""A `301 Moved Permanently` redirect."""
def __init__(self, url, status='301 Moved Permanently', absolute=False):
"""
Returns a `status` redirect to the new URL.
`url` is joined with the base URL so that things like
`redirect("about") will work properly.
"""
newloc = urlparse.urljoin(ctx.path, url)
if newloc.startswith('/'):
if absolute:
home = ctx.realhome
else:
home = ctx.home
newloc = home + newloc
headers = {
'Content-Type': 'text/html',
'Location': newloc
}
HTTPError.__init__(self, status, headers, "")
redirect = Redirect
class Found(Redirect):
"""A `302 Found` redirect."""
def __init__(self, url, absolute=False):
Redirect.__init__(self, url, '302 Found', absolute=absolute)
found = Found
class SeeOther(Redirect):
"""A `303 See Other` redirect."""
def __init__(self, url, absolute=False):
Redirect.__init__(self, url, '303 See Other', absolute=absolute)
seeother = SeeOther
class NotModified(HTTPError):
"""A `304 Not Modified` status."""
def __init__(self):
HTTPError.__init__(self, "304 Not Modified")
notmodified = NotModified
class TempRedirect(Redirect):
"""A `307 Temporary Redirect` redirect."""
def __init__(self, url, absolute=False):
Redirect.__init__(self, url, '307 Temporary Redirect', absolute=absolute)
tempredirect = TempRedirect
class BadRequest(HTTPError):
"""`400 Bad Request` error."""
message = "bad request"
def __init__(self, message=None):
status = "400 Bad Request"
headers = {'Content-Type': 'text/html'}
HTTPError.__init__(self, status, headers, message or self.message)
badrequest = BadRequest
class Unauthorized(HTTPError):
"""`401 Unauthorized` error."""
message = "unauthorized"
def __init__(self):
status = "401 Unauthorized"
headers = {'Content-Type': 'text/html'}
HTTPError.__init__(self, status, headers, self.message)
unauthorized = Unauthorized
class Forbidden(HTTPError):
"""`403 Forbidden` error."""
message = "forbidden"
def __init__(self):
status = "403 Forbidden"
headers = {'Content-Type': 'text/html'}
HTTPError.__init__(self, status, headers, self.message)
forbidden = Forbidden
class _NotFound(HTTPError):
"""`404 Not Found` error."""
message = "not found"
def __init__(self, message=None):
status = '404 Not Found'
headers = {'Content-Type': 'text/html'}
HTTPError.__init__(self, status, headers, message or self.message)
def NotFound(message=None):
"""Returns HTTPError with '404 Not Found' error from the active application.
"""
if message:
return _NotFound(message)
elif ctx.get('app_stack'):
return ctx.app_stack[-1].notfound()
else:
return _NotFound()
notfound = NotFound
class NoMethod(HTTPError):
"""A `405 Method Not Allowed` error."""
def __init__(self, cls=None):
status = '405 Method Not Allowed'
headers = {}
headers['Content-Type'] = 'text/html'
methods = ['GET', 'HEAD', 'POST', 'PUT', 'DELETE']
if cls:
methods = [method for method in methods if hasattr(cls, method)]
headers['Allow'] = ', '.join(methods)
data = None
HTTPError.__init__(self, status, headers, data)
nomethod = NoMethod
class NotAcceptable(HTTPError):
"""`406 Not Acceptable` error."""
message = "not acceptable"
def __init__(self):
status = "406 Not Acceptable"
headers = {'Content-Type': 'text/html'}
HTTPError.__init__(self, status, headers, self.message)
notacceptable = NotAcceptable
class Conflict(HTTPError):
"""`409 Conflict` error."""
message = "conflict"
def __init__(self):
status = "409 Conflict"
headers = {'Content-Type': 'text/html'}
HTTPError.__init__(self, status, headers, self.message)
conflict = Conflict
class Gone(HTTPError):
"""`410 Gone` error."""
message = "gone"
def __init__(self):
status = '410 Gone'
headers = {'Content-Type': 'text/html'}
HTTPError.__init__(self, status, headers, self.message)
gone = Gone
class PreconditionFailed(HTTPError):
"""`412 Precondition Failed` error."""
message = "precondition failed"
def __init__(self):
status = "412 Precondition Failed"
headers = {'Content-Type': 'text/html'}
HTTPError.__init__(self, status, headers, self.message)
preconditionfailed = PreconditionFailed
class UnsupportedMediaType(HTTPError):
"""`415 Unsupported Media Type` error."""
message = "unsupported media type"
def __init__(self):
status = "415 Unsupported Media Type"
headers = {'Content-Type': 'text/html'}
HTTPError.__init__(self, status, headers, self.message)
unsupportedmediatype = UnsupportedMediaType
class _InternalError(HTTPError):
"""500 Internal Server Error`."""
message = "internal server error"
def __init__(self, message=None):
status = '500 Internal Server Error'
headers = {'Content-Type': 'text/html'}
HTTPError.__init__(self, status, headers, message or self.message)
def InternalError(message=None):
"""Returns HTTPError with '500 internal error' error from the active application.
"""
if message:
return _InternalError(message)
elif ctx.get('app_stack'):
return ctx.app_stack[-1].internalerror()
else:
return _InternalError()
internalerror = InternalError
def header(hdr, value, unique=False):
"""
Adds the header `hdr: value` with the response.
If `unique` is True and a header with that name already exists,
it doesn't add a new one.
"""
hdr, value = safestr(hdr), safestr(value)
# protection against HTTP response splitting attack
if '\n' in hdr or '\r' in hdr or '\n' in value or '\r' in value:
raise ValueError, 'invalid characters in header'
if unique is True:
for h, v in ctx.headers:
if h.lower() == hdr.lower(): return
ctx.headers.append((hdr, value))
def rawinput(method=None):
"""Returns storage object with GET or POST arguments.
"""
method = method or "both"
from cStringIO import StringIO
def dictify(fs):
# hack to make web.input work with enctype='text/plain.
if fs.list is None:
fs.list = []
return dict([(k, fs[k]) for k in fs.keys()])
e = ctx.env.copy()
a = b = {}
if method.lower() in ['both', 'post', 'put']:
if e['REQUEST_METHOD'] in ['POST', 'PUT']:
if e.get('CONTENT_TYPE', '').lower().startswith('multipart/'):
# since wsgi.input is directly passed to cgi.FieldStorage,
# it can not be called multiple times. Saving the FieldStorage
# object in ctx to allow calling web.input multiple times.
a = ctx.get('_fieldstorage')
if not a:
fp = e['wsgi.input']
a = cgi.FieldStorage(fp=fp, environ=e, keep_blank_values=1)
ctx._fieldstorage = a
else:
fp = StringIO(data())
a = cgi.FieldStorage(fp=fp, environ=e, keep_blank_values=1)
a = dictify(a)
if method.lower() in ['both', 'get']:
e['REQUEST_METHOD'] = 'GET'
b = dictify(cgi.FieldStorage(environ=e, keep_blank_values=1))
def process_fieldstorage(fs):
if isinstance(fs, list):
return [process_fieldstorage(x) for x in fs]
elif fs.filename is None:
return fs.value
else:
return fs
return storage([(k, process_fieldstorage(v)) for k, v in dictadd(b, a).items()])
def input(*requireds, **defaults):
"""
Returns a `storage` object with the GET and POST arguments.
See `storify` for how `requireds` and `defaults` work.
"""
_method = defaults.pop('_method', 'both')
out = rawinput(_method)
try:
defaults.setdefault('_unicode', True) # force unicode conversion by default.
return storify(out, *requireds, **defaults)
except KeyError:
raise badrequest()
def data():
"""Returns the data sent with the request."""
if 'data' not in ctx:
cl = intget(ctx.env.get('CONTENT_LENGTH'), 0)
ctx.data = ctx.env['wsgi.input'].read(cl)
return ctx.data
def setcookie(name, value, expires='', domain=None,
secure=False, httponly=False, path=None):
"""Sets a cookie."""
morsel = Cookie.Morsel()
name, value = safestr(name), safestr(value)
morsel.set(name, value, urllib.quote(value))
if expires < 0:
expires = -1000000000
morsel['expires'] = expires
morsel['path'] = path or ctx.homepath+'/'
if domain:
morsel['domain'] = domain
if secure:
morsel['secure'] = secure
value = morsel.OutputString()
if httponly:
value += '; httponly'
header('Set-Cookie', value)
def decode_cookie(value):
r"""Safely decodes a cookie value to unicode.
Tries us-ascii, utf-8 and io8859 encodings, in that order.
>>> decode_cookie('')
u''
>>> decode_cookie('asdf')
u'asdf'
>>> decode_cookie('foo \xC3\xA9 bar')
u'foo \xe9 bar'
>>> decode_cookie('foo \xE9 bar')
u'foo \xe9 bar'
"""
try:
# First try plain ASCII encoding
return unicode(value, 'us-ascii')
except UnicodeError:
# Then try UTF-8, and if that fails, ISO8859
try:
return unicode(value, 'utf-8')
except UnicodeError:
return unicode(value, 'iso8859', 'ignore')
def parse_cookies(http_cookie):
r"""Parse a HTTP_COOKIE header and return dict of cookie names and decoded values.
>>> sorted(parse_cookies('').items())
[]
>>> sorted(parse_cookies('a=1').items())
[('a', '1')]
>>> sorted(parse_cookies('a=1%202').items())
[('a', '1 2')]
>>> sorted(parse_cookies('a=Z%C3%A9Z').items())
[('a', 'Z\xc3\xa9Z')]
>>> sorted(parse_cookies('a=1; b=2; c=3').items())
[('a', '1'), ('b', '2'), ('c', '3')]
>>> sorted(parse_cookies('a=1; b=w("x")|y=z; c=3').items())
[('a', '1'), ('b', 'w('), ('c', '3')]
>>> sorted(parse_cookies('a=1; b=w(%22x%22)|y=z; c=3').items())
[('a', '1'), ('b', 'w("x")|y=z'), ('c', '3')]
>>> sorted(parse_cookies('keebler=E=mc2').items())
[('keebler', 'E=mc2')]
>>> sorted(parse_cookies(r'keebler="E=mc2; L=\"Loves\"; fudge=\012;"').items())
[('keebler', 'E=mc2; L="Loves"; fudge=\n;')]
"""
#print "parse_cookies"
if '"' in http_cookie:
# HTTP_COOKIE has quotes in it, use slow but correct cookie parsing
cookie = Cookie.SimpleCookie()
try:
cookie.load(http_cookie)
except Cookie.CookieError:
# If HTTP_COOKIE header is malformed, try at least to load the cookies we can by
# first splitting on ';' and loading each attr=value pair separately
cookie = Cookie.SimpleCookie()
for attr_value in http_cookie.split(';'):
try:
cookie.load(attr_value)
except Cookie.CookieError:
pass
cookies = dict((k, urllib.unquote(v.value)) for k, v in cookie.iteritems())
else:
# HTTP_COOKIE doesn't have quotes, use fast cookie parsing
cookies = {}
for key_value in http_cookie.split(';'):
key_value = key_value.split('=', 1)
if len(key_value) == 2:
key, value = key_value
cookies[key.strip()] = urllib.unquote(value.strip())
return cookies
def cookies(*requireds, **defaults):
r"""Returns a `storage` object with all the request cookies in it.
See `storify` for how `requireds` and `defaults` work.
This is forgiving on bad HTTP_COOKIE input, it tries to parse at least
the cookies it can.
The values are converted to unicode if _unicode=True is passed.
"""
# If _unicode=True is specified, use decode_cookie to convert cookie value to unicode
if defaults.get("_unicode") is True:
defaults['_unicode'] = decode_cookie
# parse cookie string and cache the result for next time.
if '_parsed_cookies' not in ctx:
http_cookie = ctx.env.get("HTTP_COOKIE", "")
ctx._parsed_cookies = parse_cookies(http_cookie)
try:
return storify(ctx._parsed_cookies, *requireds, **defaults)
except KeyError:
badrequest()
raise StopIteration
def debug(*args):
"""
Prints a prettyprinted version of `args` to stderr.
"""
try:
out = ctx.environ['wsgi.errors']
except:
out = sys.stderr
for arg in args:
print >> out, pprint.pformat(arg)
return ''
def _debugwrite(x):
try:
out = ctx.environ['wsgi.errors']
except:
out = sys.stderr
out.write(x)
debug.write = _debugwrite
ctx = context = threadeddict()
ctx.__doc__ = """
A `storage` object containing various information about the request:
`environ` (aka `env`)
: A dictionary containing the standard WSGI environment variables.
`host`
: The domain (`Host` header) requested by the user.
`home`
: The base path for the application.
`ip`
: The IP address of the requester.
`method`
: The HTTP method used.
`path`
: The path request.
`query`
: If there are no query arguments, the empty string. Otherwise, a `?` followed
by the query string.
`fullpath`
: The full path requested, including query arguments (`== path + query`).
### Response Data
`status` (default: "200 OK")
: The status code to be used in the response.
`headers`
: A list of 2-tuples to be used in the response.
`output`
: A string to be used as the response.
"""
if __name__ == "__main__":
import doctest
doctest.testmod() webpy/web/form.py 0000644 0001750 0001750 00000032200 11773133455 012440 0 ustar wmb wmb """
HTML forms
(part of web.py)
"""
import copy, re
import webapi as web
import utils, net
def attrget(obj, attr, value=None):
try:
if hasattr(obj, 'has_key') and obj.has_key(attr):
return obj[attr]
except TypeError:
# Handle the case where has_key takes different number of arguments.
# This is the case with Model objects on appengine. See #134
pass
if hasattr(obj, attr):
return getattr(obj, attr)
return value
class Form(object):
r"""
HTML form.
>>> f = Form(Textbox("x"))
>>> f.render()
u'
\n
\n
'
"""
def __init__(self, *inputs, **kw):
self.inputs = inputs
self.valid = True
self.note = None
self.validators = kw.pop('validators', [])
def __call__(self, x=None):
o = copy.deepcopy(self)
if x: o.validates(x)
return o
def render(self):
out = ''
out += self.rendernote(self.note)
out += '
\n'
for i in self.inputs:
html = utils.safeunicode(i.pre) + i.render() + self.rendernote(i.note) + utils.safeunicode(i.post)
if i.is_hidden():
out += '
%s
\n' % (html)
else:
out += '
%s
\n' % (i.id, net.websafe(i.description), html)
out += "
"
return out
def render_css(self):
out = []
out.append(self.rendernote(self.note))
for i in self.inputs:
if not i.is_hidden():
out.append('' % (i.id, net.websafe(i.description)))
out.append(i.pre)
out.append(i.render())
out.append(self.rendernote(i.note))
out.append(i.post)
out.append('\n')
return ''.join(out)
def rendernote(self, note):
if note: return '%s' % net.websafe(note)
else: return ""
def validates(self, source=None, _validate=True, **kw):
source = source or kw or web.input()
out = True
for i in self.inputs:
v = attrget(source, i.name)
if _validate:
out = i.validate(v) and out
else:
i.set_value(v)
if _validate:
out = out and self._validate(source)
self.valid = out
return out
def _validate(self, value):
self.value = value
for v in self.validators:
if not v.valid(value):
self.note = v.msg
return False
return True
def fill(self, source=None, **kw):
return self.validates(source, _validate=False, **kw)
def __getitem__(self, i):
for x in self.inputs:
if x.name == i: return x
raise KeyError, i
def __getattr__(self, name):
# don't interfere with deepcopy
inputs = self.__dict__.get('inputs') or []
for x in inputs:
if x.name == name: return x
raise AttributeError, name
def get(self, i, default=None):
try:
return self[i]
except KeyError:
return default
def _get_d(self): #@@ should really be form.attr, no?
return utils.storage([(i.name, i.get_value()) for i in self.inputs])
d = property(_get_d)
class Input(object):
def __init__(self, name, *validators, **attrs):
self.name = name
self.validators = validators
self.attrs = attrs = AttributeList(attrs)
self.description = attrs.pop('description', name)
self.value = attrs.pop('value', None)
self.pre = attrs.pop('pre', "")
self.post = attrs.pop('post', "")
self.note = None
self.id = attrs.setdefault('id', self.get_default_id())
if 'class_' in attrs:
attrs['class'] = attrs['class_']
del attrs['class_']
def is_hidden(self):
return False
def get_type(self):
raise NotImplementedError
def get_default_id(self):
return self.name
def validate(self, value):
self.set_value(value)
for v in self.validators:
if not v.valid(value):
self.note = v.msg
return False
return True
def set_value(self, value):
self.value = value
def get_value(self):
return self.value
def render(self):
attrs = self.attrs.copy()
attrs['type'] = self.get_type()
if self.value is not None:
attrs['value'] = self.value
attrs['name'] = self.name
return '' % attrs
def rendernote(self, note):
if note: return '%s' % net.websafe(note)
else: return ""
def addatts(self):
# add leading space for backward-compatibility
return " " + str(self.attrs)
class AttributeList(dict):
"""List of atributes of input.
>>> a = AttributeList(type='text', name='x', value=20)
>>> a
"""
def copy(self):
return AttributeList(self)
def __str__(self):
return " ".join(['%s="%s"' % (k, net.websafe(v)) for k, v in self.items()])
def __repr__(self):
return '' % repr(str(self))
class Textbox(Input):
"""Textbox input.
>>> Textbox(name='foo', value='bar').render()
u''
>>> Textbox(name='foo', value=0).render()
u''
"""
def get_type(self):
return 'text'
class Password(Input):
"""Password input.
>>> Password(name='password', value='secret').render()
u''
"""
def get_type(self):
return 'password'
class Textarea(Input):
"""Textarea input.
>>> Textarea(name='foo', value='bar').render()
u''
"""
def render(self):
attrs = self.attrs.copy()
attrs['name'] = self.name
value = net.websafe(self.value or '')
return '' % (attrs, value)
class Dropdown(Input):
r"""Dropdown/select input.
>>> Dropdown(name='foo', args=['a', 'b', 'c'], value='b').render()
u'\n'
>>> Dropdown(name='foo', args=[('a', 'aa'), ('b', 'bb'), ('c', 'cc')], value='b').render()
u'\n'
"""
def __init__(self, name, args, *validators, **attrs):
self.args = args
super(Dropdown, self).__init__(name, *validators, **attrs)
def render(self):
attrs = self.attrs.copy()
attrs['name'] = self.name
x = '\n'
return x
def _render_option(self, arg, indent=' '):
if isinstance(arg, (tuple, list)):
value, desc= arg
else:
value, desc = arg, arg
if self.value == value or (isinstance(self.value, list) and value in self.value):
select_p = ' selected="selected"'
else:
select_p = ''
return indent + '\n' % (select_p, net.websafe(value), net.websafe(desc))
class GroupedDropdown(Dropdown):
r"""Grouped Dropdown/select input.
>>> GroupedDropdown(name='car_type', args=(('Swedish Cars', ('Volvo', 'Saab')), ('German Cars', ('Mercedes', 'Audi'))), value='Audi').render()
u'\n'
>>> GroupedDropdown(name='car_type', args=(('Swedish Cars', (('v', 'Volvo'), ('s', 'Saab'))), ('German Cars', (('m', 'Mercedes'), ('a', 'Audi')))), value='a').render()
u'\n'
"""
def __init__(self, name, args, *validators, **attrs):
self.args = args
super(Dropdown, self).__init__(name, *validators, **attrs)
def render(self):
attrs = self.attrs.copy()
attrs['name'] = self.name
x = '\n'
return x
class Radio(Input):
def __init__(self, name, args, *validators, **attrs):
self.args = args
super(Radio, self).__init__(name, *validators, **attrs)
def render(self):
x = ''
for arg in self.args:
if isinstance(arg, (tuple, list)):
value, desc= arg
else:
value, desc = arg, arg
attrs = self.attrs.copy()
attrs['name'] = self.name
attrs['type'] = 'radio'
attrs['value'] = value
if self.value == value:
attrs['checked'] = 'checked'
x += ' %s' % (attrs, net.websafe(desc))
x += ''
return x
class Checkbox(Input):
"""Checkbox input.
>>> Checkbox('foo', value='bar', checked=True).render()
u''
>>> Checkbox('foo', value='bar').render()
u''
>>> c = Checkbox('foo', value='bar')
>>> c.validate('on')
True
>>> c.render()
u''
"""
def __init__(self, name, *validators, **attrs):
self.checked = attrs.pop('checked', False)
Input.__init__(self, name, *validators, **attrs)
def get_default_id(self):
value = utils.safestr(self.value or "")
return self.name + '_' + value.replace(' ', '_')
def render(self):
attrs = self.attrs.copy()
attrs['type'] = 'checkbox'
attrs['name'] = self.name
attrs['value'] = self.value
if self.checked:
attrs['checked'] = 'checked'
return '' % attrs
def set_value(self, value):
self.checked = bool(value)
def get_value(self):
return self.checked
class Button(Input):
"""HTML Button.
>>> Button("save").render()
u''
>>> Button("action", value="save", html="Save Changes").render()
u''
"""
def __init__(self, name, *validators, **attrs):
super(Button, self).__init__(name, *validators, **attrs)
self.description = ""
def render(self):
attrs = self.attrs.copy()
attrs['name'] = self.name
if self.value is not None:
attrs['value'] = self.value
html = attrs.pop('html', None) or net.websafe(self.name)
return '' % (attrs, html)
class Hidden(Input):
"""Hidden Input.
>>> Hidden(name='foo', value='bar').render()
u''
"""
def is_hidden(self):
return True
def get_type(self):
return 'hidden'
class File(Input):
"""File input.
>>> File(name='f').render()
u''
"""
def get_type(self):
return 'file'
class Validator:
def __deepcopy__(self, memo): return copy.copy(self)
def __init__(self, msg, test, jstest=None): utils.autoassign(self, locals())
def valid(self, value):
try: return self.test(value)
except: return False
notnull = Validator("Required", bool)
class regexp(Validator):
def __init__(self, rexp, msg):
self.rexp = re.compile(rexp)
self.msg = msg
def valid(self, value):
return bool(self.rexp.match(value))
if __name__ == "__main__":
import doctest
doctest.testmod()
webpy/web/http.py 0000644 0001750 0001750 00000010577 11773133455 012471 0 ustar wmb wmb """
HTTP Utilities
(from web.py)
"""
__all__ = [
"expires", "lastmodified",
"prefixurl", "modified",
"changequery", "url",
"profiler",
]
import sys, os, threading, urllib, urlparse
try: import datetime
except ImportError: pass
import net, utils, webapi as web
def prefixurl(base=''):
"""
Sorry, this function is really difficult to explain.
Maybe some other time.
"""
url = web.ctx.path.lstrip('/')
for i in xrange(url.count('/')):
base += '../'
if not base:
base = './'
return base
def expires(delta):
"""
Outputs an `Expires` header for `delta` from now.
`delta` is a `timedelta` object or a number of seconds.
"""
if isinstance(delta, (int, long)):
delta = datetime.timedelta(seconds=delta)
date_obj = datetime.datetime.utcnow() + delta
web.header('Expires', net.httpdate(date_obj))
def lastmodified(date_obj):
"""Outputs a `Last-Modified` header for `datetime`."""
web.header('Last-Modified', net.httpdate(date_obj))
def modified(date=None, etag=None):
"""
Checks to see if the page has been modified since the version in the
requester's cache.
When you publish pages, you can include `Last-Modified` and `ETag`
with the date the page was last modified and an opaque token for
the particular version, respectively. When readers reload the page,
the browser sends along the modification date and etag value for
the version it has in its cache. If the page hasn't changed,
the server can just return `304 Not Modified` and not have to
send the whole page again.
This function takes the last-modified date `date` and the ETag `etag`
and checks the headers to see if they match. If they do, it returns
`True`, or otherwise it raises NotModified error. It also sets
`Last-Modified` and `ETag` output headers.
"""
try:
from __builtin__ import set
except ImportError:
# for python 2.3
from sets import Set as set
n = set([x.strip('" ') for x in web.ctx.env.get('HTTP_IF_NONE_MATCH', '').split(',')])
m = net.parsehttpdate(web.ctx.env.get('HTTP_IF_MODIFIED_SINCE', '').split(';')[0])
validate = False
if etag:
if '*' in n or etag in n:
validate = True
if date and m:
# we subtract a second because
# HTTP dates don't have sub-second precision
if date-datetime.timedelta(seconds=1) <= m:
validate = True
if date: lastmodified(date)
if etag: web.header('ETag', '"' + etag + '"')
if validate:
raise web.notmodified()
else:
return True
def urlencode(query, doseq=0):
"""
Same as urllib.urlencode, but supports unicode strings.
>>> urlencode({'text':'foo bar'})
'text=foo+bar'
>>> urlencode({'x': [1, 2]}, doseq=True)
'x=1&x=2'
"""
def convert(value, doseq=False):
if doseq and isinstance(value, list):
return [convert(v) for v in value]
else:
return utils.safestr(value)
query = dict([(k, convert(v, doseq)) for k, v in query.items()])
return urllib.urlencode(query, doseq=doseq)
def changequery(query=None, **kw):
"""
Imagine you're at `/foo?a=1&b=2`. Then `changequery(a=3)` will return
`/foo?a=3&b=2` -- the same URL but with the arguments you requested
changed.
"""
if query is None:
query = web.rawinput(method='get')
for k, v in kw.iteritems():
if v is None:
query.pop(k, None)
else:
query[k] = v
out = web.ctx.path
if query:
out += '?' + urlencode(query, doseq=True)
return out
def url(path=None, doseq=False, **kw):
"""
Makes url by concatenating web.ctx.homepath and path and the
query string created using the arguments.
"""
if path is None:
path = web.ctx.path
if path.startswith("/"):
out = web.ctx.homepath + path
else:
out = path
if kw:
out += '?' + urlencode(kw, doseq=doseq)
return out
def profiler(app):
"""Outputs basic profiling information at the bottom of each response."""
from utils import profile
def profile_internal(e, o):
out, result = profile(app)(e, o)
return list(out) + ['
' + net.websafe(result) + '
']
return profile_internal
if __name__ == "__main__":
import doctest
doctest.testmod()
webpy/web/wsgiserver/ 0002755 0001750 0001750 00000000000 11773133455 013330 5 ustar wmb wmb webpy/web/wsgiserver/__init__.py 0000644 0001750 0001750 00000246442 11773133455 015453 0 ustar wmb wmb """A high-speed, production ready, thread pooled, generic HTTP server.
Simplest example on how to use this module directly
(without using CherryPy's application machinery)::
from cherrypy import wsgiserver
def my_crazy_app(environ, start_response):
status = '200 OK'
response_headers = [('Content-type','text/plain')]
start_response(status, response_headers)
return ['Hello world!']
server = wsgiserver.CherryPyWSGIServer(
('0.0.0.0', 8070), my_crazy_app,
server_name='www.cherrypy.example')
server.start()
The CherryPy WSGI server can serve as many WSGI applications
as you want in one instance by using a WSGIPathInfoDispatcher::
d = WSGIPathInfoDispatcher({'/': my_crazy_app, '/blog': my_blog_app})
server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', 80), d)
Want SSL support? Just set server.ssl_adapter to an SSLAdapter instance.
This won't call the CherryPy engine (application side) at all, only the
HTTP server, which is independent from the rest of CherryPy. Don't
let the name "CherryPyWSGIServer" throw you; the name merely reflects
its origin, not its coupling.
For those of you wanting to understand internals of this module, here's the
basic call flow. The server's listening thread runs a very tight loop,
sticking incoming connections onto a Queue::
server = CherryPyWSGIServer(...)
server.start()
while True:
tick()
# This blocks until a request comes in:
child = socket.accept()
conn = HTTPConnection(child, ...)
server.requests.put(conn)
Worker threads are kept in a pool and poll the Queue, popping off and then
handling each connection in turn. Each connection can consist of an arbitrary
number of requests and their responses, so we run a nested loop::
while True:
conn = server.requests.get()
conn.communicate()
-> while True:
req = HTTPRequest(...)
req.parse_request()
-> # Read the Request-Line, e.g. "GET /page HTTP/1.1"
req.rfile.readline()
read_headers(req.rfile, req.inheaders)
req.respond()
-> response = app(...)
try:
for chunk in response:
if chunk:
req.write(chunk)
finally:
if hasattr(response, "close"):
response.close()
if req.close_connection:
return
"""
CRLF = '\r\n'
import os
import Queue
import re
quoted_slash = re.compile("(?i)%2F")
import rfc822
import socket
import sys
if 'win' in sys.platform and not hasattr(socket, 'IPPROTO_IPV6'):
socket.IPPROTO_IPV6 = 41
try:
import cStringIO as StringIO
except ImportError:
import StringIO
DEFAULT_BUFFER_SIZE = -1
_fileobject_uses_str_type = isinstance(socket._fileobject(None)._rbuf, basestring)
import threading
import time
import traceback
def format_exc(limit=None):
"""Like print_exc() but return a string. Backport for Python 2.3."""
try:
etype, value, tb = sys.exc_info()
return ''.join(traceback.format_exception(etype, value, tb, limit))
finally:
etype = value = tb = None
from urllib import unquote
from urlparse import urlparse
import warnings
import errno
def plat_specific_errors(*errnames):
"""Return error numbers for all errors in errnames on this platform.
The 'errno' module contains different global constants depending on
the specific platform (OS). This function will return the list of
numeric values for a given list of potential names.
"""
errno_names = dir(errno)
nums = [getattr(errno, k) for k in errnames if k in errno_names]
# de-dupe the list
return dict.fromkeys(nums).keys()
socket_error_eintr = plat_specific_errors("EINTR", "WSAEINTR")
socket_errors_to_ignore = plat_specific_errors(
"EPIPE",
"EBADF", "WSAEBADF",
"ENOTSOCK", "WSAENOTSOCK",
"ETIMEDOUT", "WSAETIMEDOUT",
"ECONNREFUSED", "WSAECONNREFUSED",
"ECONNRESET", "WSAECONNRESET",
"ECONNABORTED", "WSAECONNABORTED",
"ENETRESET", "WSAENETRESET",
"EHOSTDOWN", "EHOSTUNREACH",
)
socket_errors_to_ignore.append("timed out")
socket_errors_to_ignore.append("The read operation timed out")
socket_errors_nonblocking = plat_specific_errors(
'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK')
comma_separated_headers = ['Accept', 'Accept-Charset', 'Accept-Encoding',
'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control',
'Connection', 'Content-Encoding', 'Content-Language', 'Expect',
'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE',
'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning',
'WWW-Authenticate']
import logging
if not hasattr(logging, 'statistics'): logging.statistics = {}
def read_headers(rfile, hdict=None):
"""Read headers from the given stream into the given header dict.
If hdict is None, a new header dict is created. Returns the populated
header dict.
Headers which are repeated are folded together using a comma if their
specification so dictates.
This function raises ValueError when the read bytes violate the HTTP spec.
You should probably return "400 Bad Request" if this happens.
"""
if hdict is None:
hdict = {}
while True:
line = rfile.readline()
if not line:
# No more data--illegal end of headers
raise ValueError("Illegal end of headers.")
if line == CRLF:
# Normal end of headers
break
if not line.endswith(CRLF):
raise ValueError("HTTP requires CRLF terminators")
if line[0] in ' \t':
# It's a continuation line.
v = line.strip()
else:
try:
k, v = line.split(":", 1)
except ValueError:
raise ValueError("Illegal header line.")
# TODO: what about TE and WWW-Authenticate?
k = k.strip().title()
v = v.strip()
hname = k
if k in comma_separated_headers:
existing = hdict.get(hname)
if existing:
v = ", ".join((existing, v))
hdict[hname] = v
return hdict
class MaxSizeExceeded(Exception):
pass
class SizeCheckWrapper(object):
"""Wraps a file-like object, raising MaxSizeExceeded if too large."""
def __init__(self, rfile, maxlen):
self.rfile = rfile
self.maxlen = maxlen
self.bytes_read = 0
def _check_length(self):
if self.maxlen and self.bytes_read > self.maxlen:
raise MaxSizeExceeded()
def read(self, size=None):
data = self.rfile.read(size)
self.bytes_read += len(data)
self._check_length()
return data
def readline(self, size=None):
if size is not None:
data = self.rfile.readline(size)
self.bytes_read += len(data)
self._check_length()
return data
# User didn't specify a size ...
# We read the line in chunks to make sure it's not a 100MB line !
res = []
while True:
data = self.rfile.readline(256)
self.bytes_read += len(data)
self._check_length()
res.append(data)
# See http://www.cherrypy.org/ticket/421
if len(data) < 256 or data[-1:] == "\n":
return ''.join(res)
def readlines(self, sizehint=0):
# Shamelessly stolen from StringIO
total = 0
lines = []
line = self.readline()
while line:
lines.append(line)
total += len(line)
if 0 < sizehint <= total:
break
line = self.readline()
return lines
def close(self):
self.rfile.close()
def __iter__(self):
return self
def next(self):
data = self.rfile.next()
self.bytes_read += len(data)
self._check_length()
return data
class KnownLengthRFile(object):
"""Wraps a file-like object, returning an empty string when exhausted."""
def __init__(self, rfile, content_length):
self.rfile = rfile
self.remaining = content_length
def read(self, size=None):
if self.remaining == 0:
return ''
if size is None:
size = self.remaining
else:
size = min(size, self.remaining)
data = self.rfile.read(size)
self.remaining -= len(data)
return data
def readline(self, size=None):
if self.remaining == 0:
return ''
if size is None:
size = self.remaining
else:
size = min(size, self.remaining)
data = self.rfile.readline(size)
self.remaining -= len(data)
return data
def readlines(self, sizehint=0):
# Shamelessly stolen from StringIO
total = 0
lines = []
line = self.readline(sizehint)
while line:
lines.append(line)
total += len(line)
if 0 < sizehint <= total:
break
line = self.readline(sizehint)
return lines
def close(self):
self.rfile.close()
def __iter__(self):
return self
def __next__(self):
data = next(self.rfile)
self.remaining -= len(data)
return data
class ChunkedRFile(object):
"""Wraps a file-like object, returning an empty string when exhausted.
This class is intended to provide a conforming wsgi.input value for
request entities that have been encoded with the 'chunked' transfer
encoding.
"""
def __init__(self, rfile, maxlen, bufsize=8192):
self.rfile = rfile
self.maxlen = maxlen
self.bytes_read = 0
self.buffer = ''
self.bufsize = bufsize
self.closed = False
def _fetch(self):
if self.closed:
return
line = self.rfile.readline()
self.bytes_read += len(line)
if self.maxlen and self.bytes_read > self.maxlen:
raise MaxSizeExceeded("Request Entity Too Large", self.maxlen)
line = line.strip().split(";", 1)
try:
chunk_size = line.pop(0)
chunk_size = int(chunk_size, 16)
except ValueError:
raise ValueError("Bad chunked transfer size: " + repr(chunk_size))
if chunk_size <= 0:
self.closed = True
return
## if line: chunk_extension = line[0]
if self.maxlen and self.bytes_read + chunk_size > self.maxlen:
raise IOError("Request Entity Too Large")
chunk = self.rfile.read(chunk_size)
self.bytes_read += len(chunk)
self.buffer += chunk
crlf = self.rfile.read(2)
if crlf != CRLF:
raise ValueError(
"Bad chunked transfer coding (expected '\\r\\n', "
"got " + repr(crlf) + ")")
def read(self, size=None):
data = ''
while True:
if size and len(data) >= size:
return data
if not self.buffer:
self._fetch()
if not self.buffer:
# EOF
return data
if size:
remaining = size - len(data)
data += self.buffer[:remaining]
self.buffer = self.buffer[remaining:]
else:
data += self.buffer
def readline(self, size=None):
data = ''
while True:
if size and len(data) >= size:
return data
if not self.buffer:
self._fetch()
if not self.buffer:
# EOF
return data
newline_pos = self.buffer.find('\n')
if size:
if newline_pos == -1:
remaining = size - len(data)
data += self.buffer[:remaining]
self.buffer = self.buffer[remaining:]
else:
remaining = min(size - len(data), newline_pos)
data += self.buffer[:remaining]
self.buffer = self.buffer[remaining:]
else:
if newline_pos == -1:
data += self.buffer
else:
data += self.buffer[:newline_pos]
self.buffer = self.buffer[newline_pos:]
def readlines(self, sizehint=0):
# Shamelessly stolen from StringIO
total = 0
lines = []
line = self.readline(sizehint)
while line:
lines.append(line)
total += len(line)
if 0 < sizehint <= total:
break
line = self.readline(sizehint)
return lines
def read_trailer_lines(self):
if not self.closed:
raise ValueError(
"Cannot read trailers until the request body has been read.")
while True:
line = self.rfile.readline()
if not line:
# No more data--illegal end of headers
raise ValueError("Illegal end of headers.")
self.bytes_read += len(line)
if self.maxlen and self.bytes_read > self.maxlen:
raise IOError("Request Entity Too Large")
if line == CRLF:
# Normal end of headers
break
if not line.endswith(CRLF):
raise ValueError("HTTP requires CRLF terminators")
yield line
def close(self):
self.rfile.close()
def __iter__(self):
# Shamelessly stolen from StringIO
total = 0
line = self.readline(sizehint)
while line:
yield line
total += len(line)
if 0 < sizehint <= total:
break
line = self.readline(sizehint)
class HTTPRequest(object):
"""An HTTP Request (and response).
A single HTTP connection may consist of multiple request/response pairs.
"""
server = None
"""The HTTPServer object which is receiving this request."""
conn = None
"""The HTTPConnection object on which this request connected."""
inheaders = {}
"""A dict of request headers."""
outheaders = []
"""A list of header tuples to write in the response."""
ready = False
"""When True, the request has been parsed and is ready to begin generating
the response. When False, signals the calling Connection that the response
should not be generated and the connection should close."""
close_connection = False
"""Signals the calling Connection that the request should close. This does
not imply an error! The client and/or server may each request that the
connection be closed."""
chunked_write = False
"""If True, output will be encoded with the "chunked" transfer-coding.
This value is set automatically inside send_headers."""
def __init__(self, server, conn):
self.server= server
self.conn = conn
self.ready = False
self.started_request = False
self.scheme = "http"
if self.server.ssl_adapter is not None:
self.scheme = "https"
# Use the lowest-common protocol in case read_request_line errors.
self.response_protocol = 'HTTP/1.0'
self.inheaders = {}
self.status = ""
self.outheaders = []
self.sent_headers = False
self.close_connection = self.__class__.close_connection
self.chunked_read = False
self.chunked_write = self.__class__.chunked_write
def parse_request(self):
"""Parse the next HTTP request start-line and message-headers."""
self.rfile = SizeCheckWrapper(self.conn.rfile,
self.server.max_request_header_size)
try:
self.read_request_line()
except MaxSizeExceeded:
self.simple_response("414 Request-URI Too Long",
"The Request-URI sent with the request exceeds the maximum "
"allowed bytes.")
return
try:
success = self.read_request_headers()
except MaxSizeExceeded:
self.simple_response("413 Request Entity Too Large",
"The headers sent with the request exceed the maximum "
"allowed bytes.")
return
else:
if not success:
return
self.ready = True
def read_request_line(self):
# HTTP/1.1 connections are persistent by default. If a client
# requests a page, then idles (leaves the connection open),
# then rfile.readline() will raise socket.error("timed out").
# Note that it does this based on the value given to settimeout(),
# and doesn't need the client to request or acknowledge the close
# (although your TCP stack might suffer for it: cf Apache's history
# with FIN_WAIT_2).
request_line = self.rfile.readline()
# Set started_request to True so communicate() knows to send 408
# from here on out.
self.started_request = True
if not request_line:
# Force self.ready = False so the connection will close.
self.ready = False
return
if request_line == CRLF:
# RFC 2616 sec 4.1: "...if the server is reading the protocol
# stream at the beginning of a message and receives a CRLF
# first, it should ignore the CRLF."
# But only ignore one leading line! else we enable a DoS.
request_line = self.rfile.readline()
if not request_line:
self.ready = False
return
if not request_line.endswith(CRLF):
self.simple_response("400 Bad Request", "HTTP requires CRLF terminators")
return
try:
method, uri, req_protocol = request_line.strip().split(" ", 2)
rp = int(req_protocol[5]), int(req_protocol[7])
except (ValueError, IndexError):
self.simple_response("400 Bad Request", "Malformed Request-Line")
return
self.uri = uri
self.method = method
# uri may be an abs_path (including "http://host.domain.tld");
scheme, authority, path = self.parse_request_uri(uri)
if '#' in path:
self.simple_response("400 Bad Request",
"Illegal #fragment in Request-URI.")
return
if scheme:
self.scheme = scheme
qs = ''
if '?' in path:
path, qs = path.split('?', 1)
# Unquote the path+params (e.g. "/this%20path" -> "/this path").
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2
#
# But note that "...a URI must be separated into its components
# before the escaped characters within those components can be
# safely decoded." http://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2
# Therefore, "/this%2Fpath" becomes "/this%2Fpath", not "/this/path".
try:
atoms = [unquote(x) for x in quoted_slash.split(path)]
except ValueError, ex:
self.simple_response("400 Bad Request", ex.args[0])
return
path = "%2F".join(atoms)
self.path = path
# Note that, like wsgiref and most other HTTP servers,
# we "% HEX HEX"-unquote the path but not the query string.
self.qs = qs
# Compare request and server HTTP protocol versions, in case our
# server does not support the requested protocol. Limit our output
# to min(req, server). We want the following output:
# request server actual written supported response
# protocol protocol response protocol feature set
# a 1.0 1.0 1.0 1.0
# b 1.0 1.1 1.1 1.0
# c 1.1 1.0 1.0 1.0
# d 1.1 1.1 1.1 1.1
# Notice that, in (b), the response will be "HTTP/1.1" even though
# the client only understands 1.0. RFC 2616 10.5.6 says we should
# only return 505 if the _major_ version is different.
sp = int(self.server.protocol[5]), int(self.server.protocol[7])
if sp[0] != rp[0]:
self.simple_response("505 HTTP Version Not Supported")
return
self.request_protocol = req_protocol
self.response_protocol = "HTTP/%s.%s" % min(rp, sp)
def read_request_headers(self):
"""Read self.rfile into self.inheaders. Return success."""
# then all the http headers
try:
read_headers(self.rfile, self.inheaders)
except ValueError, ex:
self.simple_response("400 Bad Request", ex.args[0])
return False
mrbs = self.server.max_request_body_size
if mrbs and int(self.inheaders.get("Content-Length", 0)) > mrbs:
self.simple_response("413 Request Entity Too Large",
"The entity sent with the request exceeds the maximum "
"allowed bytes.")
return False
# Persistent connection support
if self.response_protocol == "HTTP/1.1":
# Both server and client are HTTP/1.1
if self.inheaders.get("Connection", "") == "close":
self.close_connection = True
else:
# Either the server or client (or both) are HTTP/1.0
if self.inheaders.get("Connection", "") != "Keep-Alive":
self.close_connection = True
# Transfer-Encoding support
te = None
if self.response_protocol == "HTTP/1.1":
te = self.inheaders.get("Transfer-Encoding")
if te:
te = [x.strip().lower() for x in te.split(",") if x.strip()]
self.chunked_read = False
if te:
for enc in te:
if enc == "chunked":
self.chunked_read = True
else:
# Note that, even if we see "chunked", we must reject
# if there is an extension we don't recognize.
self.simple_response("501 Unimplemented")
self.close_connection = True
return False
# From PEP 333:
# "Servers and gateways that implement HTTP 1.1 must provide
# transparent support for HTTP 1.1's "expect/continue" mechanism.
# This may be done in any of several ways:
# 1. Respond to requests containing an Expect: 100-continue request
# with an immediate "100 Continue" response, and proceed normally.
# 2. Proceed with the request normally, but provide the application
# with a wsgi.input stream that will send the "100 Continue"
# response if/when the application first attempts to read from
# the input stream. The read request must then remain blocked
# until the client responds.
# 3. Wait until the client decides that the server does not support
# expect/continue, and sends the request body on its own.
# (This is suboptimal, and is not recommended.)
#
# We used to do 3, but are now doing 1. Maybe we'll do 2 someday,
# but it seems like it would be a big slowdown for such a rare case.
if self.inheaders.get("Expect", "") == "100-continue":
# Don't use simple_response here, because it emits headers
# we don't want. See http://www.cherrypy.org/ticket/951
msg = self.server.protocol + " 100 Continue\r\n\r\n"
try:
self.conn.wfile.sendall(msg)
except socket.error, x:
if x.args[0] not in socket_errors_to_ignore:
raise
return True
def parse_request_uri(self, uri):
"""Parse a Request-URI into (scheme, authority, path).
Note that Request-URI's must be one of::
Request-URI = "*" | absoluteURI | abs_path | authority
Therefore, a Request-URI which starts with a double forward-slash
cannot be a "net_path"::
net_path = "//" authority [ abs_path ]
Instead, it must be interpreted as an "abs_path" with an empty first
path segment::
abs_path = "/" path_segments
path_segments = segment *( "/" segment )
segment = *pchar *( ";" param )
param = *pchar
"""
if uri == "*":
return None, None, uri
i = uri.find('://')
if i > 0 and '?' not in uri[:i]:
# An absoluteURI.
# If there's a scheme (and it must be http or https), then:
# http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query ]]
scheme, remainder = uri[:i].lower(), uri[i + 3:]
authority, path = remainder.split("/", 1)
return scheme, authority, path
if uri.startswith('/'):
# An abs_path.
return None, None, uri
else:
# An authority.
return None, uri, None
def respond(self):
"""Call the gateway and write its iterable output."""
mrbs = self.server.max_request_body_size
if self.chunked_read:
self.rfile = ChunkedRFile(self.conn.rfile, mrbs)
else:
cl = int(self.inheaders.get("Content-Length", 0))
if mrbs and mrbs < cl:
if not self.sent_headers:
self.simple_response("413 Request Entity Too Large",
"The entity sent with the request exceeds the maximum "
"allowed bytes.")
return
self.rfile = KnownLengthRFile(self.conn.rfile, cl)
self.server.gateway(self).respond()
if (self.ready and not self.sent_headers):
self.sent_headers = True
self.send_headers()
if self.chunked_write:
self.conn.wfile.sendall("0\r\n\r\n")
def simple_response(self, status, msg=""):
"""Write a simple response back to the client."""
status = str(status)
buf = [self.server.protocol + " " +
status + CRLF,
"Content-Length: %s\r\n" % len(msg),
"Content-Type: text/plain\r\n"]
if status[:3] in ("413", "414"):
# Request Entity Too Large / Request-URI Too Long
self.close_connection = True
if self.response_protocol == 'HTTP/1.1':
# This will not be true for 414, since read_request_line
# usually raises 414 before reading the whole line, and we
# therefore cannot know the proper response_protocol.
buf.append("Connection: close\r\n")
else:
# HTTP/1.0 had no 413/414 status nor Connection header.
# Emit 400 instead and trust the message body is enough.
status = "400 Bad Request"
buf.append(CRLF)
if msg:
if isinstance(msg, unicode):
msg = msg.encode("ISO-8859-1")
buf.append(msg)
try:
self.conn.wfile.sendall("".join(buf))
except socket.error, x:
if x.args[0] not in socket_errors_to_ignore:
raise
def write(self, chunk):
"""Write unbuffered data to the client."""
if self.chunked_write and chunk:
buf = [hex(len(chunk))[2:], CRLF, chunk, CRLF]
self.conn.wfile.sendall("".join(buf))
else:
self.conn.wfile.sendall(chunk)
def send_headers(self):
"""Assert, process, and send the HTTP response message-headers.
You must set self.status, and self.outheaders before calling this.
"""
hkeys = [key.lower() for key, value in self.outheaders]
status = int(self.status[:3])
if status == 413:
# Request Entity Too Large. Close conn to avoid garbage.
self.close_connection = True
elif "content-length" not in hkeys:
# "All 1xx (informational), 204 (no content),
# and 304 (not modified) responses MUST NOT
# include a message-body." So no point chunking.
if status < 200 or status in (204, 205, 304):
pass
else:
if (self.response_protocol == 'HTTP/1.1'
and self.method != 'HEAD'):
# Use the chunked transfer-coding
self.chunked_write = True
self.outheaders.append(("Transfer-Encoding", "chunked"))
else:
# Closing the conn is the only way to determine len.
self.close_connection = True
if "connection" not in hkeys:
if self.response_protocol == 'HTTP/1.1':
# Both server and client are HTTP/1.1 or better
if self.close_connection:
self.outheaders.append(("Connection", "close"))
else:
# Server and/or client are HTTP/1.0
if not self.close_connection:
self.outheaders.append(("Connection", "Keep-Alive"))
if (not self.close_connection) and (not self.chunked_read):
# Read any remaining request body data on the socket.
# "If an origin server receives a request that does not include an
# Expect request-header field with the "100-continue" expectation,
# the request includes a request body, and the server responds
# with a final status code before reading the entire request body
# from the transport connection, then the server SHOULD NOT close
# the transport connection until it has read the entire request,
# or until the client closes the connection. Otherwise, the client
# might not reliably receive the response message. However, this
# requirement is not be construed as preventing a server from
# defending itself against denial-of-service attacks, or from
# badly broken client implementations."
remaining = getattr(self.rfile, 'remaining', 0)
if remaining > 0:
self.rfile.read(remaining)
if "date" not in hkeys:
self.outheaders.append(("Date", rfc822.formatdate()))
if "server" not in hkeys:
self.outheaders.append(("Server", self.server.server_name))
buf = [self.server.protocol + " " + self.status + CRLF]
for k, v in self.outheaders:
buf.append(k + ": " + v + CRLF)
buf.append(CRLF)
self.conn.wfile.sendall("".join(buf))
class NoSSLError(Exception):
"""Exception raised when a client speaks HTTP to an HTTPS socket."""
pass
class FatalSSLAlert(Exception):
"""Exception raised when the SSL implementation signals a fatal alert."""
pass
class CP_fileobject(socket._fileobject):
"""Faux file object attached to a socket object."""
def __init__(self, *args, **kwargs):
self.bytes_read = 0
self.bytes_written = 0
socket._fileobject.__init__(self, *args, **kwargs)
def sendall(self, data):
"""Sendall for non-blocking sockets."""
while data:
try:
bytes_sent = self.send(data)
data = data[bytes_sent:]
except socket.error, e:
if e.args[0] not in socket_errors_nonblocking:
raise
def send(self, data):
bytes_sent = self._sock.send(data)
self.bytes_written += bytes_sent
return bytes_sent
def flush(self):
if self._wbuf:
buffer = "".join(self._wbuf)
self._wbuf = []
self.sendall(buffer)
def recv(self, size):
while True:
try:
data = self._sock.recv(size)
self.bytes_read += len(data)
return data
except socket.error, e:
if (e.args[0] not in socket_errors_nonblocking
and e.args[0] not in socket_error_eintr):
raise
if not _fileobject_uses_str_type:
def read(self, size=-1):
# Use max, disallow tiny reads in a loop as they are very inefficient.
# We never leave read() with any leftover data from a new recv() call
# in our internal buffer.
rbufsize = max(self._rbufsize, self.default_bufsize)
# Our use of StringIO rather than lists of string objects returned by
# recv() minimizes memory usage and fragmentation that occurs when
# rbufsize is large compared to the typical return value of recv().
buf = self._rbuf
buf.seek(0, 2) # seek end
if size < 0:
# Read until EOF
self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
while True:
data = self.recv(rbufsize)
if not data:
break
buf.write(data)
return buf.getvalue()
else:
# Read until size bytes or EOF seen, whichever comes first
buf_len = buf.tell()
if buf_len >= size:
# Already have size bytes in our buffer? Extract and return.
buf.seek(0)
rv = buf.read(size)
self._rbuf = StringIO.StringIO()
self._rbuf.write(buf.read())
return rv
self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
while True:
left = size - buf_len
# recv() will malloc the amount of memory given as its
# parameter even though it often returns much less data
# than that. The returned data string is short lived
# as we copy it into a StringIO and free it. This avoids
# fragmentation issues on many platforms.
data = self.recv(left)
if not data:
break
n = len(data)
if n == size and not buf_len:
# Shortcut. Avoid buffer data copies when:
# - We have no data in our buffer.
# AND
# - Our call to recv returned exactly the
# number of bytes we were asked to read.
return data
if n == left:
buf.write(data)
del data # explicit free
break
assert n <= left, "recv(%d) returned %d bytes" % (left, n)
buf.write(data)
buf_len += n
del data # explicit free
#assert buf_len == buf.tell()
return buf.getvalue()
def readline(self, size=-1):
buf = self._rbuf
buf.seek(0, 2) # seek end
if buf.tell() > 0:
# check if we already have it in our buffer
buf.seek(0)
bline = buf.readline(size)
if bline.endswith('\n') or len(bline) == size:
self._rbuf = StringIO.StringIO()
self._rbuf.write(buf.read())
return bline
del bline
if size < 0:
# Read until \n or EOF, whichever comes first
if self._rbufsize <= 1:
# Speed up unbuffered case
buf.seek(0)
buffers = [buf.read()]
self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
data = None
recv = self.recv
while data != "\n":
data = recv(1)
if not data:
break
buffers.append(data)
return "".join(buffers)
buf.seek(0, 2) # seek end
self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
while True:
data = self.recv(self._rbufsize)
if not data:
break
nl = data.find('\n')
if nl >= 0:
nl += 1
buf.write(data[:nl])
self._rbuf.write(data[nl:])
del data
break
buf.write(data)
return buf.getvalue()
else:
# Read until size bytes or \n or EOF seen, whichever comes first
buf.seek(0, 2) # seek end
buf_len = buf.tell()
if buf_len >= size:
buf.seek(0)
rv = buf.read(size)
self._rbuf = StringIO.StringIO()
self._rbuf.write(buf.read())
return rv
self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
while True:
data = self.recv(self._rbufsize)
if not data:
break
left = size - buf_len
# did we just receive a newline?
nl = data.find('\n', 0, left)
if nl >= 0:
nl += 1
# save the excess data to _rbuf
self._rbuf.write(data[nl:])
if buf_len:
buf.write(data[:nl])
break
else:
# Shortcut. Avoid data copy through buf when returning
# a substring of our first recv().
return data[:nl]
n = len(data)
if n == size and not buf_len:
# Shortcut. Avoid data copy through buf when
# returning exactly all of our first recv().
return data
if n >= left:
buf.write(data[:left])
self._rbuf.write(data[left:])
break
buf.write(data)
buf_len += n
#assert buf_len == buf.tell()
return buf.getvalue()
else:
def read(self, size=-1):
if size < 0:
# Read until EOF
buffers = [self._rbuf]
self._rbuf = ""
if self._rbufsize <= 1:
recv_size = self.default_bufsize
else:
recv_size = self._rbufsize
while True:
data = self.recv(recv_size)
if not data:
break
buffers.append(data)
return "".join(buffers)
else:
# Read until size bytes or EOF seen, whichever comes first
data = self._rbuf
buf_len = len(data)
if buf_len >= size:
self._rbuf = data[size:]
return data[:size]
buffers = []
if data:
buffers.append(data)
self._rbuf = ""
while True:
left = size - buf_len
recv_size = max(self._rbufsize, left)
data = self.recv(recv_size)
if not data:
break
buffers.append(data)
n = len(data)
if n >= left:
self._rbuf = data[left:]
buffers[-1] = data[:left]
break
buf_len += n
return "".join(buffers)
def readline(self, size=-1):
data = self._rbuf
if size < 0:
# Read until \n or EOF, whichever comes first
if self._rbufsize <= 1:
# Speed up unbuffered case
assert data == ""
buffers = []
while data != "\n":
data = self.recv(1)
if not data:
break
buffers.append(data)
return "".join(buffers)
nl = data.find('\n')
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
return data[:nl]
buffers = []
if data:
buffers.append(data)
self._rbuf = ""
while True:
data = self.recv(self._rbufsize)
if not data:
break
buffers.append(data)
nl = data.find('\n')
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
buffers[-1] = data[:nl]
break
return "".join(buffers)
else:
# Read until size bytes or \n or EOF seen, whichever comes first
nl = data.find('\n', 0, size)
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
return data[:nl]
buf_len = len(data)
if buf_len >= size:
self._rbuf = data[size:]
return data[:size]
buffers = []
if data:
buffers.append(data)
self._rbuf = ""
while True:
data = self.recv(self._rbufsize)
if not data:
break
buffers.append(data)
left = size - buf_len
nl = data.find('\n', 0, left)
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
buffers[-1] = data[:nl]
break
n = len(data)
if n >= left:
self._rbuf = data[left:]
buffers[-1] = data[:left]
break
buf_len += n
return "".join(buffers)
class HTTPConnection(object):
"""An HTTP connection (active socket).
server: the Server object which received this connection.
socket: the raw socket object (usually TCP) for this connection.
makefile: a fileobject class for reading from the socket.
"""
remote_addr = None
remote_port = None
ssl_env = None
rbufsize = DEFAULT_BUFFER_SIZE
wbufsize = DEFAULT_BUFFER_SIZE
RequestHandlerClass = HTTPRequest
def __init__(self, server, sock, makefile=CP_fileobject):
self.server = server
self.socket = sock
self.rfile = makefile(sock, "rb", self.rbufsize)
self.wfile = makefile(sock, "wb", self.wbufsize)
self.requests_seen = 0
def communicate(self):
"""Read each request and respond appropriately."""
request_seen = False
try:
while True:
# (re)set req to None so that if something goes wrong in
# the RequestHandlerClass constructor, the error doesn't
# get written to the previous request.
req = None
req = self.RequestHandlerClass(self.server, self)
# This order of operations should guarantee correct pipelining.
req.parse_request()
if self.server.stats['Enabled']:
self.requests_seen += 1
if not req.ready:
# Something went wrong in the parsing (and the server has
# probably already made a simple_response). Return and
# let the conn close.
return
request_seen = True
req.respond()
if req.close_connection:
return
except socket.error, e:
errnum = e.args[0]
# sadly SSL sockets return a different (longer) time out string
if errnum == 'timed out' or errnum == 'The read operation timed out':
# Don't error if we're between requests; only error
# if 1) no request has been started at all, or 2) we're
# in the middle of a request.
# See http://www.cherrypy.org/ticket/853
if (not request_seen) or (req and req.started_request):
# Don't bother writing the 408 if the response
# has already started being written.
if req and not req.sent_headers:
try:
req.simple_response("408 Request Timeout")
except FatalSSLAlert:
# Close the connection.
return
elif errnum not in socket_errors_to_ignore:
if req and not req.sent_headers:
try:
req.simple_response("500 Internal Server Error",
format_exc())
except FatalSSLAlert:
# Close the connection.
return
return
except (KeyboardInterrupt, SystemExit):
raise
except FatalSSLAlert:
# Close the connection.
return
except NoSSLError:
if req and not req.sent_headers:
# Unwrap our wfile
self.wfile = CP_fileobject(self.socket._sock, "wb", self.wbufsize)
req.simple_response("400 Bad Request",
"The client sent a plain HTTP request, but "
"this server only speaks HTTPS on this port.")
self.linger = True
except Exception:
if req and not req.sent_headers:
try:
req.simple_response("500 Internal Server Error", format_exc())
except FatalSSLAlert:
# Close the connection.
return
linger = False
def close(self):
"""Close the socket underlying this connection."""
self.rfile.close()
if not self.linger:
# Python's socket module does NOT call close on the kernel socket
# when you call socket.close(). We do so manually here because we
# want this server to send a FIN TCP segment immediately. Note this
# must be called *before* calling socket.close(), because the latter
# drops its reference to the kernel socket.
if hasattr(self.socket, '_sock'):
self.socket._sock.close()
self.socket.close()
else:
# On the other hand, sometimes we want to hang around for a bit
# to make sure the client has a chance to read our entire
# response. Skipping the close() calls here delays the FIN
# packet until the socket object is garbage-collected later.
# Someday, perhaps, we'll do the full lingering_close that
# Apache does, but not today.
pass
_SHUTDOWNREQUEST = None
class WorkerThread(threading.Thread):
"""Thread which continuously polls a Queue for Connection objects.
Due to the timing issues of polling a Queue, a WorkerThread does not
check its own 'ready' flag after it has started. To stop the thread,
it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue
(one for each running WorkerThread).
"""
conn = None
"""The current connection pulled off the Queue, or None."""
server = None
"""The HTTP Server which spawned this thread, and which owns the
Queue and is placing active connections into it."""
ready = False
"""A simple flag for the calling server to know when this thread
has begun polling the Queue."""
def __init__(self, server):
self.ready = False
self.server = server
self.requests_seen = 0
self.bytes_read = 0
self.bytes_written = 0
self.start_time = None
self.work_time = 0
self.stats = {
'Requests': lambda s: self.requests_seen + ((self.start_time is None) and 0 or self.conn.requests_seen),
'Bytes Read': lambda s: self.bytes_read + ((self.start_time is None) and 0 or self.conn.rfile.bytes_read),
'Bytes Written': lambda s: self.bytes_written + ((self.start_time is None) and 0 or self.conn.wfile.bytes_written),
'Work Time': lambda s: self.work_time + ((self.start_time is None) and 0 or time.time() - self.start_time),
'Read Throughput': lambda s: s['Bytes Read'](s) / (s['Work Time'](s) or 1e-6),
'Write Throughput': lambda s: s['Bytes Written'](s) / (s['Work Time'](s) or 1e-6),
}
threading.Thread.__init__(self)
def run(self):
self.server.stats['Worker Threads'][self.getName()] = self.stats
try:
self.ready = True
while True:
conn = self.server.requests.get()
if conn is _SHUTDOWNREQUEST:
return
self.conn = conn
if self.server.stats['Enabled']:
self.start_time = time.time()
try:
conn.communicate()
finally:
conn.close()
if self.server.stats['Enabled']:
self.requests_seen += self.conn.requests_seen
self.bytes_read += self.conn.rfile.bytes_read
self.bytes_written += self.conn.wfile.bytes_written
self.work_time += time.time() - self.start_time
self.start_time = None
self.conn = None
except (KeyboardInterrupt, SystemExit), exc:
self.server.interrupt = exc
class ThreadPool(object):
"""A Request Queue for the CherryPyWSGIServer which pools threads.
ThreadPool objects must provide min, get(), put(obj), start()
and stop(timeout) attributes.
"""
def __init__(self, server, min=10, max=-1):
self.server = server
self.min = min
self.max = max
self._threads = []
self._queue = Queue.Queue()
self.get = self._queue.get
def start(self):
"""Start the pool of threads."""
for i in range(self.min):
self._threads.append(WorkerThread(self.server))
for worker in self._threads:
worker.setName("CP Server " + worker.getName())
worker.start()
for worker in self._threads:
while not worker.ready:
time.sleep(.1)
def _get_idle(self):
"""Number of worker threads which are idle. Read-only."""
return len([t for t in self._threads if t.conn is None])
idle = property(_get_idle, doc=_get_idle.__doc__)
def put(self, obj):
self._queue.put(obj)
if obj is _SHUTDOWNREQUEST:
return
def grow(self, amount):
"""Spawn new worker threads (not above self.max)."""
for i in range(amount):
if self.max > 0 and len(self._threads) >= self.max:
break
worker = WorkerThread(self.server)
worker.setName("CP Server " + worker.getName())
self._threads.append(worker)
worker.start()
def shrink(self, amount):
"""Kill off worker threads (not below self.min)."""
# Grow/shrink the pool if necessary.
# Remove any dead threads from our list
for t in self._threads:
if not t.isAlive():
self._threads.remove(t)
amount -= 1
if amount > 0:
for i in range(min(amount, len(self._threads) - self.min)):
# Put a number of shutdown requests on the queue equal
# to 'amount'. Once each of those is processed by a worker,
# that worker will terminate and be culled from our list
# in self.put.
self._queue.put(_SHUTDOWNREQUEST)
def stop(self, timeout=5):
# Must shut down threads here so the code that calls
# this method can know when all threads are stopped.
for worker in self._threads:
self._queue.put(_SHUTDOWNREQUEST)
# Don't join currentThread (when stop is called inside a request).
current = threading.currentThread()
if timeout and timeout >= 0:
endtime = time.time() + timeout
while self._threads:
worker = self._threads.pop()
if worker is not current and worker.isAlive():
try:
if timeout is None or timeout < 0:
worker.join()
else:
remaining_time = endtime - time.time()
if remaining_time > 0:
worker.join(remaining_time)
if worker.isAlive():
# We exhausted the timeout.
# Forcibly shut down the socket.
c = worker.conn
if c and not c.rfile.closed:
try:
c.socket.shutdown(socket.SHUT_RD)
except TypeError:
# pyOpenSSL sockets don't take an arg
c.socket.shutdown()
worker.join()
except (AssertionError,
# Ignore repeated Ctrl-C.
# See http://www.cherrypy.org/ticket/691.
KeyboardInterrupt), exc1:
pass
def _get_qsize(self):
return self._queue.qsize()
qsize = property(_get_qsize)
try:
import fcntl
except ImportError:
try:
from ctypes import windll, WinError
except ImportError:
def prevent_socket_inheritance(sock):
"""Dummy function, since neither fcntl nor ctypes are available."""
pass
else:
def prevent_socket_inheritance(sock):
"""Mark the given socket fd as non-inheritable (Windows)."""
if not windll.kernel32.SetHandleInformation(sock.fileno(), 1, 0):
raise WinError()
else:
def prevent_socket_inheritance(sock):
"""Mark the given socket fd as non-inheritable (POSIX)."""
fd = sock.fileno()
old_flags = fcntl.fcntl(fd, fcntl.F_GETFD)
fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC)
class SSLAdapter(object):
"""Base class for SSL driver library adapters.
Required methods:
* ``wrap(sock) -> (wrapped socket, ssl environ dict)``
* ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> socket file object``
"""
def __init__(self, certificate, private_key, certificate_chain=None):
self.certificate = certificate
self.private_key = private_key
self.certificate_chain = certificate_chain
def wrap(self, sock):
raise NotImplemented
def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE):
raise NotImplemented
class HTTPServer(object):
"""An HTTP server."""
_bind_addr = "127.0.0.1"
_interrupt = None
gateway = None
"""A Gateway instance."""
minthreads = None
"""The minimum number of worker threads to create (default 10)."""
maxthreads = None
"""The maximum number of worker threads to create (default -1 = no limit)."""
server_name = None
"""The name of the server; defaults to socket.gethostname()."""
protocol = "HTTP/1.1"
"""The version string to write in the Status-Line of all HTTP responses.
For example, "HTTP/1.1" is the default. This also limits the supported
features used in the response."""
request_queue_size = 5
"""The 'backlog' arg to socket.listen(); max queued connections (default 5)."""
shutdown_timeout = 5
"""The total time, in seconds, to wait for worker threads to cleanly exit."""
timeout = 10
"""The timeout in seconds for accepted connections (default 10)."""
version = "CherryPy/3.2.0"
"""A version string for the HTTPServer."""
software = None
"""The value to set for the SERVER_SOFTWARE entry in the WSGI environ.
If None, this defaults to ``'%s Server' % self.version``."""
ready = False
"""An internal flag which marks whether the socket is accepting connections."""
max_request_header_size = 0
"""The maximum size, in bytes, for request headers, or 0 for no limit."""
max_request_body_size = 0
"""The maximum size, in bytes, for request bodies, or 0 for no limit."""
nodelay = True
"""If True (the default since 3.1), sets the TCP_NODELAY socket option."""
ConnectionClass = HTTPConnection
"""The class to use for handling HTTP connections."""
ssl_adapter = None
"""An instance of SSLAdapter (or a subclass).
You must have the corresponding SSL driver library installed."""
def __init__(self, bind_addr, gateway, minthreads=10, maxthreads=-1,
server_name=None):
self.bind_addr = bind_addr
self.gateway = gateway
self.requests = ThreadPool(self, min=minthreads or 1, max=maxthreads)
if not server_name:
server_name = socket.gethostname()
self.server_name = server_name
self.clear_stats()
def clear_stats(self):
self._start_time = None
self._run_time = 0
self.stats = {
'Enabled': False,
'Bind Address': lambda s: repr(self.bind_addr),
'Run time': lambda s: (not s['Enabled']) and 0 or self.runtime(),
'Accepts': 0,
'Accepts/sec': lambda s: s['Accepts'] / self.runtime(),
'Queue': lambda s: getattr(self.requests, "qsize", None),
'Threads': lambda s: len(getattr(self.requests, "_threads", [])),
'Threads Idle': lambda s: getattr(self.requests, "idle", None),
'Socket Errors': 0,
'Requests': lambda s: (not s['Enabled']) and 0 or sum([w['Requests'](w) for w
in s['Worker Threads'].values()], 0),
'Bytes Read': lambda s: (not s['Enabled']) and 0 or sum([w['Bytes Read'](w) for w
in s['Worker Threads'].values()], 0),
'Bytes Written': lambda s: (not s['Enabled']) and 0 or sum([w['Bytes Written'](w) for w
in s['Worker Threads'].values()], 0),
'Work Time': lambda s: (not s['Enabled']) and 0 or sum([w['Work Time'](w) for w
in s['Worker Threads'].values()], 0),
'Read Throughput': lambda s: (not s['Enabled']) and 0 or sum(
[w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6)
for w in s['Worker Threads'].values()], 0),
'Write Throughput': lambda s: (not s['Enabled']) and 0 or sum(
[w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6)
for w in s['Worker Threads'].values()], 0),
'Worker Threads': {},
}
logging.statistics["CherryPy HTTPServer %d" % id(self)] = self.stats
def runtime(self):
if self._start_time is None:
return self._run_time
else:
return self._run_time + (time.time() - self._start_time)
def __str__(self):
return "%s.%s(%r)" % (self.__module__, self.__class__.__name__,
self.bind_addr)
def _get_bind_addr(self):
return self._bind_addr
def _set_bind_addr(self, value):
if isinstance(value, tuple) and value[0] in ('', None):
# Despite the socket module docs, using '' does not
# allow AI_PASSIVE to work. Passing None instead
# returns '0.0.0.0' like we want. In other words:
# host AI_PASSIVE result
# '' Y 192.168.x.y
# '' N 192.168.x.y
# None Y 0.0.0.0
# None N 127.0.0.1
# But since you can get the same effect with an explicit
# '0.0.0.0', we deny both the empty string and None as values.
raise ValueError("Host values of '' or None are not allowed. "
"Use '0.0.0.0' (IPv4) or '::' (IPv6) instead "
"to listen on all active interfaces.")
self._bind_addr = value
bind_addr = property(_get_bind_addr, _set_bind_addr,
doc="""The interface on which to listen for connections.
For TCP sockets, a (host, port) tuple. Host values may be any IPv4
or IPv6 address, or any valid hostname. The string 'localhost' is a
synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6).
The string '0.0.0.0' is a special IPv4 entry meaning "any active
interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for
IPv6. The empty string or None are not allowed.
For UNIX sockets, supply the filename as a string.""")
def start(self):
"""Run the server forever."""
# We don't have to trap KeyboardInterrupt or SystemExit here,
# because cherrpy.server already does so, calling self.stop() for us.
# If you're using this server with another framework, you should
# trap those exceptions in whatever code block calls start().
self._interrupt = None
if self.software is None:
self.software = "%s Server" % self.version
# SSL backward compatibility
if (self.ssl_adapter is None and
getattr(self, 'ssl_certificate', None) and
getattr(self, 'ssl_private_key', None)):
warnings.warn(
"SSL attributes are deprecated in CherryPy 3.2, and will "
"be removed in CherryPy 3.3. Use an ssl_adapter attribute "
"instead.",
DeprecationWarning
)
try:
from cherrypy.wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter
except ImportError:
pass
else:
self.ssl_adapter = pyOpenSSLAdapter(
self.ssl_certificate, self.ssl_private_key,
getattr(self, 'ssl_certificate_chain', None))
# Select the appropriate socket
if isinstance(self.bind_addr, basestring):
# AF_UNIX socket
# So we can reuse the socket...
try: os.unlink(self.bind_addr)
except: pass
# So everyone can access the socket...
try: os.chmod(self.bind_addr, 0777)
except: pass
info = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_addr)]
else:
# AF_INET or AF_INET6 socket
# Get the correct address family for our host (allows IPv6 addresses)
host, port = self.bind_addr
try:
info = socket.getaddrinfo(host, port, socket.AF_UNSPEC,
socket.SOCK_STREAM, 0, socket.AI_PASSIVE)
except socket.gaierror:
if ':' in self.bind_addr[0]:
info = [(socket.AF_INET6, socket.SOCK_STREAM,
0, "", self.bind_addr + (0, 0))]
else:
info = [(socket.AF_INET, socket.SOCK_STREAM,
0, "", self.bind_addr)]
self.socket = None
msg = "No socket could be created"
for res in info:
af, socktype, proto, canonname, sa = res
try:
self.bind(af, socktype, proto)
except socket.error:
if self.socket:
self.socket.close()
self.socket = None
continue
break
if not self.socket:
raise socket.error(msg)
# Timeout so KeyboardInterrupt can be caught on Win32
self.socket.settimeout(1)
self.socket.listen(self.request_queue_size)
# Create worker threads
self.requests.start()
self.ready = True
self._start_time = time.time()
while self.ready:
self.tick()
if self.interrupt:
while self.interrupt is True:
# Wait for self.stop() to complete. See _set_interrupt.
time.sleep(0.1)
if self.interrupt:
raise self.interrupt
def bind(self, family, type, proto=0):
"""Create (or recreate) the actual socket object."""
self.socket = socket.socket(family, type, proto)
prevent_socket_inheritance(self.socket)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if self.nodelay and not isinstance(self.bind_addr, str):
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if self.ssl_adapter is not None:
self.socket = self.ssl_adapter.bind(self.socket)
# If listening on the IPV6 any address ('::' = IN6ADDR_ANY),
# activate dual-stack. See http://www.cherrypy.org/ticket/871.
if (hasattr(socket, 'AF_INET6') and family == socket.AF_INET6
and self.bind_addr[0] in ('::', '::0', '::0.0.0.0')):
try:
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
except (AttributeError, socket.error):
# Apparently, the socket option is not available in
# this machine's TCP stack
pass
self.socket.bind(self.bind_addr)
def tick(self):
"""Accept a new connection and put it on the Queue."""
try:
s, addr = self.socket.accept()
if self.stats['Enabled']:
self.stats['Accepts'] += 1
if not self.ready:
return
prevent_socket_inheritance(s)
if hasattr(s, 'settimeout'):
s.settimeout(self.timeout)
makefile = CP_fileobject
ssl_env = {}
# if ssl cert and key are set, we try to be a secure HTTP server
if self.ssl_adapter is not None:
try:
s, ssl_env = self.ssl_adapter.wrap(s)
except NoSSLError:
msg = ("The client sent a plain HTTP request, but "
"this server only speaks HTTPS on this port.")
buf = ["%s 400 Bad Request\r\n" % self.protocol,
"Content-Length: %s\r\n" % len(msg),
"Content-Type: text/plain\r\n\r\n",
msg]
wfile = CP_fileobject(s, "wb", DEFAULT_BUFFER_SIZE)
try:
wfile.sendall("".join(buf))
except socket.error, x:
if x.args[0] not in socket_errors_to_ignore:
raise
return
if not s:
return
makefile = self.ssl_adapter.makefile
# Re-apply our timeout since we may have a new socket object
if hasattr(s, 'settimeout'):
s.settimeout(self.timeout)
conn = self.ConnectionClass(self, s, makefile)
if not isinstance(self.bind_addr, basestring):
# optional values
# Until we do DNS lookups, omit REMOTE_HOST
if addr is None: # sometimes this can happen
# figure out if AF_INET or AF_INET6.
if len(s.getsockname()) == 2:
# AF_INET
addr = ('0.0.0.0', 0)
else:
# AF_INET6
addr = ('::', 0)
conn.remote_addr = addr[0]
conn.remote_port = addr[1]
conn.ssl_env = ssl_env
self.requests.put(conn)
except socket.timeout:
# The only reason for the timeout in start() is so we can
# notice keyboard interrupts on Win32, which don't interrupt
# accept() by default
return
except socket.error, x:
if self.stats['Enabled']:
self.stats['Socket Errors'] += 1
if x.args[0] in socket_error_eintr:
# I *think* this is right. EINTR should occur when a signal
# is received during the accept() call; all docs say retry
# the call, and I *think* I'm reading it right that Python
# will then go ahead and poll for and handle the signal
# elsewhere. See http://www.cherrypy.org/ticket/707.
return
if x.args[0] in socket_errors_nonblocking:
# Just try again. See http://www.cherrypy.org/ticket/479.
return
if x.args[0] in socket_errors_to_ignore:
# Our socket was closed.
# See http://www.cherrypy.org/ticket/686.
return
raise
def _get_interrupt(self):
return self._interrupt
def _set_interrupt(self, interrupt):
self._interrupt = True
self.stop()
self._interrupt = interrupt
interrupt = property(_get_interrupt, _set_interrupt,
doc="Set this to an Exception instance to "
"interrupt the server.")
def stop(self):
"""Gracefully shutdown a server that is serving forever."""
self.ready = False
if self._start_time is not None:
self._run_time += (time.time() - self._start_time)
self._start_time = None
sock = getattr(self, "socket", None)
if sock:
if not isinstance(self.bind_addr, basestring):
# Touch our own socket to make accept() return immediately.
try:
host, port = sock.getsockname()[:2]
except socket.error, x:
if x.args[0] not in socket_errors_to_ignore:
# Changed to use error code and not message
# See http://www.cherrypy.org/ticket/860.
raise
else:
# Note that we're explicitly NOT using AI_PASSIVE,
# here, because we want an actual IP to touch.
# localhost won't work if we've bound to a public IP,
# but it will if we bound to '0.0.0.0' (INADDR_ANY).
for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC,
socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
s = None
try:
s = socket.socket(af, socktype, proto)
# See http://groups.google.com/group/cherrypy-users/
# browse_frm/thread/bbfe5eb39c904fe0
s.settimeout(1.0)
s.connect((host, port))
s.close()
except socket.error:
if s:
s.close()
if hasattr(sock, "close"):
sock.close()
self.socket = None
self.requests.stop(self.shutdown_timeout)
class Gateway(object):
def __init__(self, req):
self.req = req
def respond(self):
raise NotImplemented
# These may either be wsgiserver.SSLAdapter subclasses or the string names
# of such classes (in which case they will be lazily loaded).
ssl_adapters = {
'builtin': 'cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter',
'pyopenssl': 'cherrypy.wsgiserver.ssl_pyopenssl.pyOpenSSLAdapter',
}
def get_ssl_adapter_class(name='pyopenssl'):
adapter = ssl_adapters[name.lower()]
if isinstance(adapter, basestring):
last_dot = adapter.rfind(".")
attr_name = adapter[last_dot + 1:]
mod_path = adapter[:last_dot]
try:
mod = sys.modules[mod_path]
if mod is None:
raise KeyError()
except KeyError:
# The last [''] is important.
mod = __import__(mod_path, globals(), locals(), [''])
# Let an AttributeError propagate outward.
try:
adapter = getattr(mod, attr_name)
except AttributeError:
raise AttributeError("'%s' object has no attribute '%s'"
% (mod_path, attr_name))
return adapter
# -------------------------------- WSGI Stuff -------------------------------- #
class CherryPyWSGIServer(HTTPServer):
wsgi_version = (1, 0)
def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None,
max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5):
self.requests = ThreadPool(self, min=numthreads or 1, max=max)
self.wsgi_app = wsgi_app
self.gateway = wsgi_gateways[self.wsgi_version]
self.bind_addr = bind_addr
if not server_name:
server_name = socket.gethostname()
self.server_name = server_name
self.request_queue_size = request_queue_size
self.timeout = timeout
self.shutdown_timeout = shutdown_timeout
self.clear_stats()
def _get_numthreads(self):
return self.requests.min
def _set_numthreads(self, value):
self.requests.min = value
numthreads = property(_get_numthreads, _set_numthreads)
class WSGIGateway(Gateway):
def __init__(self, req):
self.req = req
self.started_response = False
self.env = self.get_environ()
self.remaining_bytes_out = None
def get_environ(self):
"""Return a new environ dict targeting the given wsgi.version"""
raise NotImplemented
def respond(self):
response = self.req.server.wsgi_app(self.env, self.start_response)
try:
for chunk in response:
# "The start_response callable must not actually transmit
# the response headers. Instead, it must store them for the
# server or gateway to transmit only after the first
# iteration of the application return value that yields
# a NON-EMPTY string, or upon the application's first
# invocation of the write() callable." (PEP 333)
if chunk:
if isinstance(chunk, unicode):
chunk = chunk.encode('ISO-8859-1')
self.write(chunk)
finally:
if hasattr(response, "close"):
response.close()
def start_response(self, status, headers, exc_info = None):
"""WSGI callable to begin the HTTP response."""
# "The application may call start_response more than once,
# if and only if the exc_info argument is provided."
if self.started_response and not exc_info:
raise AssertionError("WSGI start_response called a second "
"time with no exc_info.")
self.started_response = True
# "if exc_info is provided, and the HTTP headers have already been
# sent, start_response must raise an error, and should raise the
# exc_info tuple."
if self.req.sent_headers:
try:
raise exc_info[0], exc_info[1], exc_info[2]
finally:
exc_info = None
self.req.status = status
for k, v in headers:
if not isinstance(k, str):
raise TypeError("WSGI response header key %r is not a byte string." % k)
if not isinstance(v, str):
raise TypeError("WSGI response header value %r is not a byte string." % v)
if k.lower() == 'content-length':
self.remaining_bytes_out = int(v)
self.req.outheaders.extend(headers)
return self.write
def write(self, chunk):
"""WSGI callable to write unbuffered data to the client.
This method is also used internally by start_response (to write
data from the iterable returned by the WSGI application).
"""
if not self.started_response:
raise AssertionError("WSGI write called before start_response.")
chunklen = len(chunk)
rbo = self.remaining_bytes_out
if rbo is not None and chunklen > rbo:
if not self.req.sent_headers:
# Whew. We can send a 500 to the client.
self.req.simple_response("500 Internal Server Error",
"The requested resource returned more bytes than the "
"declared Content-Length.")
else:
# Dang. We have probably already sent data. Truncate the chunk
# to fit (so the client doesn't hang) and raise an error later.
chunk = chunk[:rbo]
if not self.req.sent_headers:
self.req.sent_headers = True
self.req.send_headers()
self.req.write(chunk)
if rbo is not None:
rbo -= chunklen
if rbo < 0:
raise ValueError(
"Response body exceeds the declared Content-Length.")
class WSGIGateway_10(WSGIGateway):
def get_environ(self):
"""Return a new environ dict targeting the given wsgi.version"""
req = self.req
env = {
# set a non-standard environ entry so the WSGI app can know what
# the *real* server protocol is (and what features to support).
# See http://www.faqs.org/rfcs/rfc2145.html.
'ACTUAL_SERVER_PROTOCOL': req.server.protocol,
'PATH_INFO': req.path,
'QUERY_STRING': req.qs,
'REMOTE_ADDR': req.conn.remote_addr or '',
'REMOTE_PORT': str(req.conn.remote_port or ''),
'REQUEST_METHOD': req.method,
'REQUEST_URI': req.uri,
'SCRIPT_NAME': '',
'SERVER_NAME': req.server.server_name,
# Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol.
'SERVER_PROTOCOL': req.request_protocol,
'SERVER_SOFTWARE': req.server.software,
'wsgi.errors': sys.stderr,
'wsgi.input': req.rfile,
'wsgi.multiprocess': False,
'wsgi.multithread': True,
'wsgi.run_once': False,
'wsgi.url_scheme': req.scheme,
'wsgi.version': (1, 0),
}
if isinstance(req.server.bind_addr, basestring):
# AF_UNIX. This isn't really allowed by WSGI, which doesn't
# address unix domain sockets. But it's better than nothing.
env["SERVER_PORT"] = ""
else:
env["SERVER_PORT"] = str(req.server.bind_addr[1])
# Request headers
for k, v in req.inheaders.iteritems():
env["HTTP_" + k.upper().replace("-", "_")] = v
# CONTENT_TYPE/CONTENT_LENGTH
ct = env.pop("HTTP_CONTENT_TYPE", None)
if ct is not None:
env["CONTENT_TYPE"] = ct
cl = env.pop("HTTP_CONTENT_LENGTH", None)
if cl is not None:
env["CONTENT_LENGTH"] = cl
if req.conn.ssl_env:
env.update(req.conn.ssl_env)
return env
class WSGIGateway_u0(WSGIGateway_10):
def get_environ(self):
"""Return a new environ dict targeting the given wsgi.version"""
req = self.req
env_10 = WSGIGateway_10.get_environ(self)
env = dict([(k.decode('ISO-8859-1'), v) for k, v in env_10.iteritems()])
env[u'wsgi.version'] = ('u', 0)
# Request-URI
env.setdefault(u'wsgi.url_encoding', u'utf-8')
try:
for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]:
env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding'])
except UnicodeDecodeError:
# Fall back to latin 1 so apps can transcode if needed.
env[u'wsgi.url_encoding'] = u'ISO-8859-1'
for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]:
env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding'])
for k, v in sorted(env.items()):
if isinstance(v, str) and k not in ('REQUEST_URI', 'wsgi.input'):
env[k] = v.decode('ISO-8859-1')
return env
wsgi_gateways = {
(1, 0): WSGIGateway_10,
('u', 0): WSGIGateway_u0,
}
class WSGIPathInfoDispatcher(object):
"""A WSGI dispatcher for dispatch based on the PATH_INFO.
apps: a dict or list of (path_prefix, app) pairs.
"""
def __init__(self, apps):
try:
apps = apps.items()
except AttributeError:
pass
# Sort the apps by len(path), descending
apps.sort(cmp=lambda x,y: cmp(len(x[0]), len(y[0])))
apps.reverse()
# The path_prefix strings must start, but not end, with a slash.
# Use "" instead of "/".
self.apps = [(p.rstrip("/"), a) for p, a in apps]
def __call__(self, environ, start_response):
path = environ["PATH_INFO"] or "/"
for p, app in self.apps:
# The apps list should be sorted by length, descending.
if path.startswith(p + "/") or path == p:
environ = environ.copy()
environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p
environ["PATH_INFO"] = path[len(p):]
return app(environ, start_response)
start_response('404 Not Found', [('Content-Type', 'text/plain'),
('Content-Length', '0')])
return ['']
webpy/web/wsgiserver/LICENSE.txt 0000644 0001750 0001750 00000003006 11773133455 015150 0 ustar wmb wmb Copyright (c) 2004-2007, CherryPy Team (team@cherrypy.org)
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the CherryPy Team nor the names of its contributors
may be used to endorse or promote products derived from this software
without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
webpy/web/wsgiserver/ssl_builtin.py 0000644 0001750 0001750 00000005035 11773133455 016232 0 ustar wmb wmb """A library for integrating Python's builtin ``ssl`` library with CherryPy.
The ssl module must be importable for SSL functionality.
To use this module, set ``CherryPyWSGIServer.ssl_adapter`` to an instance of
``BuiltinSSLAdapter``.
"""
try:
import ssl
except ImportError:
ssl = None
from cherrypy import wsgiserver
class BuiltinSSLAdapter(wsgiserver.SSLAdapter):
"""A wrapper for integrating Python's builtin ssl module with CherryPy."""
certificate = None
"""The filename of the server SSL certificate."""
private_key = None
"""The filename of the server's private key file."""
def __init__(self, certificate, private_key, certificate_chain=None):
if ssl is None:
raise ImportError("You must install the ssl module to use HTTPS.")
self.certificate = certificate
self.private_key = private_key
self.certificate_chain = certificate_chain
def bind(self, sock):
"""Wrap and return the given socket."""
return sock
def wrap(self, sock):
"""Wrap and return the given socket, plus WSGI environ entries."""
try:
s = ssl.wrap_socket(sock, do_handshake_on_connect=True,
server_side=True, certfile=self.certificate,
keyfile=self.private_key, ssl_version=ssl.PROTOCOL_SSLv23)
except ssl.SSLError, e:
if e.errno == ssl.SSL_ERROR_EOF:
# This is almost certainly due to the cherrypy engine
# 'pinging' the socket to assert it's connectable;
# the 'ping' isn't SSL.
return None, {}
elif e.errno == ssl.SSL_ERROR_SSL:
if e.args[1].endswith('http request'):
# The client is speaking HTTP to an HTTPS server.
raise wsgiserver.NoSSLError
raise
return s, self.get_environ(s)
# TODO: fill this out more with mod ssl env
def get_environ(self, sock):
"""Create WSGI environ entries to be merged into each request."""
cipher = sock.cipher()
ssl_environ = {
"wsgi.url_scheme": "https",
"HTTPS": "on",
'SSL_PROTOCOL': cipher[1],
'SSL_CIPHER': cipher[0]
## SSL_VERSION_INTERFACE string The mod_ssl program version
## SSL_VERSION_LIBRARY string The OpenSSL program version
}
return ssl_environ
def makefile(self, sock, mode='r', bufsize=-1):
return wsgiserver.CP_fileobject(sock, mode, bufsize)
webpy/web/wsgiserver/ssl_pyopenssl.py 0000644 0001750 0001750 00000022605 11773133455 016622 0 ustar wmb wmb """A library for integrating pyOpenSSL with CherryPy.
The OpenSSL module must be importable for SSL functionality.
You can obtain it from http://pyopenssl.sourceforge.net/
To use this module, set CherryPyWSGIServer.ssl_adapter to an instance of
SSLAdapter. There are two ways to use SSL:
Method One
----------
* ``ssl_adapter.context``: an instance of SSL.Context.
If this is not None, it is assumed to be an SSL.Context instance,
and will be passed to SSL.Connection on bind(). The developer is
responsible for forming a valid Context object. This approach is
to be preferred for more flexibility, e.g. if the cert and key are
streams instead of files, or need decryption, or SSL.SSLv3_METHOD
is desired instead of the default SSL.SSLv23_METHOD, etc. Consult
the pyOpenSSL documentation for complete options.
Method Two (shortcut)
---------------------
* ``ssl_adapter.certificate``: the filename of the server SSL certificate.
* ``ssl_adapter.private_key``: the filename of the server's private key file.
Both are None by default. If ssl_adapter.context is None, but .private_key
and .certificate are both given and valid, they will be read, and the
context will be automatically created from them.
"""
import socket
import threading
import time
from cherrypy import wsgiserver
try:
from OpenSSL import SSL
from OpenSSL import crypto
except ImportError:
SSL = None
class SSL_fileobject(wsgiserver.CP_fileobject):
"""SSL file object attached to a socket object."""
ssl_timeout = 3
ssl_retry = .01
def _safe_call(self, is_reader, call, *args, **kwargs):
"""Wrap the given call with SSL error-trapping.
is_reader: if False EOF errors will be raised. If True, EOF errors
will return "" (to emulate normal sockets).
"""
start = time.time()
while True:
try:
return call(*args, **kwargs)
except SSL.WantReadError:
# Sleep and try again. This is dangerous, because it means
# the rest of the stack has no way of differentiating
# between a "new handshake" error and "client dropped".
# Note this isn't an endless loop: there's a timeout below.
time.sleep(self.ssl_retry)
except SSL.WantWriteError:
time.sleep(self.ssl_retry)
except SSL.SysCallError, e:
if is_reader and e.args == (-1, 'Unexpected EOF'):
return ""
errnum = e.args[0]
if is_reader and errnum in wsgiserver.socket_errors_to_ignore:
return ""
raise socket.error(errnum)
except SSL.Error, e:
if is_reader and e.args == (-1, 'Unexpected EOF'):
return ""
thirdarg = None
try:
thirdarg = e.args[0][0][2]
except IndexError:
pass
if thirdarg == 'http request':
# The client is talking HTTP to an HTTPS server.
raise wsgiserver.NoSSLError()
raise wsgiserver.FatalSSLAlert(*e.args)
except:
raise
if time.time() - start > self.ssl_timeout:
raise socket.timeout("timed out")
def recv(self, *args, **kwargs):
buf = []
r = super(SSL_fileobject, self).recv
while True:
data = self._safe_call(True, r, *args, **kwargs)
buf.append(data)
p = self._sock.pending()
if not p:
return "".join(buf)
def sendall(self, *args, **kwargs):
return self._safe_call(False, super(SSL_fileobject, self).sendall,
*args, **kwargs)
def send(self, *args, **kwargs):
return self._safe_call(False, super(SSL_fileobject, self).send,
*args, **kwargs)
class SSLConnection:
"""A thread-safe wrapper for an SSL.Connection.
``*args``: the arguments to create the wrapped ``SSL.Connection(*args)``.
"""
def __init__(self, *args):
self._ssl_conn = SSL.Connection(*args)
self._lock = threading.RLock()
for f in ('get_context', 'pending', 'send', 'write', 'recv', 'read',
'renegotiate', 'bind', 'listen', 'connect', 'accept',
'setblocking', 'fileno', 'close', 'get_cipher_list',
'getpeername', 'getsockname', 'getsockopt', 'setsockopt',
'makefile', 'get_app_data', 'set_app_data', 'state_string',
'sock_shutdown', 'get_peer_certificate', 'want_read',
'want_write', 'set_connect_state', 'set_accept_state',
'connect_ex', 'sendall', 'settimeout', 'gettimeout'):
exec("""def %s(self, *args):
self._lock.acquire()
try:
return self._ssl_conn.%s(*args)
finally:
self._lock.release()
""" % (f, f))
def shutdown(self, *args):
self._lock.acquire()
try:
# pyOpenSSL.socket.shutdown takes no args
return self._ssl_conn.shutdown()
finally:
self._lock.release()
class pyOpenSSLAdapter(wsgiserver.SSLAdapter):
"""A wrapper for integrating pyOpenSSL with CherryPy."""
context = None
"""An instance of SSL.Context."""
certificate = None
"""The filename of the server SSL certificate."""
private_key = None
"""The filename of the server's private key file."""
certificate_chain = None
"""Optional. The filename of CA's intermediate certificate bundle.
This is needed for cheaper "chained root" SSL certificates, and should be
left as None if not required."""
def __init__(self, certificate, private_key, certificate_chain=None):
if SSL is None:
raise ImportError("You must install pyOpenSSL to use HTTPS.")
self.context = None
self.certificate = certificate
self.private_key = private_key
self.certificate_chain = certificate_chain
self._environ = None
def bind(self, sock):
"""Wrap and return the given socket."""
if self.context is None:
self.context = self.get_context()
conn = SSLConnection(self.context, sock)
self._environ = self.get_environ()
return conn
def wrap(self, sock):
"""Wrap and return the given socket, plus WSGI environ entries."""
return sock, self._environ.copy()
def get_context(self):
"""Return an SSL.Context from self attributes."""
# See http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/442473
c = SSL.Context(SSL.SSLv23_METHOD)
c.use_privatekey_file(self.private_key)
if self.certificate_chain:
c.load_verify_locations(self.certificate_chain)
c.use_certificate_file(self.certificate)
return c
def get_environ(self):
"""Return WSGI environ entries to be merged into each request."""
ssl_environ = {
"HTTPS": "on",
# pyOpenSSL doesn't provide access to any of these AFAICT
## 'SSL_PROTOCOL': 'SSLv2',
## SSL_CIPHER string The cipher specification name
## SSL_VERSION_INTERFACE string The mod_ssl program version
## SSL_VERSION_LIBRARY string The OpenSSL program version
}
if self.certificate:
# Server certificate attributes
cert = open(self.certificate, 'rb').read()
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
ssl_environ.update({
'SSL_SERVER_M_VERSION': cert.get_version(),
'SSL_SERVER_M_SERIAL': cert.get_serial_number(),
## 'SSL_SERVER_V_START': Validity of server's certificate (start time),
## 'SSL_SERVER_V_END': Validity of server's certificate (end time),
})
for prefix, dn in [("I", cert.get_issuer()),
("S", cert.get_subject())]:
# X509Name objects don't seem to have a way to get the
# complete DN string. Use str() and slice it instead,
# because str(dn) == ""
dnstr = str(dn)[18:-2]
wsgikey = 'SSL_SERVER_%s_DN' % prefix
ssl_environ[wsgikey] = dnstr
# The DN should be of the form: /k1=v1/k2=v2, but we must allow
# for any value to contain slashes itself (in a URL).
while dnstr:
pos = dnstr.rfind("=")
dnstr, value = dnstr[:pos], dnstr[pos + 1:]
pos = dnstr.rfind("/")
dnstr, key = dnstr[:pos], dnstr[pos + 1:]
if key and value:
wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key)
ssl_environ[wsgikey] = value
return ssl_environ
def makefile(self, sock, mode='r', bufsize=-1):
if SSL and isinstance(sock, SSL.ConnectionType):
timeout = sock.gettimeout()
f = SSL_fileobject(sock, mode, bufsize)
f.ssl_timeout = timeout
return f
else:
return wsgiserver.CP_fileobject(sock, mode, bufsize)
webpy/web/contrib/ 0002755 0001750 0001750 00000000000 11773133455 012570 5 ustar wmb wmb webpy/web/contrib/__init__.py 0000644 0001750 0001750 00000000000 11773133455 014665 0 ustar wmb wmb webpy/web/contrib/template.py 0000644 0001750 0001750 00000006571 11773133455 014764 0 ustar wmb wmb """
Interface to various templating engines.
"""
import os.path
__all__ = [
"render_cheetah", "render_genshi", "render_mako",
"cache",
]
class render_cheetah:
"""Rendering interface to Cheetah Templates.
Example:
render = render_cheetah('templates')
render.hello(name="cheetah")
"""
def __init__(self, path):
# give error if Chetah is not installed
from Cheetah.Template import Template
self.path = path
def __getattr__(self, name):
from Cheetah.Template import Template
path = os.path.join(self.path, name + ".html")
def template(**kw):
t = Template(file=path, searchList=[kw])
return t.respond()
return template
class render_genshi:
"""Rendering interface genshi templates.
Example:
for xml/html templates.
render = render_genshi(['templates/'])
render.hello(name='genshi')
For text templates:
render = render_genshi(['templates/'], type='text')
render.hello(name='genshi')
"""
def __init__(self, *a, **kwargs):
from genshi.template import TemplateLoader
self._type = kwargs.pop('type', None)
self._loader = TemplateLoader(*a, **kwargs)
def __getattr__(self, name):
# Assuming all templates are html
path = name + ".html"
if self._type == "text":
from genshi.template import TextTemplate
cls = TextTemplate
type = "text"
else:
cls = None
type = None
t = self._loader.load(path, cls=cls)
def template(**kw):
stream = t.generate(**kw)
if type:
return stream.render(type)
else:
return stream.render()
return template
class render_jinja:
"""Rendering interface to Jinja2 Templates
Example:
render= render_jinja('templates')
render.hello(name='jinja2')
"""
def __init__(self, *a, **kwargs):
extensions = kwargs.pop('extensions', [])
globals = kwargs.pop('globals', {})
from jinja2 import Environment,FileSystemLoader
self._lookup = Environment(loader=FileSystemLoader(*a, **kwargs), extensions=extensions)
self._lookup.globals.update(globals)
def __getattr__(self, name):
# Assuming all templates end with .html
path = name + '.html'
t = self._lookup.get_template(path)
return t.render
class render_mako:
"""Rendering interface to Mako Templates.
Example:
render = render_mako(directories=['templates'])
render.hello(name="mako")
"""
def __init__(self, *a, **kwargs):
from mako.lookup import TemplateLookup
self._lookup = TemplateLookup(*a, **kwargs)
def __getattr__(self, name):
# Assuming all templates are html
path = name + ".html"
t = self._lookup.get_template(path)
return t.render
class cache:
"""Cache for any rendering interface.
Example:
render = cache(render_cheetah("templates/"))
render.hello(name='cache')
"""
def __init__(self, render):
self._render = render
self._cache = {}
def __getattr__(self, name):
if name not in self._cache:
self._cache[name] = getattr(self._render, name)
return self._cache[name]
webpy/web/browser.py 0000644 0001750 0001750 00000016777 11773133455 013205 0 ustar wmb wmb """Browser to test web applications.
(from web.py)
"""
from utils import re_compile
from net import htmlunquote
import httplib, urllib, urllib2
import copy
from StringIO import StringIO
DEBUG = False
__all__ = [
"BrowserError",
"Browser", "AppBrowser",
"AppHandler"
]
class BrowserError(Exception):
pass
class Browser:
def __init__(self):
import cookielib
self.cookiejar = cookielib.CookieJar()
self._cookie_processor = urllib2.HTTPCookieProcessor(self.cookiejar)
self.form = None
self.url = "http://0.0.0.0:8080/"
self.path = "/"
self.status = None
self.data = None
self._response = None
self._forms = None
def reset(self):
"""Clears all cookies and history."""
self.cookiejar.clear()
def build_opener(self):
"""Builds the opener using urllib2.build_opener.
Subclasses can override this function to prodive custom openers.
"""
return urllib2.build_opener()
def do_request(self, req):
if DEBUG:
print 'requesting', req.get_method(), req.get_full_url()
opener = self.build_opener()
opener.add_handler(self._cookie_processor)
try:
self._response = opener.open(req)
except urllib2.HTTPError, e:
self._response = e
self.url = self._response.geturl()
self.path = urllib2.Request(self.url).get_selector()
self.data = self._response.read()
self.status = self._response.code
self._forms = None
self.form = None
return self.get_response()
def open(self, url, data=None, headers={}):
"""Opens the specified url."""
url = urllib.basejoin(self.url, url)
req = urllib2.Request(url, data, headers)
return self.do_request(req)
def show(self):
"""Opens the current page in real web browser."""
f = open('page.html', 'w')
f.write(self.data)
f.close()
import webbrowser, os
url = 'file://' + os.path.abspath('page.html')
webbrowser.open(url)
def get_response(self):
"""Returns a copy of the current response."""
return urllib.addinfourl(StringIO(self.data), self._response.info(), self._response.geturl())
def get_soup(self):
"""Returns beautiful soup of the current document."""
import BeautifulSoup
return BeautifulSoup.BeautifulSoup(self.data)
def get_text(self, e=None):
"""Returns content of e or the current document as plain text."""
e = e or self.get_soup()
return ''.join([htmlunquote(c) for c in e.recursiveChildGenerator() if isinstance(c, unicode)])
def _get_links(self):
soup = self.get_soup()
return [a for a in soup.findAll(name='a')]
def get_links(self, text=None, text_regex=None, url=None, url_regex=None, predicate=None):
"""Returns all links in the document."""
return self._filter_links(self._get_links(),
text=text, text_regex=text_regex, url=url, url_regex=url_regex, predicate=predicate)
def follow_link(self, link=None, text=None, text_regex=None, url=None, url_regex=None, predicate=None):
if link is None:
links = self._filter_links(self.get_links(),
text=text, text_regex=text_regex, url=url, url_regex=url_regex, predicate=predicate)
link = links and links[0]
if link:
return self.open(link['href'])
else:
raise BrowserError("No link found")
def find_link(self, text=None, text_regex=None, url=None, url_regex=None, predicate=None):
links = self._filter_links(self.get_links(),
text=text, text_regex=text_regex, url=url, url_regex=url_regex, predicate=predicate)
return links and links[0] or None
def _filter_links(self, links,
text=None, text_regex=None,
url=None, url_regex=None,
predicate=None):
predicates = []
if text is not None:
predicates.append(lambda link: link.string == text)
if text_regex is not None:
predicates.append(lambda link: re_compile(text_regex).search(link.string or ''))
if url is not None:
predicates.append(lambda link: link.get('href') == url)
if url_regex is not None:
predicates.append(lambda link: re_compile(url_regex).search(link.get('href', '')))
if predicate:
predicate.append(predicate)
def f(link):
for p in predicates:
if not p(link):
return False
return True
return [link for link in links if f(link)]
def get_forms(self):
"""Returns all forms in the current document.
The returned form objects implement the ClientForm.HTMLForm interface.
"""
if self._forms is None:
import ClientForm
self._forms = ClientForm.ParseResponse(self.get_response(), backwards_compat=False)
return self._forms
def select_form(self, name=None, predicate=None, index=0):
"""Selects the specified form."""
forms = self.get_forms()
if name is not None:
forms = [f for f in forms if f.name == name]
if predicate:
forms = [f for f in forms if predicate(f)]
if forms:
self.form = forms[index]
return self.form
else:
raise BrowserError("No form selected.")
def submit(self, **kw):
"""submits the currently selected form."""
if self.form is None:
raise BrowserError("No form selected.")
req = self.form.click(**kw)
return self.do_request(req)
def __getitem__(self, key):
return self.form[key]
def __setitem__(self, key, value):
self.form[key] = value
class AppBrowser(Browser):
"""Browser interface to test web.py apps.
b = AppBrowser(app)
b.open('/')
b.follow_link(text='Login')
b.select_form(name='login')
b['username'] = 'joe'
b['password'] = 'secret'
b.submit()
assert b.path == '/'
assert 'Welcome joe' in b.get_text()
"""
def __init__(self, app):
Browser.__init__(self)
self.app = app
def build_opener(self):
return urllib2.build_opener(AppHandler(self.app))
class AppHandler(urllib2.HTTPHandler):
"""urllib2 handler to handle requests using web.py application."""
handler_order = 100
def __init__(self, app):
self.app = app
def http_open(self, req):
result = self.app.request(
localpart=req.get_selector(),
method=req.get_method(),
host=req.get_host(),
data=req.get_data(),
headers=dict(req.header_items()),
https=req.get_type() == "https"
)
return self._make_response(result, req.get_full_url())
def https_open(self, req):
return self.http_open(req)
try:
https_request = urllib2.HTTPHandler.do_request_
except AttributeError:
# for python 2.3
pass
def _make_response(self, result, url):
data = "\r\n".join(["%s: %s" % (k, v) for k, v in result.header_items])
headers = httplib.HTTPMessage(StringIO(data))
response = urllib.addinfourl(StringIO(result.data), headers, url)
code, msg = result.status.split(None, 1)
response.code, response.msg = int(code), msg
return response
webpy/web/debugerror.py 0000644 0001750 0001750 00000030072 11773133455 013642 0 ustar wmb wmb """
pretty debug errors
(part of web.py)
portions adapted from Django
Copyright (c) 2005, the Lawrence Journal-World
Used under the modified BSD license:
http://www.xfree86.org/3.3.6/COPYRIGHT2.html#5
"""
__all__ = ["debugerror", "djangoerror", "emailerrors"]
import sys, urlparse, pprint, traceback
from template import Template
from net import websafe
from utils import sendmail, safestr
import webapi as web
import os, os.path
whereami = os.path.join(os.getcwd(), __file__)
whereami = os.path.sep.join(whereami.split(os.path.sep)[:-1])
djangoerror_t = """\
$def with (exception_type, exception_value, frames)
$exception_type at $ctx.path
$def dicttable (d, kls='req', id=None):
$ items = d and d.items() or []
$items.sort()
$:dicttable_items(items, kls, id)
$def dicttable_items(items, kls='req', id=None):
$if items:
Variable
Value
$for k, v in items:
$k
$prettify(v)
$else:
No data.
$exception_type at $ctx.path
$exception_value
Python
$frames[0].filename in $frames[0].function, line $frames[0].lineno
Web
$ctx.method $ctx.home$ctx.path
Traceback (innermost first)
$for frame in frames:
$frame.filename in $frame.function
$if frame.context_line is not None:
$if frame.pre_context:
$for line in frame.pre_context:
$line
$frame.context_line ...
$if frame.post_context:
$for line in frame.post_context:
$line
$if frame.vars:
▶ Local vars
$# $inspect.formatargvalues(*inspect.getargvalues(frame['tb'].tb_frame))
$ newctx = [(k, v) for (k, v) in ctx.iteritems() if not k.startswith('_') and not isinstance(v, dict)]
$:dicttable(dict(newctx))
ENVIRONMENT
$:dicttable(ctx.env)
You're seeing this error because you have web.config.debug
set to True. Set that to False if you don't want to see this.
"""
djangoerror_r = None
def djangoerror():
def _get_lines_from_file(filename, lineno, context_lines):
"""
Returns context_lines before and after lineno from file.
Returns (pre_context_lineno, pre_context, context_line, post_context).
"""
try:
source = open(filename).readlines()
lower_bound = max(0, lineno - context_lines)
upper_bound = lineno + context_lines
pre_context = \
[line.strip('\n') for line in source[lower_bound:lineno]]
context_line = source[lineno].strip('\n')
post_context = \
[line.strip('\n') for line in source[lineno + 1:upper_bound]]
return lower_bound, pre_context, context_line, post_context
except (OSError, IOError, IndexError):
return None, [], None, []
exception_type, exception_value, tback = sys.exc_info()
frames = []
while tback is not None:
filename = tback.tb_frame.f_code.co_filename
function = tback.tb_frame.f_code.co_name
lineno = tback.tb_lineno - 1
# hack to get correct line number for templates
lineno += tback.tb_frame.f_locals.get("__lineoffset__", 0)
pre_context_lineno, pre_context, context_line, post_context = \
_get_lines_from_file(filename, lineno, 7)
if '__hidetraceback__' not in tback.tb_frame.f_locals:
frames.append(web.storage({
'tback': tback,
'filename': filename,
'function': function,
'lineno': lineno,
'vars': tback.tb_frame.f_locals,
'id': id(tback),
'pre_context': pre_context,
'context_line': context_line,
'post_context': post_context,
'pre_context_lineno': pre_context_lineno,
}))
tback = tback.tb_next
frames.reverse()
urljoin = urlparse.urljoin
def prettify(x):
try:
out = pprint.pformat(x)
except Exception, e:
out = '[could not display: <' + e.__class__.__name__ + \
': '+str(e)+'>]'
return out
global djangoerror_r
if djangoerror_r is None:
djangoerror_r = Template(djangoerror_t, filename=__file__, filter=websafe)
t = djangoerror_r
globals = {'ctx': web.ctx, 'web':web, 'dict':dict, 'str':str, 'prettify': prettify}
t.t.func_globals.update(globals)
return t(exception_type, exception_value, frames)
def debugerror():
"""
A replacement for `internalerror` that presents a nice page with lots
of debug information for the programmer.
(Based on the beautiful 500 page from [Django](http://djangoproject.com/),
designed by [Wilson Miner](http://wilsonminer.com/).)
"""
return web._InternalError(djangoerror())
def emailerrors(to_address, olderror, from_address=None):
"""
Wraps the old `internalerror` handler (pass as `olderror`) to
additionally email all errors to `to_address`, to aid in
debugging production websites.
Emails contain a normal text traceback as well as an
attachment containing the nice `debugerror` page.
"""
from_address = from_address or to_address
def emailerrors_internal():
error = olderror()
tb = sys.exc_info()
error_name = tb[0]
error_value = tb[1]
tb_txt = ''.join(traceback.format_exception(*tb))
path = web.ctx.path
request = web.ctx.method + ' ' + web.ctx.home + web.ctx.fullpath
message = "\n%s\n\n%s\n\n" % (request, tb_txt)
sendmail(
"your buggy site <%s>" % from_address,
"the bugfixer <%s>" % to_address,
"bug: %(error_name)s: %(error_value)s (%(path)s)" % locals(),
message,
attachments=[
dict(filename="bug.html", content=safestr(djangoerror()))
],
)
return error
return emailerrors_internal
if __name__ == "__main__":
urls = (
'/', 'index'
)
from application import application
app = application(urls, globals())
app.internalerror = debugerror
class index:
def GET(self):
thisdoesnotexist
app.run()
webpy/web/webopenid.py 0000644 0001750 0001750 00000007173 11773133455 013464 0 ustar wmb wmb """openid.py: an openid library for web.py
Notes:
- This will create a file called .openid_secret_key in the
current directory with your secret key in it. If someone
has access to this file they can log in as any user. And
if the app can't find this file for any reason (e.g. you
moved the app somewhere else) then each currently logged
in user will get logged out.
- State must be maintained through the entire auth process
-- this means that if you have multiple web.py processes
serving one set of URLs or if you restart your app often
then log ins will fail. You have to replace sessions and
store for things to work.
- We set cookies starting with "openid_".
"""
import os
import random
import hmac
import __init__ as web
import openid.consumer.consumer
import openid.store.memstore
sessions = {}
store = openid.store.memstore.MemoryStore()
def _secret():
try:
secret = file('.openid_secret_key').read()
except IOError:
# file doesn't exist
secret = os.urandom(20)
file('.openid_secret_key', 'w').write(secret)
return secret
def _hmac(identity_url):
return hmac.new(_secret(), identity_url).hexdigest()
def _random_session():
n = random.random()
while n in sessions:
n = random.random()
n = str(n)
return n
def status():
oid_hash = web.cookies().get('openid_identity_hash', '').split(',', 1)
if len(oid_hash) > 1:
oid_hash, identity_url = oid_hash
if oid_hash == _hmac(identity_url):
return identity_url
return None
def form(openid_loc):
oid = status()
if oid:
return '''
''' % (openid_loc, oid, web.ctx.fullpath)
else:
return '''
''' % (openid_loc, web.ctx.fullpath)
def logout():
web.setcookie('openid_identity_hash', '', expires=-1)
class host:
def POST(self):
# unlike the usual scheme of things, the POST is actually called
# first here
i = web.input(return_to='/')
if i.get('action') == 'logout':
logout()
return web.redirect(i.return_to)
i = web.input('openid', return_to='/')
n = _random_session()
sessions[n] = {'webpy_return_to': i.return_to}
c = openid.consumer.consumer.Consumer(sessions[n], store)
a = c.begin(i.openid)
f = a.redirectURL(web.ctx.home, web.ctx.home + web.ctx.fullpath)
web.setcookie('openid_session_id', n)
return web.redirect(f)
def GET(self):
n = web.cookies('openid_session_id').openid_session_id
web.setcookie('openid_session_id', '', expires=-1)
return_to = sessions[n]['webpy_return_to']
c = openid.consumer.consumer.Consumer(sessions[n], store)
a = c.complete(web.input(), web.ctx.home + web.ctx.fullpath)
if a.status.lower() == 'success':
web.setcookie('openid_identity_hash', _hmac(a.identity_url) + ',' + a.identity_url)
del sessions[n]
return web.redirect(return_to)
webpy/web/test.py 0000644 0001750 0001750 00000002561 11773133455 012463 0 ustar wmb wmb """test utilities
(part of web.py)
"""
import unittest
import sys, os
import web
TestCase = unittest.TestCase
TestSuite = unittest.TestSuite
def load_modules(names):
return [__import__(name, None, None, "x") for name in names]
def module_suite(module, classnames=None):
"""Makes a suite from a module."""
if classnames:
return unittest.TestLoader().loadTestsFromNames(classnames, module)
elif hasattr(module, 'suite'):
return module.suite()
else:
return unittest.TestLoader().loadTestsFromModule(module)
def doctest_suite(module_names):
"""Makes a test suite from doctests."""
import doctest
suite = TestSuite()
for mod in load_modules(module_names):
suite.addTest(doctest.DocTestSuite(mod))
return suite
def suite(module_names):
"""Creates a suite from multiple modules."""
suite = TestSuite()
for mod in load_modules(module_names):
suite.addTest(module_suite(mod))
return suite
def runTests(suite):
runner = unittest.TextTestRunner()
return runner.run(suite)
def main(suite=None):
if not suite:
main_module = __import__('__main__')
# allow command line switches
args = [a for a in sys.argv[1:] if not a.startswith('-')]
suite = module_suite(main_module, args or None)
result = runTests(suite)
sys.exit(not result.wasSuccessful())
webpy/web/utils.py 0000755 0001750 0001750 00000123541 11773133455 012651 0 ustar wmb wmb #!/usr/bin/env python
"""
General Utilities
(part of web.py)
"""
__all__ = [
"Storage", "storage", "storify",
"Counter", "counter",
"iters",
"rstrips", "lstrips", "strips",
"safeunicode", "safestr", "utf8",
"TimeoutError", "timelimit",
"Memoize", "memoize",
"re_compile", "re_subm",
"group", "uniq", "iterview",
"IterBetter", "iterbetter",
"safeiter", "safewrite",
"dictreverse", "dictfind", "dictfindall", "dictincr", "dictadd",
"requeue", "restack",
"listget", "intget", "datestr",
"numify", "denumify", "commify", "dateify",
"nthstr", "cond",
"CaptureStdout", "capturestdout", "Profile", "profile",
"tryall",
"ThreadedDict", "threadeddict",
"autoassign",
"to36",
"safemarkdown",
"sendmail"
]
import re, sys, time, threading, itertools, traceback, os
try:
import subprocess
except ImportError:
subprocess = None
try: import datetime
except ImportError: pass
try: set
except NameError:
from sets import Set as set
try:
from threading import local as threadlocal
except ImportError:
from python23 import threadlocal
class Storage(dict):
"""
A Storage object is like a dictionary except `obj.foo` can be used
in addition to `obj['foo']`.
>>> o = storage(a=1)
>>> o.a
1
>>> o['a']
1
>>> o.a = 2
>>> o['a']
2
>>> del o.a
>>> o.a
Traceback (most recent call last):
...
AttributeError: 'a'
"""
def __getattr__(self, key):
try:
return self[key]
except KeyError, k:
raise AttributeError, k
def __setattr__(self, key, value):
self[key] = value
def __delattr__(self, key):
try:
del self[key]
except KeyError, k:
raise AttributeError, k
def __repr__(self):
return ''
storage = Storage
def storify(mapping, *requireds, **defaults):
"""
Creates a `storage` object from dictionary `mapping`, raising `KeyError` if
d doesn't have all of the keys in `requireds` and using the default
values for keys found in `defaults`.
For example, `storify({'a':1, 'c':3}, b=2, c=0)` will return the equivalent of
`storage({'a':1, 'b':2, 'c':3})`.
If a `storify` value is a list (e.g. multiple values in a form submission),
`storify` returns the last element of the list, unless the key appears in
`defaults` as a list. Thus:
>>> storify({'a':[1, 2]}).a
2
>>> storify({'a':[1, 2]}, a=[]).a
[1, 2]
>>> storify({'a':1}, a=[]).a
[1]
>>> storify({}, a=[]).a
[]
Similarly, if the value has a `value` attribute, `storify will return _its_
value, unless the key appears in `defaults` as a dictionary.
>>> storify({'a':storage(value=1)}).a
1
>>> storify({'a':storage(value=1)}, a={}).a
>>> storify({}, a={}).a
{}
Optionally, keyword parameter `_unicode` can be passed to convert all values to unicode.
>>> storify({'x': 'a'}, _unicode=True)
>>> storify({'x': storage(value='a')}, x={}, _unicode=True)
}>
>>> storify({'x': storage(value='a')}, _unicode=True)
"""
_unicode = defaults.pop('_unicode', False)
# if _unicode is callable object, use it convert a string to unicode.
to_unicode = safeunicode
if _unicode is not False and hasattr(_unicode, "__call__"):
to_unicode = _unicode
def unicodify(s):
if _unicode and isinstance(s, str): return to_unicode(s)
else: return s
def getvalue(x):
if hasattr(x, 'file') and hasattr(x, 'value'):
return x.value
elif hasattr(x, 'value'):
return unicodify(x.value)
else:
return unicodify(x)
stor = Storage()
for key in requireds + tuple(mapping.keys()):
value = mapping[key]
if isinstance(value, list):
if isinstance(defaults.get(key), list):
value = [getvalue(x) for x in value]
else:
value = value[-1]
if not isinstance(defaults.get(key), dict):
value = getvalue(value)
if isinstance(defaults.get(key), list) and not isinstance(value, list):
value = [value]
setattr(stor, key, value)
for (key, value) in defaults.iteritems():
result = value
if hasattr(stor, key):
result = stor[key]
if value == () and not isinstance(result, tuple):
result = (result,)
setattr(stor, key, result)
return stor
class Counter(storage):
"""Keeps count of how many times something is added.
>>> c = counter()
>>> c.add('x')
>>> c.add('x')
>>> c.add('x')
>>> c.add('x')
>>> c.add('x')
>>> c.add('y')
>>> c
>>> c.most()
['x']
"""
def add(self, n):
self.setdefault(n, 0)
self[n] += 1
def most(self):
"""Returns the keys with maximum count."""
m = max(self.itervalues())
return [k for k, v in self.iteritems() if v == m]
def least(self):
"""Returns the keys with mininum count."""
m = min(self.itervalues())
return [k for k, v in self.iteritems() if v == m]
def percent(self, key):
"""Returns what percentage a certain key is of all entries.
>>> c = counter()
>>> c.add('x')
>>> c.add('x')
>>> c.add('x')
>>> c.add('y')
>>> c.percent('x')
0.75
>>> c.percent('y')
0.25
"""
return float(self[key])/sum(self.values())
def sorted_keys(self):
"""Returns keys sorted by value.
>>> c = counter()
>>> c.add('x')
>>> c.add('x')
>>> c.add('y')
>>> c.sorted_keys()
['x', 'y']
"""
return sorted(self.keys(), key=lambda k: self[k], reverse=True)
def sorted_values(self):
"""Returns values sorted by value.
>>> c = counter()
>>> c.add('x')
>>> c.add('x')
>>> c.add('y')
>>> c.sorted_values()
[2, 1]
"""
return [self[k] for k in self.sorted_keys()]
def sorted_items(self):
"""Returns items sorted by value.
>>> c = counter()
>>> c.add('x')
>>> c.add('x')
>>> c.add('y')
>>> c.sorted_items()
[('x', 2), ('y', 1)]
"""
return [(k, self[k]) for k in self.sorted_keys()]
def __repr__(self):
return ''
counter = Counter
iters = [list, tuple]
import __builtin__
if hasattr(__builtin__, 'set'):
iters.append(set)
if hasattr(__builtin__, 'frozenset'):
iters.append(set)
if sys.version_info < (2,6): # sets module deprecated in 2.6
try:
from sets import Set
iters.append(Set)
except ImportError:
pass
class _hack(tuple): pass
iters = _hack(iters)
iters.__doc__ = """
A list of iterable items (like lists, but not strings). Includes whichever
of lists, tuples, sets, and Sets are available in this version of Python.
"""
def _strips(direction, text, remove):
if isinstance(remove, iters):
for subr in remove:
text = _strips(direction, text, subr)
return text
if direction == 'l':
if text.startswith(remove):
return text[len(remove):]
elif direction == 'r':
if text.endswith(remove):
return text[:-len(remove)]
else:
raise ValueError, "Direction needs to be r or l."
return text
def rstrips(text, remove):
"""
removes the string `remove` from the right of `text`
>>> rstrips("foobar", "bar")
'foo'
"""
return _strips('r', text, remove)
def lstrips(text, remove):
"""
removes the string `remove` from the left of `text`
>>> lstrips("foobar", "foo")
'bar'
>>> lstrips('http://foo.org/', ['http://', 'https://'])
'foo.org/'
>>> lstrips('FOOBARBAZ', ['FOO', 'BAR'])
'BAZ'
>>> lstrips('FOOBARBAZ', ['BAR', 'FOO'])
'BARBAZ'
"""
return _strips('l', text, remove)
def strips(text, remove):
"""
removes the string `remove` from the both sides of `text`
>>> strips("foobarfoo", "foo")
'bar'
"""
return rstrips(lstrips(text, remove), remove)
def safeunicode(obj, encoding='utf-8'):
r"""
Converts any given object to unicode string.
>>> safeunicode('hello')
u'hello'
>>> safeunicode(2)
u'2'
>>> safeunicode('\xe1\x88\xb4')
u'\u1234'
"""
t = type(obj)
if t is unicode:
return obj
elif t is str:
return obj.decode(encoding)
elif t in [int, float, bool]:
return unicode(obj)
elif hasattr(obj, '__unicode__') or isinstance(obj, unicode):
return unicode(obj)
else:
return str(obj).decode(encoding)
def safestr(obj, encoding='utf-8'):
r"""
Converts any given object to utf-8 encoded string.
>>> safestr('hello')
'hello'
>>> safestr(u'\u1234')
'\xe1\x88\xb4'
>>> safestr(2)
'2'
"""
if isinstance(obj, unicode):
return obj.encode(encoding)
elif isinstance(obj, str):
return obj
elif hasattr(obj, 'next'): # iterator
return itertools.imap(safestr, obj)
else:
return str(obj)
# for backward-compatibility
utf8 = safestr
class TimeoutError(Exception): pass
def timelimit(timeout):
"""
A decorator to limit a function to `timeout` seconds, raising `TimeoutError`
if it takes longer.
>>> import time
>>> def meaningoflife():
... time.sleep(.2)
... return 42
>>>
>>> timelimit(.1)(meaningoflife)()
Traceback (most recent call last):
...
TimeoutError: took too long
>>> timelimit(1)(meaningoflife)()
42
_Caveat:_ The function isn't stopped after `timeout` seconds but continues
executing in a separate thread. (There seems to be no way to kill a thread.)
inspired by
"""
def _1(function):
def _2(*args, **kw):
class Dispatch(threading.Thread):
def __init__(self):
threading.Thread.__init__(self)
self.result = None
self.error = None
self.setDaemon(True)
self.start()
def run(self):
try:
self.result = function(*args, **kw)
except:
self.error = sys.exc_info()
c = Dispatch()
c.join(timeout)
if c.isAlive():
raise TimeoutError, 'took too long'
if c.error:
raise c.error[0], c.error[1]
return c.result
return _2
return _1
class Memoize:
"""
'Memoizes' a function, caching its return values for each input.
If `expires` is specified, values are recalculated after `expires` seconds.
If `background` is specified, values are recalculated in a separate thread.
>>> calls = 0
>>> def howmanytimeshaveibeencalled():
... global calls
... calls += 1
... return calls
>>> fastcalls = memoize(howmanytimeshaveibeencalled)
>>> howmanytimeshaveibeencalled()
1
>>> howmanytimeshaveibeencalled()
2
>>> fastcalls()
3
>>> fastcalls()
3
>>> import time
>>> fastcalls = memoize(howmanytimeshaveibeencalled, .1, background=False)
>>> fastcalls()
4
>>> fastcalls()
4
>>> time.sleep(.2)
>>> fastcalls()
5
>>> def slowfunc():
... time.sleep(.1)
... return howmanytimeshaveibeencalled()
>>> fastcalls = memoize(slowfunc, .2, background=True)
>>> fastcalls()
6
>>> timelimit(.05)(fastcalls)()
6
>>> time.sleep(.2)
>>> timelimit(.05)(fastcalls)()
6
>>> timelimit(.05)(fastcalls)()
6
>>> time.sleep(.2)
>>> timelimit(.05)(fastcalls)()
7
>>> fastcalls = memoize(slowfunc, None, background=True)
>>> threading.Thread(target=fastcalls).start()
>>> time.sleep(.01)
>>> fastcalls()
9
"""
def __init__(self, func, expires=None, background=True):
self.func = func
self.cache = {}
self.expires = expires
self.background = background
self.running = {}
def __call__(self, *args, **keywords):
key = (args, tuple(keywords.items()))
if not self.running.get(key):
self.running[key] = threading.Lock()
def update(block=False):
if self.running[key].acquire(block):
try:
self.cache[key] = (self.func(*args, **keywords), time.time())
finally:
self.running[key].release()
if key not in self.cache:
update(block=True)
elif self.expires and (time.time() - self.cache[key][1]) > self.expires:
if self.background:
threading.Thread(target=update).start()
else:
update()
return self.cache[key][0]
memoize = Memoize
re_compile = memoize(re.compile) #@@ threadsafe?
re_compile.__doc__ = """
A memoized version of re.compile.
"""
class _re_subm_proxy:
def __init__(self):
self.match = None
def __call__(self, match):
self.match = match
return ''
def re_subm(pat, repl, string):
"""
Like re.sub, but returns the replacement _and_ the match object.
>>> t, m = re_subm('g(oo+)fball', r'f\\1lish', 'goooooofball')
>>> t
'foooooolish'
>>> m.groups()
('oooooo',)
"""
compiled_pat = re_compile(pat)
proxy = _re_subm_proxy()
compiled_pat.sub(proxy.__call__, string)
return compiled_pat.sub(repl, string), proxy.match
def group(seq, size):
"""
Returns an iterator over a series of lists of length size from iterable.
>>> list(group([1,2,3,4], 2))
[[1, 2], [3, 4]]
>>> list(group([1,2,3,4,5], 2))
[[1, 2], [3, 4], [5]]
"""
def take(seq, n):
for i in xrange(n):
yield seq.next()
if not hasattr(seq, 'next'):
seq = iter(seq)
while True:
x = list(take(seq, size))
if x:
yield x
else:
break
def uniq(seq, key=None):
"""
Removes duplicate elements from a list while preserving the order of the rest.
>>> uniq([9,0,2,1,0])
[9, 0, 2, 1]
The value of the optional `key` parameter should be a function that
takes a single argument and returns a key to test the uniqueness.
>>> uniq(["Foo", "foo", "bar"], key=lambda s: s.lower())
['Foo', 'bar']
"""
key = key or (lambda x: x)
seen = set()
result = []
for v in seq:
k = key(v)
if k in seen:
continue
seen.add(k)
result.append(v)
return result
def iterview(x):
"""
Takes an iterable `x` and returns an iterator over it
which prints its progress to stderr as it iterates through.
"""
WIDTH = 70
def plainformat(n, lenx):
return '%5.1f%% (%*d/%d)' % ((float(n)/lenx)*100, len(str(lenx)), n, lenx)
def bars(size, n, lenx):
val = int((float(n)*size)/lenx + 0.5)
if size - val:
spacing = ">" + (" "*(size-val))[1:]
else:
spacing = ""
return "[%s%s]" % ("="*val, spacing)
def eta(elapsed, n, lenx):
if n == 0:
return '--:--:--'
if n == lenx:
secs = int(elapsed)
else:
secs = int((elapsed/n) * (lenx-n))
mins, secs = divmod(secs, 60)
hrs, mins = divmod(mins, 60)
return '%02d:%02d:%02d' % (hrs, mins, secs)
def format(starttime, n, lenx):
out = plainformat(n, lenx) + ' '
if n == lenx:
end = ' '
else:
end = ' ETA '
end += eta(time.time() - starttime, n, lenx)
out += bars(WIDTH - len(out) - len(end), n, lenx)
out += end
return out
starttime = time.time()
lenx = len(x)
for n, y in enumerate(x):
sys.stderr.write('\r' + format(starttime, n, lenx))
yield y
sys.stderr.write('\r' + format(starttime, n+1, lenx) + '\n')
class IterBetter:
"""
Returns an object that can be used as an iterator
but can also be used via __getitem__ (although it
cannot go backwards -- that is, you cannot request
`iterbetter[0]` after requesting `iterbetter[1]`).
>>> import itertools
>>> c = iterbetter(itertools.count())
>>> c[1]
1
>>> c[5]
5
>>> c[3]
Traceback (most recent call last):
...
IndexError: already passed 3
For boolean test, IterBetter peeps at first value in the itertor without effecting the iteration.
>>> c = iterbetter(iter(range(5)))
>>> bool(c)
True
>>> list(c)
[0, 1, 2, 3, 4]
>>> c = iterbetter(iter([]))
>>> bool(c)
False
>>> list(c)
[]
"""
def __init__(self, iterator):
self.i, self.c = iterator, 0
def __iter__(self):
if hasattr(self, "_head"):
yield self._head
while 1:
yield self.i.next()
self.c += 1
def __getitem__(self, i):
#todo: slices
if i < self.c:
raise IndexError, "already passed "+str(i)
try:
while i > self.c:
self.i.next()
self.c += 1
# now self.c == i
self.c += 1
return self.i.next()
except StopIteration:
raise IndexError, str(i)
def __nonzero__(self):
if hasattr(self, "__len__"):
return len(self) != 0
elif hasattr(self, "_head"):
return True
else:
try:
self._head = self.i.next()
except StopIteration:
return False
else:
return True
iterbetter = IterBetter
def safeiter(it, cleanup=None, ignore_errors=True):
"""Makes an iterator safe by ignoring the exceptions occured during the iteration.
"""
def next():
while True:
try:
return it.next()
except StopIteration:
raise
except:
traceback.print_exc()
it = iter(it)
while True:
yield next()
def safewrite(filename, content):
"""Writes the content to a temp file and then moves the temp file to
given filename to avoid overwriting the existing file in case of errors.
"""
f = file(filename + '.tmp', 'w')
f.write(content)
f.close()
os.rename(f.name, filename)
def dictreverse(mapping):
"""
Returns a new dictionary with keys and values swapped.
>>> dictreverse({1: 2, 3: 4})
{2: 1, 4: 3}
"""
return dict([(value, key) for (key, value) in mapping.iteritems()])
def dictfind(dictionary, element):
"""
Returns a key whose value in `dictionary` is `element`
or, if none exists, None.
>>> d = {1:2, 3:4}
>>> dictfind(d, 4)
3
>>> dictfind(d, 5)
"""
for (key, value) in dictionary.iteritems():
if element is value:
return key
def dictfindall(dictionary, element):
"""
Returns the keys whose values in `dictionary` are `element`
or, if none exists, [].
>>> d = {1:4, 3:4}
>>> dictfindall(d, 4)
[1, 3]
>>> dictfindall(d, 5)
[]
"""
res = []
for (key, value) in dictionary.iteritems():
if element is value:
res.append(key)
return res
def dictincr(dictionary, element):
"""
Increments `element` in `dictionary`,
setting it to one if it doesn't exist.
>>> d = {1:2, 3:4}
>>> dictincr(d, 1)
3
>>> d[1]
3
>>> dictincr(d, 5)
1
>>> d[5]
1
"""
dictionary.setdefault(element, 0)
dictionary[element] += 1
return dictionary[element]
def dictadd(*dicts):
"""
Returns a dictionary consisting of the keys in the argument dictionaries.
If they share a key, the value from the last argument is used.
>>> dictadd({1: 0, 2: 0}, {2: 1, 3: 1})
{1: 0, 2: 1, 3: 1}
"""
result = {}
for dct in dicts:
result.update(dct)
return result
def requeue(queue, index=-1):
"""Returns the element at index after moving it to the beginning of the queue.
>>> x = [1, 2, 3, 4]
>>> requeue(x)
4
>>> x
[4, 1, 2, 3]
"""
x = queue.pop(index)
queue.insert(0, x)
return x
def restack(stack, index=0):
"""Returns the element at index after moving it to the top of stack.
>>> x = [1, 2, 3, 4]
>>> restack(x)
1
>>> x
[2, 3, 4, 1]
"""
x = stack.pop(index)
stack.append(x)
return x
def listget(lst, ind, default=None):
"""
Returns `lst[ind]` if it exists, `default` otherwise.
>>> listget(['a'], 0)
'a'
>>> listget(['a'], 1)
>>> listget(['a'], 1, 'b')
'b'
"""
if len(lst)-1 < ind:
return default
return lst[ind]
def intget(integer, default=None):
"""
Returns `integer` as an int or `default` if it can't.
>>> intget('3')
3
>>> intget('3a')
>>> intget('3a', 0)
0
"""
try:
return int(integer)
except (TypeError, ValueError):
return default
def datestr(then, now=None):
"""
Converts a (UTC) datetime object to a nice string representation.
>>> from datetime import datetime, timedelta
>>> d = datetime(1970, 5, 1)
>>> datestr(d, now=d)
'0 microseconds ago'
>>> for t, v in {
... timedelta(microseconds=1): '1 microsecond ago',
... timedelta(microseconds=2): '2 microseconds ago',
... -timedelta(microseconds=1): '1 microsecond from now',
... -timedelta(microseconds=2): '2 microseconds from now',
... timedelta(microseconds=2000): '2 milliseconds ago',
... timedelta(seconds=2): '2 seconds ago',
... timedelta(seconds=2*60): '2 minutes ago',
... timedelta(seconds=2*60*60): '2 hours ago',
... timedelta(days=2): '2 days ago',
... }.iteritems():
... assert datestr(d, now=d+t) == v
>>> datestr(datetime(1970, 1, 1), now=d)
'January 1'
>>> datestr(datetime(1969, 1, 1), now=d)
'January 1, 1969'
>>> datestr(datetime(1970, 6, 1), now=d)
'June 1, 1970'
>>> datestr(None)
''
"""
def agohence(n, what, divisor=None):
if divisor: n = n // divisor
out = str(abs(n)) + ' ' + what # '2 day'
if abs(n) != 1: out += 's' # '2 days'
out += ' ' # '2 days '
if n < 0:
out += 'from now'
else:
out += 'ago'
return out # '2 days ago'
oneday = 24 * 60 * 60
if not then: return ""
if not now: now = datetime.datetime.utcnow()
if type(now).__name__ == "DateTime":
now = datetime.datetime.fromtimestamp(now)
if type(then).__name__ == "DateTime":
then = datetime.datetime.fromtimestamp(then)
elif type(then).__name__ == "date":
then = datetime.datetime(then.year, then.month, then.day)
delta = now - then
deltaseconds = int(delta.days * oneday + delta.seconds + delta.microseconds * 1e-06)
deltadays = abs(deltaseconds) // oneday
if deltaseconds < 0: deltadays *= -1 # fix for oddity of floor
if deltadays:
if abs(deltadays) < 4:
return agohence(deltadays, 'day')
try:
out = then.strftime('%B %e') # e.g. 'June 3'
except ValueError:
# %e doesn't work on Windows.
out = then.strftime('%B %d') # e.g. 'June 03'
if then.year != now.year or deltadays < 0:
out += ', %s' % then.year
return out
if int(deltaseconds):
if abs(deltaseconds) > (60 * 60):
return agohence(deltaseconds, 'hour', 60 * 60)
elif abs(deltaseconds) > 60:
return agohence(deltaseconds, 'minute', 60)
else:
return agohence(deltaseconds, 'second')
deltamicroseconds = delta.microseconds
if delta.days: deltamicroseconds = int(delta.microseconds - 1e6) # datetime oddity
if abs(deltamicroseconds) > 1000:
return agohence(deltamicroseconds, 'millisecond', 1000)
return agohence(deltamicroseconds, 'microsecond')
def numify(string):
"""
Removes all non-digit characters from `string`.
>>> numify('800-555-1212')
'8005551212'
>>> numify('800.555.1212')
'8005551212'
"""
return ''.join([c for c in str(string) if c.isdigit()])
def denumify(string, pattern):
"""
Formats `string` according to `pattern`, where the letter X gets replaced
by characters from `string`.
>>> denumify("8005551212", "(XXX) XXX-XXXX")
'(800) 555-1212'
"""
out = []
for c in pattern:
if c == "X":
out.append(string[0])
string = string[1:]
else:
out.append(c)
return ''.join(out)
def commify(n):
"""
Add commas to an integer `n`.
>>> commify(1)
'1'
>>> commify(123)
'123'
>>> commify(1234)
'1,234'
>>> commify(1234567890)
'1,234,567,890'
>>> commify(123.0)
'123.0'
>>> commify(1234.5)
'1,234.5'
>>> commify(1234.56789)
'1,234.56789'
>>> commify('%.2f' % 1234.5)
'1,234.50'
>>> commify(None)
>>>
"""
if n is None: return None
n = str(n)
if '.' in n:
dollars, cents = n.split('.')
else:
dollars, cents = n, None
r = []
for i, c in enumerate(str(dollars)[::-1]):
if i and (not (i % 3)):
r.insert(0, ',')
r.insert(0, c)
out = ''.join(r)
if cents:
out += '.' + cents
return out
def dateify(datestring):
"""
Formats a numified `datestring` properly.
"""
return denumify(datestring, "XXXX-XX-XX XX:XX:XX")
def nthstr(n):
"""
Formats an ordinal.
Doesn't handle negative numbers.
>>> nthstr(1)
'1st'
>>> nthstr(0)
'0th'
>>> [nthstr(x) for x in [2, 3, 4, 5, 10, 11, 12, 13, 14, 15]]
['2nd', '3rd', '4th', '5th', '10th', '11th', '12th', '13th', '14th', '15th']
>>> [nthstr(x) for x in [91, 92, 93, 94, 99, 100, 101, 102]]
['91st', '92nd', '93rd', '94th', '99th', '100th', '101st', '102nd']
>>> [nthstr(x) for x in [111, 112, 113, 114, 115]]
['111th', '112th', '113th', '114th', '115th']
"""
assert n >= 0
if n % 100 in [11, 12, 13]: return '%sth' % n
return {1: '%sst', 2: '%snd', 3: '%srd'}.get(n % 10, '%sth') % n
def cond(predicate, consequence, alternative=None):
"""
Function replacement for if-else to use in expressions.
>>> x = 2
>>> cond(x % 2 == 0, "even", "odd")
'even'
>>> cond(x % 2 == 0, "even", "odd") + '_row'
'even_row'
"""
if predicate:
return consequence
else:
return alternative
class CaptureStdout:
"""
Captures everything `func` prints to stdout and returns it instead.
>>> def idiot():
... print "foo"
>>> capturestdout(idiot)()
'foo\\n'
**WARNING:** Not threadsafe!
"""
def __init__(self, func):
self.func = func
def __call__(self, *args, **keywords):
from cStringIO import StringIO
# Not threadsafe!
out = StringIO()
oldstdout = sys.stdout
sys.stdout = out
try:
self.func(*args, **keywords)
finally:
sys.stdout = oldstdout
return out.getvalue()
capturestdout = CaptureStdout
class Profile:
"""
Profiles `func` and returns a tuple containing its output
and a string with human-readable profiling information.
>>> import time
>>> out, inf = profile(time.sleep)(.001)
>>> out
>>> inf[:10].strip()
'took 0.0'
"""
def __init__(self, func):
self.func = func
def __call__(self, *args): ##, **kw): kw unused
import hotshot, hotshot.stats, os, tempfile ##, time already imported
f, filename = tempfile.mkstemp()
os.close(f)
prof = hotshot.Profile(filename)
stime = time.time()
result = prof.runcall(self.func, *args)
stime = time.time() - stime
prof.close()
import cStringIO
out = cStringIO.StringIO()
stats = hotshot.stats.load(filename)
stats.stream = out
stats.strip_dirs()
stats.sort_stats('time', 'calls')
stats.print_stats(40)
stats.print_callers()
x = '\n\ntook '+ str(stime) + ' seconds\n'
x += out.getvalue()
# remove the tempfile
try:
os.remove(filename)
except IOError:
pass
return result, x
profile = Profile
import traceback
# hack for compatibility with Python 2.3:
if not hasattr(traceback, 'format_exc'):
from cStringIO import StringIO
def format_exc(limit=None):
strbuf = StringIO()
traceback.print_exc(limit, strbuf)
return strbuf.getvalue()
traceback.format_exc = format_exc
def tryall(context, prefix=None):
"""
Tries a series of functions and prints their results.
`context` is a dictionary mapping names to values;
the value will only be tried if it's callable.
>>> tryall(dict(j=lambda: True))
j: True
----------------------------------------
results:
True: 1
For example, you might have a file `test/stuff.py`
with a series of functions testing various things in it.
At the bottom, have a line:
if __name__ == "__main__": tryall(globals())
Then you can run `python test/stuff.py` and get the results of
all the tests.
"""
context = context.copy() # vars() would update
results = {}
for (key, value) in context.iteritems():
if not hasattr(value, '__call__'):
continue
if prefix and not key.startswith(prefix):
continue
print key + ':',
try:
r = value()
dictincr(results, r)
print r
except:
print 'ERROR'
dictincr(results, 'ERROR')
print ' ' + '\n '.join(traceback.format_exc().split('\n'))
print '-'*40
print 'results:'
for (key, value) in results.iteritems():
print ' '*2, str(key)+':', value
class ThreadedDict(threadlocal):
"""
Thread local storage.
>>> d = ThreadedDict()
>>> d.x = 1
>>> d.x
1
>>> import threading
>>> def f(): d.x = 2
...
>>> t = threading.Thread(target=f)
>>> t.start()
>>> t.join()
>>> d.x
1
"""
_instances = set()
def __init__(self):
ThreadedDict._instances.add(self)
def __del__(self):
ThreadedDict._instances.remove(self)
def __hash__(self):
return id(self)
def clear_all():
"""Clears all ThreadedDict instances.
"""
for t in list(ThreadedDict._instances):
t.clear()
clear_all = staticmethod(clear_all)
# Define all these methods to more or less fully emulate dict -- attribute access
# is built into threading.local.
def __getitem__(self, key):
return self.__dict__[key]
def __setitem__(self, key, value):
self.__dict__[key] = value
def __delitem__(self, key):
del self.__dict__[key]
def __contains__(self, key):
return key in self.__dict__
has_key = __contains__
def clear(self):
self.__dict__.clear()
def copy(self):
return self.__dict__.copy()
def get(self, key, default=None):
return self.__dict__.get(key, default)
def items(self):
return self.__dict__.items()
def iteritems(self):
return self.__dict__.iteritems()
def keys(self):
return self.__dict__.keys()
def iterkeys(self):
return self.__dict__.iterkeys()
iter = iterkeys
def values(self):
return self.__dict__.values()
def itervalues(self):
return self.__dict__.itervalues()
def pop(self, key, *args):
return self.__dict__.pop(key, *args)
def popitem(self):
return self.__dict__.popitem()
def setdefault(self, key, default=None):
return self.__dict__.setdefault(key, default)
def update(self, *args, **kwargs):
self.__dict__.update(*args, **kwargs)
def __repr__(self):
return '' % self.__dict__
__str__ = __repr__
threadeddict = ThreadedDict
def autoassign(self, locals):
"""
Automatically assigns local variables to `self`.
>>> self = storage()
>>> autoassign(self, dict(a=1, b=2))
>>> self
Generally used in `__init__` methods, as in:
def __init__(self, foo, bar, baz=1): autoassign(self, locals())
"""
for (key, value) in locals.iteritems():
if key == 'self':
continue
setattr(self, key, value)
def to36(q):
"""
Converts an integer to base 36 (a useful scheme for human-sayable IDs).
>>> to36(35)
'z'
>>> to36(119292)
'2k1o'
>>> int(to36(939387374), 36)
939387374
>>> to36(0)
'0'
>>> to36(-393)
Traceback (most recent call last):
...
ValueError: must supply a positive integer
"""
if q < 0: raise ValueError, "must supply a positive integer"
letters = "0123456789abcdefghijklmnopqrstuvwxyz"
converted = []
while q != 0:
q, r = divmod(q, 36)
converted.insert(0, letters[r])
return "".join(converted) or '0'
r_url = re_compile('(?', text)
text = markdown(text)
return text
def sendmail(from_address, to_address, subject, message, headers=None, **kw):
"""
Sends the email message `message` with mail and envelope headers
for from `from_address_` to `to_address` with `subject`.
Additional email headers can be specified with the dictionary
`headers.
Optionally cc, bcc and attachments can be specified as keyword arguments.
Attachments must be an iterable and each attachment can be either a
filename or a file object or a dictionary with filename, content and
optionally content_type keys.
If `web.config.smtp_server` is set, it will send the message
to that SMTP server. Otherwise it will look for
`/usr/sbin/sendmail`, the typical location for the sendmail-style
binary. To use sendmail from a different path, set `web.config.sendmail_path`.
"""
attachments = kw.pop("attachments", [])
mail = _EmailMessage(from_address, to_address, subject, message, headers, **kw)
for a in attachments:
if isinstance(a, dict):
mail.attach(a['filename'], a['content'], a.get('content_type'))
elif hasattr(a, 'read'): # file
filename = os.path.basename(getattr(a, "name", ""))
content_type = getattr(a, 'content_type', None)
mail.attach(filename, a.read(), content_type)
elif isinstance(a, basestring):
f = open(a, 'rb')
content = f.read()
f.close()
filename = os.path.basename(a)
mail.attach(filename, content, None)
else:
raise ValueError, "Invalid attachment: %s" % repr(a)
mail.send()
class _EmailMessage:
def __init__(self, from_address, to_address, subject, message, headers=None, **kw):
def listify(x):
if not isinstance(x, list):
return [safestr(x)]
else:
return [safestr(a) for a in x]
subject = safestr(subject)
message = safestr(message)
from_address = safestr(from_address)
to_address = listify(to_address)
cc = listify(kw.get('cc', []))
bcc = listify(kw.get('bcc', []))
recipients = to_address + cc + bcc
import email.Utils
self.from_address = email.Utils.parseaddr(from_address)[1]
self.recipients = [email.Utils.parseaddr(r)[1] for r in recipients]
self.headers = dictadd({
'From': from_address,
'To': ", ".join(to_address),
'Subject': subject
}, headers or {})
if cc:
self.headers['Cc'] = ", ".join(cc)
self.message = self.new_message()
self.message.add_header("Content-Transfer-Encoding", "7bit")
self.message.add_header("Content-Disposition", "inline")
self.message.add_header("MIME-Version", "1.0")
self.message.set_payload(message, 'utf-8')
self.multipart = False
def new_message(self):
from email.Message import Message
return Message()
def attach(self, filename, content, content_type=None):
if not self.multipart:
msg = self.new_message()
msg.add_header("Content-Type", "multipart/mixed")
msg.attach(self.message)
self.message = msg
self.multipart = True
import mimetypes
try:
from email import encoders
except:
from email import Encoders as encoders
content_type = content_type or mimetypes.guess_type(filename)[0] or "applcation/octet-stream"
msg = self.new_message()
msg.set_payload(content)
msg.add_header('Content-Type', content_type)
msg.add_header('Content-Disposition', 'attachment', filename=filename)
if not content_type.startswith("text/"):
encoders.encode_base64(msg)
self.message.attach(msg)
def prepare_message(self):
for k, v in self.headers.iteritems():
if k.lower() == "content-type":
self.message.set_type(v)
else:
self.message.add_header(k, v)
self.headers = {}
def send(self):
try:
import webapi
except ImportError:
webapi = Storage(config=Storage())
self.prepare_message()
message_text = self.message.as_string()
if webapi.config.get('smtp_server'):
server = webapi.config.get('smtp_server')
port = webapi.config.get('smtp_port', 0)
username = webapi.config.get('smtp_username')
password = webapi.config.get('smtp_password')
debug_level = webapi.config.get('smtp_debuglevel', None)
starttls = webapi.config.get('smtp_starttls', False)
import smtplib
smtpserver = smtplib.SMTP(server, port)
if debug_level:
smtpserver.set_debuglevel(debug_level)
if starttls:
smtpserver.ehlo()
smtpserver.starttls()
smtpserver.ehlo()
if username and password:
smtpserver.login(username, password)
smtpserver.sendmail(self.from_address, self.recipients, message_text)
smtpserver.quit()
elif webapi.config.get('email_engine') == 'aws':
import boto.ses
c = boto.ses.SESConnection(
aws_access_key_id=webapi.config.get('aws_access_key_id'),
aws_secret_access_key=web.api.config.get('aws_secret_access_key'))
c.send_raw_email(self.from_address, message_text, self.from_recipients)
else:
sendmail = webapi.config.get('sendmail_path', '/usr/sbin/sendmail')
assert not self.from_address.startswith('-'), 'security'
for r in self.recipients:
assert not r.startswith('-'), 'security'
cmd = [sendmail, '-f', self.from_address] + self.recipients
if subprocess:
p = subprocess.Popen(cmd, stdin=subprocess.PIPE)
p.stdin.write(message_text)
p.stdin.close()
p.wait()
else:
i, o = os.popen2(cmd)
i.write(message)
i.close()
o.close()
del i, o
def __repr__(self):
return ""
def __str__(self):
return self.message.as_string()
if __name__ == "__main__":
import doctest
doctest.testmod()
webpy/web/httpserver.py 0000644 0001750 0001750 00000026337 11773133455 013721 0 ustar wmb wmb __all__ = ["runsimple"]
import sys, os
from SimpleHTTPServer import SimpleHTTPRequestHandler
import urllib
import posixpath
import webapi as web
import net
import utils
def runbasic(func, server_address=("0.0.0.0", 8080)):
"""
Runs a simple HTTP server hosting WSGI app `func`. The directory `static/`
is hosted statically.
Based on [WsgiServer][ws] from [Colin Stewart][cs].
[ws]: http://www.owlfish.com/software/wsgiutils/documentation/wsgi-server-api.html
[cs]: http://www.owlfish.com/
"""
# Copyright (c) 2004 Colin Stewart (http://www.owlfish.com/)
# Modified somewhat for simplicity
# Used under the modified BSD license:
# http://www.xfree86.org/3.3.6/COPYRIGHT2.html#5
import SimpleHTTPServer, SocketServer, BaseHTTPServer, urlparse
import socket, errno
import traceback
class WSGIHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
def run_wsgi_app(self):
protocol, host, path, parameters, query, fragment = \
urlparse.urlparse('http://dummyhost%s' % self.path)
# we only use path, query
env = {'wsgi.version': (1, 0)
,'wsgi.url_scheme': 'http'
,'wsgi.input': self.rfile
,'wsgi.errors': sys.stderr
,'wsgi.multithread': 1
,'wsgi.multiprocess': 0
,'wsgi.run_once': 0
,'REQUEST_METHOD': self.command
,'REQUEST_URI': self.path
,'PATH_INFO': path
,'QUERY_STRING': query
,'CONTENT_TYPE': self.headers.get('Content-Type', '')
,'CONTENT_LENGTH': self.headers.get('Content-Length', '')
,'REMOTE_ADDR': self.client_address[0]
,'SERVER_NAME': self.server.server_address[0]
,'SERVER_PORT': str(self.server.server_address[1])
,'SERVER_PROTOCOL': self.request_version
}
for http_header, http_value in self.headers.items():
env ['HTTP_%s' % http_header.replace('-', '_').upper()] = \
http_value
# Setup the state
self.wsgi_sent_headers = 0
self.wsgi_headers = []
try:
# We have there environment, now invoke the application
result = self.server.app(env, self.wsgi_start_response)
try:
try:
for data in result:
if data:
self.wsgi_write_data(data)
finally:
if hasattr(result, 'close'):
result.close()
except socket.error, socket_err:
# Catch common network errors and suppress them
if (socket_err.args[0] in \
(errno.ECONNABORTED, errno.EPIPE)):
return
except socket.timeout, socket_timeout:
return
except:
print >> web.debug, traceback.format_exc(),
if (not self.wsgi_sent_headers):
# We must write out something!
self.wsgi_write_data(" ")
return
do_POST = run_wsgi_app
do_PUT = run_wsgi_app
do_DELETE = run_wsgi_app
def do_GET(self):
if self.path.startswith('/static/'):
SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self)
else:
self.run_wsgi_app()
def wsgi_start_response(self, response_status, response_headers,
exc_info=None):
if (self.wsgi_sent_headers):
raise Exception \
("Headers already sent and start_response called again!")
# Should really take a copy to avoid changes in the application....
self.wsgi_headers = (response_status, response_headers)
return self.wsgi_write_data
def wsgi_write_data(self, data):
if (not self.wsgi_sent_headers):
status, headers = self.wsgi_headers
# Need to send header prior to data
status_code = status[:status.find(' ')]
status_msg = status[status.find(' ') + 1:]
self.send_response(int(status_code), status_msg)
for header, value in headers:
self.send_header(header, value)
self.end_headers()
self.wsgi_sent_headers = 1
# Send the data
self.wfile.write(data)
class WSGIServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
def __init__(self, func, server_address):
BaseHTTPServer.HTTPServer.__init__(self,
server_address,
WSGIHandler)
self.app = func
self.serverShuttingDown = 0
print "http://%s:%d/" % server_address
WSGIServer(func, server_address).serve_forever()
# The WSGIServer instance.
# Made global so that it can be stopped in embedded mode.
server = None
def runsimple(func, server_address=("0.0.0.0", 8080)):
"""
Runs [CherryPy][cp] WSGI server hosting WSGI app `func`.
The directory `static/` is hosted statically.
[cp]: http://www.cherrypy.org
"""
global server
func = StaticMiddleware(func)
func = LogMiddleware(func)
server = WSGIServer(server_address, func)
if server.ssl_adapter:
print "https://%s:%d/" % server_address
else:
print "http://%s:%d/" % server_address
try:
server.start()
except (KeyboardInterrupt, SystemExit):
server.stop()
server = None
def WSGIServer(server_address, wsgi_app):
"""Creates CherryPy WSGI server listening at `server_address` to serve `wsgi_app`.
This function can be overwritten to customize the webserver or use a different webserver.
"""
import wsgiserver
# Default values of wsgiserver.ssl_adapters uses cherrypy.wsgiserver
# prefix. Overwriting it make it work with web.wsgiserver.
wsgiserver.ssl_adapters = {
'builtin': 'web.wsgiserver.ssl_builtin.BuiltinSSLAdapter',
'pyopenssl': 'web.wsgiserver.ssl_pyopenssl.pyOpenSSLAdapter',
}
server = wsgiserver.CherryPyWSGIServer(server_address, wsgi_app, server_name="localhost")
def create_ssl_adapter(cert, key):
# wsgiserver tries to import submodules as cherrypy.wsgiserver.foo.
# That doesn't work as not it is web.wsgiserver.
# Patching sys.modules temporarily to make it work.
import types
cherrypy = types.ModuleType('cherrypy')
cherrypy.wsgiserver = wsgiserver
sys.modules['cherrypy'] = cherrypy
sys.modules['cherrypy.wsgiserver'] = wsgiserver
from wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter
adapter = pyOpenSSLAdapter(cert, key)
# We are done with our work. Cleanup the patches.
del sys.modules['cherrypy']
del sys.modules['cherrypy.wsgiserver']
return adapter
# SSL backward compatibility
if (server.ssl_adapter is None and
getattr(server, 'ssl_certificate', None) and
getattr(server, 'ssl_private_key', None)):
server.ssl_adapter = create_ssl_adapter(server.ssl_certificate, server.ssl_private_key)
server.nodelay = not sys.platform.startswith('java') # TCP_NODELAY isn't supported on the JVM
return server
class StaticApp(SimpleHTTPRequestHandler):
"""WSGI application for serving static files."""
def __init__(self, environ, start_response):
self.headers = []
self.environ = environ
self.start_response = start_response
def send_response(self, status, msg=""):
self.status = str(status) + " " + msg
def send_header(self, name, value):
self.headers.append((name, value))
def end_headers(self):
pass
def log_message(*a): pass
def __iter__(self):
environ = self.environ
self.path = environ.get('PATH_INFO', '')
self.client_address = environ.get('REMOTE_ADDR','-'), \
environ.get('REMOTE_PORT','-')
self.command = environ.get('REQUEST_METHOD', '-')
from cStringIO import StringIO
self.wfile = StringIO() # for capturing error
try:
path = self.translate_path(self.path)
etag = '"%s"' % os.path.getmtime(path)
client_etag = environ.get('HTTP_IF_NONE_MATCH')
self.send_header('ETag', etag)
if etag == client_etag:
self.send_response(304, "Not Modified")
self.start_response(self.status, self.headers)
raise StopIteration
except OSError:
pass # Probably a 404
f = self.send_head()
self.start_response(self.status, self.headers)
if f:
block_size = 16 * 1024
while True:
buf = f.read(block_size)
if not buf:
break
yield buf
f.close()
else:
value = self.wfile.getvalue()
yield value
class StaticMiddleware:
"""WSGI middleware for serving static files."""
def __init__(self, app, prefix='/static/'):
self.app = app
self.prefix = prefix
def __call__(self, environ, start_response):
path = environ.get('PATH_INFO', '')
path = self.normpath(path)
if path.startswith(self.prefix):
return StaticApp(environ, start_response)
else:
return self.app(environ, start_response)
def normpath(self, path):
path2 = posixpath.normpath(urllib.unquote(path))
if path.endswith("/"):
path2 += "/"
return path2
class LogMiddleware:
"""WSGI middleware for logging the status."""
def __init__(self, app):
self.app = app
self.format = '%s - - [%s] "%s %s %s" - %s'
from BaseHTTPServer import BaseHTTPRequestHandler
import StringIO
f = StringIO.StringIO()
class FakeSocket:
def makefile(self, *a):
return f
# take log_date_time_string method from BaseHTTPRequestHandler
self.log_date_time_string = BaseHTTPRequestHandler(FakeSocket(), None, None).log_date_time_string
def __call__(self, environ, start_response):
def xstart_response(status, response_headers, *args):
out = start_response(status, response_headers, *args)
self.log(status, environ)
return out
return self.app(environ, xstart_response)
def log(self, status, environ):
outfile = environ.get('wsgi.errors', web.debug)
req = environ.get('PATH_INFO', '_')
protocol = environ.get('ACTUAL_SERVER_PROTOCOL', '-')
method = environ.get('REQUEST_METHOD', '-')
host = "%s:%s" % (environ.get('REMOTE_ADDR','-'),
environ.get('REMOTE_PORT','-'))
time = self.log_date_time_string()
msg = self.format % (host, time, protocol, method, req, status)
print >> outfile, utils.safestr(msg)
webpy/web/wsgi.py 0000644 0001750 0001750 00000004226 11773133455 012455 0 ustar wmb wmb """
WSGI Utilities
(from web.py)
"""
import os, sys
import http
import webapi as web
from utils import listget
from net import validaddr, validip
import httpserver
def runfcgi(func, addr=('localhost', 8000)):
"""Runs a WSGI function as a FastCGI server."""
import flup.server.fcgi as flups
return flups.WSGIServer(func, multiplexed=True, bindAddress=addr, debug=False).run()
def runscgi(func, addr=('localhost', 4000)):
"""Runs a WSGI function as an SCGI server."""
import flup.server.scgi as flups
return flups.WSGIServer(func, bindAddress=addr, debug=False).run()
def runwsgi(func):
"""
Runs a WSGI-compatible `func` using FCGI, SCGI, or a simple web server,
as appropriate based on context and `sys.argv`.
"""
if os.environ.has_key('SERVER_SOFTWARE'): # cgi
os.environ['FCGI_FORCE_CGI'] = 'Y'
if (os.environ.has_key('PHP_FCGI_CHILDREN') #lighttpd fastcgi
or os.environ.has_key('SERVER_SOFTWARE')):
return runfcgi(func, None)
if 'fcgi' in sys.argv or 'fastcgi' in sys.argv:
args = sys.argv[1:]
if 'fastcgi' in args: args.remove('fastcgi')
elif 'fcgi' in args: args.remove('fcgi')
if args:
return runfcgi(func, validaddr(args[0]))
else:
return runfcgi(func, None)
if 'scgi' in sys.argv:
args = sys.argv[1:]
args.remove('scgi')
if args:
return runscgi(func, validaddr(args[0]))
else:
return runscgi(func)
return httpserver.runsimple(func, validip(listget(sys.argv, 1, '')))
def _is_dev_mode():
# Some embedded python interpreters won't have sys.arv
# For details, see https://github.com/webpy/webpy/issues/87
argv = getattr(sys, "argv", [])
# quick hack to check if the program is running in dev mode.
if os.environ.has_key('SERVER_SOFTWARE') \
or os.environ.has_key('PHP_FCGI_CHILDREN') \
or 'fcgi' in argv or 'fastcgi' in argv \
or 'mod_wsgi' in argv:
return False
return True
# When running the builtin-server, enable debug mode if not already set.
web.config.setdefault('debug', _is_dev_mode())
webpy/web/template.py 0000644 0001750 0001750 00000140362 11773133455 013321 0 ustar wmb wmb """
simple, elegant templating
(part of web.py)
Template design:
Template string is split into tokens and the tokens are combined into nodes.
Parse tree is a nodelist. TextNode and ExpressionNode are simple nodes and
for-loop, if-loop etc are block nodes, which contain multiple child nodes.
Each node can emit some python string. python string emitted by the
root node is validated for safeeval and executed using python in the given environment.
Enough care is taken to make sure the generated code and the template has line to line match,
so that the error messages can point to exact line number in template. (It doesn't work in some cases still.)
Grammar:
template -> defwith sections
defwith -> '$def with (' arguments ')' | ''
sections -> section*
section -> block | assignment | line
assignment -> '$ '
line -> (text|expr)*
text ->
expr -> '$' pyexpr | '$(' pyexpr ')' | '${' pyexpr '}'
pyexpr ->
"""
__all__ = [
"Template",
"Render", "render", "frender",
"ParseError", "SecurityError",
"test"
]
import tokenize
import os
import sys
import glob
import re
from UserDict import DictMixin
import warnings
from utils import storage, safeunicode, safestr, re_compile
from webapi import config
from net import websafe
def splitline(text):
r"""
Splits the given text at newline.
>>> splitline('foo\nbar')
('foo\n', 'bar')
>>> splitline('foo')
('foo', '')
>>> splitline('')
('', '')
"""
index = text.find('\n') + 1
if index:
return text[:index], text[index:]
else:
return text, ''
class Parser:
"""Parser Base.
"""
def __init__(self):
self.statement_nodes = STATEMENT_NODES
self.keywords = KEYWORDS
def parse(self, text, name=""):
self.text = text
self.name = name
defwith, text = self.read_defwith(text)
suite = self.read_suite(text)
return DefwithNode(defwith, suite)
def read_defwith(self, text):
if text.startswith('$def with'):
defwith, text = splitline(text)
defwith = defwith[1:].strip() # strip $ and spaces
return defwith, text
else:
return '', text
def read_section(self, text):
r"""Reads one section from the given text.
section -> block | assignment | line
>>> read_section = Parser().read_section
>>> read_section('foo\nbar\n')
(, 'bar\n')
>>> read_section('$ a = b + 1\nfoo\n')
(, 'foo\n')
read_section('$for in range(10):\n hello $i\nfoo)
"""
if text.lstrip(' ').startswith('$'):
index = text.index('$')
begin_indent, text2 = text[:index], text[index+1:]
ahead = self.python_lookahead(text2)
if ahead == 'var':
return self.read_var(text2)
elif ahead in self.statement_nodes:
return self.read_block_section(text2, begin_indent)
elif ahead in self.keywords:
return self.read_keyword(text2)
elif ahead.strip() == '':
# assignments starts with a space after $
# ex: $ a = b + 2
return self.read_assignment(text2)
return self.readline(text)
def read_var(self, text):
r"""Reads a var statement.
>>> read_var = Parser().read_var
>>> read_var('var x=10\nfoo')
(, 'foo')
>>> read_var('var x: hello $name\nfoo')
(, 'foo')
"""
line, text = splitline(text)
tokens = self.python_tokens(line)
if len(tokens) < 4:
raise SyntaxError('Invalid var statement')
name = tokens[1]
sep = tokens[2]
value = line.split(sep, 1)[1].strip()
if sep == '=':
pass # no need to process value
elif sep == ':':
#@@ Hack for backward-compatability
if tokens[3] == '\n': # multi-line var statement
block, text = self.read_indented_block(text, ' ')
lines = [self.readline(x)[0] for x in block.splitlines()]
nodes = []
for x in lines:
nodes.extend(x.nodes)
nodes.append(TextNode('\n'))
else: # single-line var statement
linenode, _ = self.readline(value)
nodes = linenode.nodes
parts = [node.emit('') for node in nodes]
value = "join_(%s)" % ", ".join(parts)
else:
raise SyntaxError('Invalid var statement')
return VarNode(name, value), text
def read_suite(self, text):
r"""Reads section by section till end of text.
>>> read_suite = Parser().read_suite
>>> read_suite('hello $name\nfoo\n')
[, ]
"""
sections = []
while text:
section, text = self.read_section(text)
sections.append(section)
return SuiteNode(sections)
def readline(self, text):
r"""Reads one line from the text. Newline is supressed if the line ends with \.
>>> readline = Parser().readline
>>> readline('hello $name!\nbye!')
(, 'bye!')
>>> readline('hello $name!\\\nbye!')
(, 'bye!')
>>> readline('$f()\n\n')
(, '\n')
"""
line, text = splitline(text)
# supress new line if line ends with \
if line.endswith('\\\n'):
line = line[:-2]
nodes = []
while line:
node, line = self.read_node(line)
nodes.append(node)
return LineNode(nodes), text
def read_node(self, text):
r"""Reads a node from the given text and returns the node and remaining text.
>>> read_node = Parser().read_node
>>> read_node('hello $name')
(t'hello ', '$name')
>>> read_node('$name')
($name, '')
"""
if text.startswith('$$'):
return TextNode('$'), text[2:]
elif text.startswith('$#'): # comment
line, text = splitline(text)
return TextNode('\n'), text
elif text.startswith('$'):
text = text[1:] # strip $
if text.startswith(':'):
escape = False
text = text[1:] # strip :
else:
escape = True
return self.read_expr(text, escape=escape)
else:
return self.read_text(text)
def read_text(self, text):
r"""Reads a text node from the given text.
>>> read_text = Parser().read_text
>>> read_text('hello $name')
(t'hello ', '$name')
"""
index = text.find('$')
if index < 0:
return TextNode(text), ''
else:
return TextNode(text[:index]), text[index:]
def read_keyword(self, text):
line, text = splitline(text)
return StatementNode(line.strip() + "\n"), text
def read_expr(self, text, escape=True):
"""Reads a python expression from the text and returns the expression and remaining text.
expr -> simple_expr | paren_expr
simple_expr -> id extended_expr
extended_expr -> attr_access | paren_expr extended_expr | ''
attr_access -> dot id extended_expr
paren_expr -> [ tokens ] | ( tokens ) | { tokens }
>>> read_expr = Parser().read_expr
>>> read_expr("name")
($name, '')
>>> read_expr("a.b and c")
($a.b, ' and c')
>>> read_expr("a. b")
($a, '. b')
>>> read_expr("name")
($name, '')
>>> read_expr("(limit)ing")
($(limit), 'ing')
>>> read_expr('a[1, 2][:3].f(1+2, "weird string[).", 3 + 4) done.')
($a[1, 2][:3].f(1+2, "weird string[).", 3 + 4), ' done.')
"""
def simple_expr():
identifier()
extended_expr()
def identifier():
tokens.next()
def extended_expr():
lookahead = tokens.lookahead()
if lookahead is None:
return
elif lookahead.value == '.':
attr_access()
elif lookahead.value in parens:
paren_expr()
extended_expr()
else:
return
def attr_access():
from token import NAME # python token constants
dot = tokens.lookahead()
if tokens.lookahead2().type == NAME:
tokens.next() # consume dot
identifier()
extended_expr()
def paren_expr():
begin = tokens.next().value
end = parens[begin]
while True:
if tokens.lookahead().value in parens:
paren_expr()
else:
t = tokens.next()
if t.value == end:
break
return
parens = {
"(": ")",
"[": "]",
"{": "}"
}
def get_tokens(text):
"""tokenize text using python tokenizer.
Python tokenizer ignores spaces, but they might be important in some cases.
This function introduces dummy space tokens when it identifies any ignored space.
Each token is a storage object containing type, value, begin and end.
"""
readline = iter([text]).next
end = None
for t in tokenize.generate_tokens(readline):
t = storage(type=t[0], value=t[1], begin=t[2], end=t[3])
if end is not None and end != t.begin:
_, x1 = end
_, x2 = t.begin
yield storage(type=-1, value=text[x1:x2], begin=end, end=t.begin)
end = t.end
yield t
class BetterIter:
"""Iterator like object with 2 support for 2 look aheads."""
def __init__(self, items):
self.iteritems = iter(items)
self.items = []
self.position = 0
self.current_item = None
def lookahead(self):
if len(self.items) <= self.position:
self.items.append(self._next())
return self.items[self.position]
def _next(self):
try:
return self.iteritems.next()
except StopIteration:
return None
def lookahead2(self):
if len(self.items) <= self.position+1:
self.items.append(self._next())
return self.items[self.position+1]
def next(self):
self.current_item = self.lookahead()
self.position += 1
return self.current_item
tokens = BetterIter(get_tokens(text))
if tokens.lookahead().value in parens:
paren_expr()
else:
simple_expr()
row, col = tokens.current_item.end
return ExpressionNode(text[:col], escape=escape), text[col:]
def read_assignment(self, text):
r"""Reads assignment statement from text.
>>> read_assignment = Parser().read_assignment
>>> read_assignment('a = b + 1\nfoo')
(, 'foo')
"""
line, text = splitline(text)
return AssignmentNode(line.strip()), text
def python_lookahead(self, text):
"""Returns the first python token from the given text.
>>> python_lookahead = Parser().python_lookahead
>>> python_lookahead('for i in range(10):')
'for'
>>> python_lookahead('else:')
'else'
>>> python_lookahead(' x = 1')
' '
"""
readline = iter([text]).next
tokens = tokenize.generate_tokens(readline)
return tokens.next()[1]
def python_tokens(self, text):
readline = iter([text]).next
tokens = tokenize.generate_tokens(readline)
return [t[1] for t in tokens]
def read_indented_block(self, text, indent):
r"""Read a block of text. A block is what typically follows a for or it statement.
It can be in the same line as that of the statement or an indented block.
>>> read_indented_block = Parser().read_indented_block
>>> read_indented_block(' a\n b\nc', ' ')
('a\nb\n', 'c')
>>> read_indented_block(' a\n b\n c\nd', ' ')
('a\n b\nc\n', 'd')
>>> read_indented_block(' a\n\n b\nc', ' ')
('a\n\n b\n', 'c')
"""
if indent == '':
return '', text
block = ""
while text:
line, text2 = splitline(text)
if line.strip() == "":
block += '\n'
elif line.startswith(indent):
block += line[len(indent):]
else:
break
text = text2
return block, text
def read_statement(self, text):
r"""Reads a python statement.
>>> read_statement = Parser().read_statement
>>> read_statement('for i in range(10): hello $name')
('for i in range(10):', ' hello $name')
"""
tok = PythonTokenizer(text)
tok.consume_till(':')
return text[:tok.index], text[tok.index:]
def read_block_section(self, text, begin_indent=''):
r"""
>>> read_block_section = Parser().read_block_section
>>> read_block_section('for i in range(10): hello $i\nfoo')
(]>, 'foo')
>>> read_block_section('for i in range(10):\n hello $i\n foo', begin_indent=' ')
(]>, ' foo')
>>> read_block_section('for i in range(10):\n hello $i\nfoo')
(]>, 'foo')
"""
line, text = splitline(text)
stmt, line = self.read_statement(line)
keyword = self.python_lookahead(stmt)
# if there is some thing left in the line
if line.strip():
block = line.lstrip()
else:
def find_indent(text):
rx = re_compile(' +')
match = rx.match(text)
first_indent = match and match.group(0)
return first_indent or ""
# find the indentation of the block by looking at the first line
first_indent = find_indent(text)[len(begin_indent):]
#TODO: fix this special case
if keyword == "code":
indent = begin_indent + first_indent
else:
indent = begin_indent + min(first_indent, INDENT)
block, text = self.read_indented_block(text, indent)
return self.create_block_node(keyword, stmt, block, begin_indent), text
def create_block_node(self, keyword, stmt, block, begin_indent):
if keyword in self.statement_nodes:
return self.statement_nodes[keyword](stmt, block, begin_indent)
else:
raise ParseError, 'Unknown statement: %s' % repr(keyword)
class PythonTokenizer:
"""Utility wrapper over python tokenizer."""
def __init__(self, text):
self.text = text
readline = iter([text]).next
self.tokens = tokenize.generate_tokens(readline)
self.index = 0
def consume_till(self, delim):
"""Consumes tokens till colon.
>>> tok = PythonTokenizer('for i in range(10): hello $i')
>>> tok.consume_till(':')
>>> tok.text[:tok.index]
'for i in range(10):'
>>> tok.text[tok.index:]
' hello $i'
"""
try:
while True:
t = self.next()
if t.value == delim:
break
elif t.value == '(':
self.consume_till(')')
elif t.value == '[':
self.consume_till(']')
elif t.value == '{':
self.consume_till('}')
# if end of line is found, it is an exception.
# Since there is no easy way to report the line number,
# leave the error reporting to the python parser later
#@@ This should be fixed.
if t.value == '\n':
break
except:
#raise ParseError, "Expected %s, found end of line." % repr(delim)
# raising ParseError doesn't show the line number.
# if this error is ignored, then it will be caught when compiling the python code.
return
def next(self):
type, t, begin, end, line = self.tokens.next()
row, col = end
self.index = col
return storage(type=type, value=t, begin=begin, end=end)
class DefwithNode:
def __init__(self, defwith, suite):
if defwith:
self.defwith = defwith.replace('with', '__template__') + ':'
# offset 4 lines. for encoding, __lineoffset__, loop and self.
self.defwith += "\n __lineoffset__ = -4"
else:
self.defwith = 'def __template__():'
# offset 4 lines for encoding, __template__, __lineoffset__, loop and self.
self.defwith += "\n __lineoffset__ = -5"
self.defwith += "\n loop = ForLoop()"
self.defwith += "\n self = TemplateResult(); extend_ = self.extend"
self.suite = suite
self.end = "\n return self"
def emit(self, indent):
encoding = "# coding: utf-8\n"
return encoding + self.defwith + self.suite.emit(indent + INDENT) + self.end
def __repr__(self):
return "" % (self.defwith, self.suite)
class TextNode:
def __init__(self, value):
self.value = value
def emit(self, indent, begin_indent=''):
return repr(safeunicode(self.value))
def __repr__(self):
return 't' + repr(self.value)
class ExpressionNode:
def __init__(self, value, escape=True):
self.value = value.strip()
# convert ${...} to $(...)
if value.startswith('{') and value.endswith('}'):
self.value = '(' + self.value[1:-1] + ')'
self.escape = escape
def emit(self, indent, begin_indent=''):
return 'escape_(%s, %s)' % (self.value, bool(self.escape))
def __repr__(self):
if self.escape:
escape = ''
else:
escape = ':'
return "$%s%s" % (escape, self.value)
class AssignmentNode:
def __init__(self, code):
self.code = code
def emit(self, indent, begin_indent=''):
return indent + self.code + "\n"
def __repr__(self):
return "" % repr(self.code)
class LineNode:
def __init__(self, nodes):
self.nodes = nodes
def emit(self, indent, text_indent='', name=''):
text = [node.emit('') for node in self.nodes]
if text_indent:
text = [repr(text_indent)] + text
return indent + "extend_([%s])\n" % ", ".join(text)
def __repr__(self):
return "" % repr(self.nodes)
INDENT = ' ' # 4 spaces
class BlockNode:
def __init__(self, stmt, block, begin_indent=''):
self.stmt = stmt
self.suite = Parser().read_suite(block)
self.begin_indent = begin_indent
def emit(self, indent, text_indent=''):
text_indent = self.begin_indent + text_indent
out = indent + self.stmt + self.suite.emit(indent + INDENT, text_indent)
return out
def __repr__(self):
return "" % (repr(self.stmt), repr(self.suite))
class ForNode(BlockNode):
def __init__(self, stmt, block, begin_indent=''):
self.original_stmt = stmt
tok = PythonTokenizer(stmt)
tok.consume_till('in')
a = stmt[:tok.index] # for i in
b = stmt[tok.index:-1] # rest of for stmt excluding :
stmt = a + ' loop.setup(' + b.strip() + '):'
BlockNode.__init__(self, stmt, block, begin_indent)
def __repr__(self):
return "" % (repr(self.original_stmt), repr(self.suite))
class CodeNode:
def __init__(self, stmt, block, begin_indent=''):
# compensate one line for $code:
self.code = "\n" + block
def emit(self, indent, text_indent=''):
import re
rx = re.compile('^', re.M)
return rx.sub(indent, self.code).rstrip(' ')
def __repr__(self):
return "" % repr(self.code)
class StatementNode:
def __init__(self, stmt):
self.stmt = stmt
def emit(self, indent, begin_indent=''):
return indent + self.stmt
def __repr__(self):
return "" % repr(self.stmt)
class IfNode(BlockNode):
pass
class ElseNode(BlockNode):
pass
class ElifNode(BlockNode):
pass
class DefNode(BlockNode):
def __init__(self, *a, **kw):
BlockNode.__init__(self, *a, **kw)
code = CodeNode("", "")
code.code = "self = TemplateResult(); extend_ = self.extend\n"
self.suite.sections.insert(0, code)
code = CodeNode("", "")
code.code = "return self\n"
self.suite.sections.append(code)
def emit(self, indent, text_indent=''):
text_indent = self.begin_indent + text_indent
out = indent + self.stmt + self.suite.emit(indent + INDENT, text_indent)
return indent + "__lineoffset__ -= 3\n" + out
class VarNode:
def __init__(self, name, value):
self.name = name
self.value = value
def emit(self, indent, text_indent):
return indent + "self[%s] = %s\n" % (repr(self.name), self.value)
def __repr__(self):
return "" % (self.name, self.value)
class SuiteNode:
"""Suite is a list of sections."""
def __init__(self, sections):
self.sections = sections
def emit(self, indent, text_indent=''):
return "\n" + "".join([s.emit(indent, text_indent) for s in self.sections])
def __repr__(self):
return repr(self.sections)
STATEMENT_NODES = {
'for': ForNode,
'while': BlockNode,
'if': IfNode,
'elif': ElifNode,
'else': ElseNode,
'def': DefNode,
'code': CodeNode
}
KEYWORDS = [
"pass",
"break",
"continue",
"return"
]
TEMPLATE_BUILTIN_NAMES = [
"dict", "enumerate", "float", "int", "bool", "list", "long", "reversed",
"set", "slice", "tuple", "xrange",
"abs", "all", "any", "callable", "chr", "cmp", "divmod", "filter", "hex",
"id", "isinstance", "iter", "len", "max", "min", "oct", "ord", "pow", "range",
"True", "False",
"None",
"__import__", # some c-libraries like datetime requires __import__ to present in the namespace
]
import __builtin__
TEMPLATE_BUILTINS = dict([(name, getattr(__builtin__, name)) for name in TEMPLATE_BUILTIN_NAMES if name in __builtin__.__dict__])
class ForLoop:
"""
Wrapper for expression in for stament to support loop.xxx helpers.
>>> loop = ForLoop()
>>> for x in loop.setup(['a', 'b', 'c']):
... print loop.index, loop.revindex, loop.parity, x
...
1 3 odd a
2 2 even b
3 1 odd c
>>> loop.index
Traceback (most recent call last):
...
AttributeError: index
"""
def __init__(self):
self._ctx = None
def __getattr__(self, name):
if self._ctx is None:
raise AttributeError, name
else:
return getattr(self._ctx, name)
def setup(self, seq):
self._push()
return self._ctx.setup(seq)
def _push(self):
self._ctx = ForLoopContext(self, self._ctx)
def _pop(self):
self._ctx = self._ctx.parent
class ForLoopContext:
"""Stackable context for ForLoop to support nested for loops.
"""
def __init__(self, forloop, parent):
self._forloop = forloop
self.parent = parent
def setup(self, seq):
try:
self.length = len(seq)
except:
self.length = 0
self.index = 0
for a in seq:
self.index += 1
yield a
self._forloop._pop()
index0 = property(lambda self: self.index-1)
first = property(lambda self: self.index == 1)
last = property(lambda self: self.index == self.length)
odd = property(lambda self: self.index % 2 == 1)
even = property(lambda self: self.index % 2 == 0)
parity = property(lambda self: ['odd', 'even'][self.even])
revindex0 = property(lambda self: self.length - self.index)
revindex = property(lambda self: self.length - self.index + 1)
class BaseTemplate:
def __init__(self, code, filename, filter, globals, builtins):
self.filename = filename
self.filter = filter
self._globals = globals
self._builtins = builtins
if code:
self.t = self._compile(code)
else:
self.t = lambda: ''
def _compile(self, code):
env = self.make_env(self._globals or {}, self._builtins)
exec(code, env)
return env['__template__']
def __call__(self, *a, **kw):
__hidetraceback__ = True
return self.t(*a, **kw)
def make_env(self, globals, builtins):
return dict(globals,
__builtins__=builtins,
ForLoop=ForLoop,
TemplateResult=TemplateResult,
escape_=self._escape,
join_=self._join
)
def _join(self, *items):
return u"".join(items)
def _escape(self, value, escape=False):
if value is None:
value = ''
value = safeunicode(value)
if escape and self.filter:
value = self.filter(value)
return value
class Template(BaseTemplate):
CONTENT_TYPES = {
'.html' : 'text/html; charset=utf-8',
'.xhtml' : 'application/xhtml+xml; charset=utf-8',
'.txt' : 'text/plain',
}
FILTERS = {
'.html': websafe,
'.xhtml': websafe,
'.xml': websafe
}
globals = {}
def __init__(self, text, filename='', filter=None, globals=None, builtins=None, extensions=None):
self.extensions = extensions or []
text = Template.normalize_text(text)
code = self.compile_template(text, filename)
_, ext = os.path.splitext(filename)
filter = filter or self.FILTERS.get(ext, None)
self.content_type = self.CONTENT_TYPES.get(ext, None)
if globals is None:
globals = self.globals
if builtins is None:
builtins = TEMPLATE_BUILTINS
BaseTemplate.__init__(self, code=code, filename=filename, filter=filter, globals=globals, builtins=builtins)
def normalize_text(text):
"""Normalizes template text by correcting \r\n, tabs and BOM chars."""
text = text.replace('\r\n', '\n').replace('\r', '\n').expandtabs()
if not text.endswith('\n'):
text += '\n'
# ignore BOM chars at the begining of template
BOM = '\xef\xbb\xbf'
if isinstance(text, str) and text.startswith(BOM):
text = text[len(BOM):]
# support fort \$ for backward-compatibility
text = text.replace(r'\$', '$$')
return text
normalize_text = staticmethod(normalize_text)
def __call__(self, *a, **kw):
__hidetraceback__ = True
import webapi as web
if 'headers' in web.ctx and self.content_type:
web.header('Content-Type', self.content_type, unique=True)
return BaseTemplate.__call__(self, *a, **kw)
def generate_code(text, filename, parser=None):
# parse the text
parser = parser or Parser()
rootnode = parser.parse(text, filename)
# generate python code from the parse tree
code = rootnode.emit(indent="").strip()
return safestr(code)
generate_code = staticmethod(generate_code)
def create_parser(self):
p = Parser()
for ext in self.extensions:
p = ext(p)
return p
def compile_template(self, template_string, filename):
code = Template.generate_code(template_string, filename, parser=self.create_parser())
def get_source_line(filename, lineno):
try:
lines = open(filename).read().splitlines()
return lines[lineno]
except:
return None
try:
# compile the code first to report the errors, if any, with the filename
compiled_code = compile(code, filename, 'exec')
except SyntaxError, e:
# display template line that caused the error along with the traceback.
try:
e.msg += '\n\nTemplate traceback:\n File %s, line %s\n %s' % \
(repr(e.filename), e.lineno, get_source_line(e.filename, e.lineno-1))
except:
pass
raise
# make sure code is safe - but not with jython, it doesn't have a working compiler module
if not sys.platform.startswith('java'):
try:
import compiler
ast = compiler.parse(code)
SafeVisitor().walk(ast, filename)
except ImportError:
warnings.warn("Unabled to import compiler module. Unable to check templates for safety.")
else:
warnings.warn("SECURITY ISSUE: You are using Jython, which does not support checking templates for safety. Your templates can execute arbitrary code.")
return compiled_code
class CompiledTemplate(Template):
def __init__(self, f, filename):
Template.__init__(self, '', filename)
self.t = f
def compile_template(self, *a):
return None
def _compile(self, *a):
return None
class Render:
"""The most preferred way of using templates.
render = web.template.render('templates')
print render.foo()
Optional parameter can be `base` can be used to pass output of
every template through the base template.
render = web.template.render('templates', base='layout')
"""
def __init__(self, loc='templates', cache=None, base=None, **keywords):
self._loc = loc
self._keywords = keywords
if cache is None:
cache = not config.get('debug', False)
if cache:
self._cache = {}
else:
self._cache = None
if base and not hasattr(base, '__call__'):
# make base a function, so that it can be passed to sub-renders
self._base = lambda page: self._template(base)(page)
else:
self._base = base
def _add_global(self, obj, name=None):
"""Add a global to this rendering instance."""
if 'globals' not in self._keywords: self._keywords['globals'] = {}
if not name:
name = obj.__name__
self._keywords['globals'][name] = obj
def _lookup(self, name):
path = os.path.join(self._loc, name)
if os.path.isdir(path):
return 'dir', path
else:
path = self._findfile(path)
if path:
return 'file', path
else:
return 'none', None
def _load_template(self, name):
kind, path = self._lookup(name)
if kind == 'dir':
return Render(path, cache=self._cache is not None, base=self._base, **self._keywords)
elif kind == 'file':
return Template(open(path).read(), filename=path, **self._keywords)
else:
raise AttributeError, "No template named " + name
def _findfile(self, path_prefix):
p = [f for f in glob.glob(path_prefix + '.*') if not f.endswith('~')] # skip backup files
p.sort() # sort the matches for deterministic order
return p and p[0]
def _template(self, name):
if self._cache is not None:
if name not in self._cache:
self._cache[name] = self._load_template(name)
return self._cache[name]
else:
return self._load_template(name)
def __getattr__(self, name):
t = self._template(name)
if self._base and isinstance(t, Template):
def template(*a, **kw):
return self._base(t(*a, **kw))
return template
else:
return self._template(name)
class GAE_Render(Render):
# Render gets over-written. make a copy here.
super = Render
def __init__(self, loc, *a, **kw):
GAE_Render.super.__init__(self, loc, *a, **kw)
import types
if isinstance(loc, types.ModuleType):
self.mod = loc
else:
name = loc.rstrip('/').replace('/', '.')
self.mod = __import__(name, None, None, ['x'])
self.mod.__dict__.update(kw.get('builtins', TEMPLATE_BUILTINS))
self.mod.__dict__.update(Template.globals)
self.mod.__dict__.update(kw.get('globals', {}))
def _load_template(self, name):
t = getattr(self.mod, name)
import types
if isinstance(t, types.ModuleType):
return GAE_Render(t, cache=self._cache is not None, base=self._base, **self._keywords)
else:
return t
render = Render
# setup render for Google App Engine.
try:
from google import appengine
render = Render = GAE_Render
except ImportError:
pass
def frender(path, **keywords):
"""Creates a template from the given file path.
"""
return Template(open(path).read(), filename=path, **keywords)
def compile_templates(root):
"""Compiles templates to python code."""
re_start = re_compile('^', re.M)
for dirpath, dirnames, filenames in os.walk(root):
filenames = [f for f in filenames if not f.startswith('.') and not f.endswith('~') and not f.startswith('__init__.py')]
for d in dirnames[:]:
if d.startswith('.'):
dirnames.remove(d) # don't visit this dir
out = open(os.path.join(dirpath, '__init__.py'), 'w')
out.write('from web.template import CompiledTemplate, ForLoop, TemplateResult\n\n')
if dirnames:
out.write("import " + ", ".join(dirnames))
out.write("\n")
for f in filenames:
path = os.path.join(dirpath, f)
if '.' in f:
name, _ = f.split('.', 1)
else:
name = f
text = open(path).read()
text = Template.normalize_text(text)
code = Template.generate_code(text, path)
code = code.replace("__template__", name, 1)
out.write(code)
out.write('\n\n')
out.write('%s = CompiledTemplate(%s, %s)\n' % (name, name, repr(path)))
out.write("join_ = %s._join; escape_ = %s._escape\n\n" % (name, name))
# create template to make sure it compiles
t = Template(open(path).read(), path)
out.close()
class ParseError(Exception):
pass
class SecurityError(Exception):
"""The template seems to be trying to do something naughty."""
pass
# Enumerate all the allowed AST nodes
ALLOWED_AST_NODES = [
"Add", "And",
# "AssAttr",
"AssList", "AssName", "AssTuple",
# "Assert",
"Assign", "AugAssign",
# "Backquote",
"Bitand", "Bitor", "Bitxor", "Break",
"CallFunc","Class", "Compare", "Const", "Continue",
"Decorators", "Dict", "Discard", "Div",
"Ellipsis", "EmptyNode",
# "Exec",
"Expression", "FloorDiv", "For",
# "From",
"Function",
"GenExpr", "GenExprFor", "GenExprIf", "GenExprInner",
"Getattr",
# "Global",
"If", "IfExp",
# "Import",
"Invert", "Keyword", "Lambda", "LeftShift",
"List", "ListComp", "ListCompFor", "ListCompIf", "Mod",
"Module",
"Mul", "Name", "Not", "Or", "Pass", "Power",
# "Print", "Printnl", "Raise",
"Return", "RightShift", "Slice", "Sliceobj",
"Stmt", "Sub", "Subscript",
# "TryExcept", "TryFinally",
"Tuple", "UnaryAdd", "UnarySub",
"While", "With", "Yield",
]
class SafeVisitor(object):
"""
Make sure code is safe by walking through the AST.
Code considered unsafe if:
* it has restricted AST nodes
* it is trying to access resricted attributes
Adopted from http://www.zafar.se/bkz/uploads/safe.txt (public domain, Babar K. Zafar)
"""
def __init__(self):
"Initialize visitor by generating callbacks for all AST node types."
self.errors = []
def walk(self, ast, filename):
"Validate each node in AST and raise SecurityError if the code is not safe."
self.filename = filename
self.visit(ast)
if self.errors:
raise SecurityError, '\n'.join([str(err) for err in self.errors])
def visit(self, node, *args):
"Recursively validate node and all of its children."
def classname(obj):
return obj.__class__.__name__
nodename = classname(node)
fn = getattr(self, 'visit' + nodename, None)
if fn:
fn(node, *args)
else:
if nodename not in ALLOWED_AST_NODES:
self.fail(node, *args)
for child in node.getChildNodes():
self.visit(child, *args)
def visitName(self, node, *args):
"Disallow any attempts to access a restricted attr."
#self.assert_attr(node.getChildren()[0], node)
pass
def visitGetattr(self, node, *args):
"Disallow any attempts to access a restricted attribute."
self.assert_attr(node.attrname, node)
def assert_attr(self, attrname, node):
if self.is_unallowed_attr(attrname):
lineno = self.get_node_lineno(node)
e = SecurityError("%s:%d - access to attribute '%s' is denied" % (self.filename, lineno, attrname))
self.errors.append(e)
def is_unallowed_attr(self, name):
return name.startswith('_') \
or name.startswith('func_') \
or name.startswith('im_')
def get_node_lineno(self, node):
return (node.lineno) and node.lineno or 0
def fail(self, node, *args):
"Default callback for unallowed AST nodes."
lineno = self.get_node_lineno(node)
nodename = node.__class__.__name__
e = SecurityError("%s:%d - execution of '%s' statements is denied" % (self.filename, lineno, nodename))
self.errors.append(e)
class TemplateResult(object, DictMixin):
"""Dictionary like object for storing template output.
The result of a template execution is usally a string, but sometimes it
contains attributes set using $var. This class provides a simple
dictionary like interface for storing the output of the template and the
attributes. The output is stored with a special key __body__. Convering
the the TemplateResult to string or unicode returns the value of __body__.
When the template is in execution, the output is generated part by part
and those parts are combined at the end. Parts are added to the
TemplateResult by calling the `extend` method and the parts are combined
seemlessly when __body__ is accessed.
>>> d = TemplateResult(__body__='hello, world', x='foo')
>>> d
>>> print d
hello, world
>>> d.x
'foo'
>>> d = TemplateResult()
>>> d.extend([u'hello', u'world'])
>>> d
"""
def __init__(self, *a, **kw):
self.__dict__["_d"] = dict(*a, **kw)
self._d.setdefault("__body__", u'')
self.__dict__['_parts'] = []
self.__dict__["extend"] = self._parts.extend
self._d.setdefault("__body__", None)
def keys(self):
return self._d.keys()
def _prepare_body(self):
"""Prepare value of __body__ by joining parts.
"""
if self._parts:
value = u"".join(self._parts)
self._parts[:] = []
body = self._d.get('__body__')
if body:
self._d['__body__'] = body + value
else:
self._d['__body__'] = value
def __getitem__(self, name):
if name == "__body__":
self._prepare_body()
return self._d[name]
def __setitem__(self, name, value):
if name == "__body__":
self._prepare_body()
return self._d.__setitem__(name, value)
def __delitem__(self, name):
if name == "__body__":
self._prepare_body()
return self._d.__delitem__(name)
def __getattr__(self, key):
try:
return self[key]
except KeyError, k:
raise AttributeError, k
def __setattr__(self, key, value):
self[key] = value
def __delattr__(self, key):
try:
del self[key]
except KeyError, k:
raise AttributeError, k
def __unicode__(self):
self._prepare_body()
return self["__body__"]
def __str__(self):
self._prepare_body()
return self["__body__"].encode('utf-8')
def __repr__(self):
self._prepare_body()
return "" % self._d
def test():
r"""Doctest for testing template module.
Define a utility function to run template test.
>>> class TestResult:
... def __init__(self, t): self.t = t
... def __getattr__(self, name): return getattr(self.t, name)
... def __repr__(self): return repr(unicode(self))
...
>>> def t(code, **keywords):
... tmpl = Template(code, **keywords)
... return lambda *a, **kw: TestResult(tmpl(*a, **kw))
...
Simple tests.
>>> t('1')()
u'1\n'
>>> t('$def with ()\n1')()
u'1\n'
>>> t('$def with (a)\n$a')(1)
u'1\n'
>>> t('$def with (a=0)\n$a')(1)
u'1\n'
>>> t('$def with (a=0)\n$a')(a=1)
u'1\n'
Test complicated expressions.
>>> t('$def with (x)\n$x.upper()')('hello')
u'HELLO\n'
>>> t('$(2 * 3 + 4 * 5)')()
u'26\n'
>>> t('${2 * 3 + 4 * 5}')()
u'26\n'
>>> t('$def with (limit)\nkeep $(limit)ing.')('go')
u'keep going.\n'
>>> t('$def with (a)\n$a.b[0]')(storage(b=[1]))
u'1\n'
Test html escaping.
>>> t('$def with (x)\n$x', filename='a.html')('')
u'<html>\n'
>>> t('$def with (x)\n$x', filename='a.txt')('')
u'\n'
Test if, for and while.
>>> t('$if 1: 1')()
u'1\n'
>>> t('$if 1:\n 1')()
u'1\n'
>>> t('$if 1:\n 1\\')()
u'1'
>>> t('$if 0: 0\n$elif 1: 1')()
u'1\n'
>>> t('$if 0: 0\n$elif None: 0\n$else: 1')()
u'1\n'
>>> t('$if 0 < 1 and 1 < 2: 1')()
u'1\n'
>>> t('$for x in [1, 2, 3]: $x')()
u'1\n2\n3\n'
>>> t('$def with (d)\n$for k, v in d.iteritems(): $k')({1: 1})
u'1\n'
>>> t('$for x in [1, 2, 3]:\n\t$x')()
u' 1\n 2\n 3\n'
>>> t('$def with (a)\n$while a and a.pop():1')([1, 2, 3])
u'1\n1\n1\n'
The space after : must be ignored.
>>> t('$if True: foo')()
u'foo\n'
Test loop.xxx.
>>> t("$for i in range(5):$loop.index, $loop.parity")()
u'1, odd\n2, even\n3, odd\n4, even\n5, odd\n'
>>> t("$for i in range(2):\n $for j in range(2):$loop.parent.parity $loop.parity")()
u'odd odd\nodd even\neven odd\neven even\n'
Test assignment.
>>> t('$ a = 1\n$a')()
u'1\n'
>>> t('$ a = [1]\n$a[0]')()
u'1\n'
>>> t('$ a = {1: 1}\n$a.keys()[0]')()
u'1\n'
>>> t('$ a = []\n$if not a: 1')()
u'1\n'
>>> t('$ a = {}\n$if not a: 1')()
u'1\n'
>>> t('$ a = -1\n$a')()
u'-1\n'
>>> t('$ a = "1"\n$a')()
u'1\n'
Test comments.
>>> t('$# 0')()
u'\n'
>>> t('hello$#comment1\nhello$#comment2')()
u'hello\nhello\n'
>>> t('$#comment0\nhello$#comment1\nhello$#comment2')()
u'\nhello\nhello\n'
Test unicode.
>>> t('$def with (a)\n$a')(u'\u203d')
u'\u203d\n'
>>> t('$def with (a)\n$a')(u'\u203d'.encode('utf-8'))
u'\u203d\n'
>>> t(u'$def with (a)\n$a $:a')(u'\u203d')
u'\u203d \u203d\n'
>>> t(u'$def with ()\nfoo')()
u'foo\n'
>>> def f(x): return x
...
>>> t(u'$def with (f)\n$:f("x")')(f)
u'x\n'
>>> t('$def with (f)\n$:f("x")')(f)
u'x\n'
Test dollar escaping.
>>> t("Stop, $$money isn't evaluated.")()
u"Stop, $money isn't evaluated.\n"
>>> t("Stop, \$money isn't evaluated.")()
u"Stop, $money isn't evaluated.\n"
Test space sensitivity.
>>> t('$def with (x)\n$x')(1)
u'1\n'
>>> t('$def with(x ,y)\n$x')(1, 1)
u'1\n'
>>> t('$(1 + 2*3 + 4)')()
u'11\n'
Make sure globals are working.
>>> t('$x')()
Traceback (most recent call last):
...
NameError: global name 'x' is not defined
>>> t('$x', globals={'x': 1})()
u'1\n'
Can't change globals.
>>> t('$ x = 2\n$x', globals={'x': 1})()
u'2\n'
>>> t('$ x = x + 1\n$x', globals={'x': 1})()
Traceback (most recent call last):
...
UnboundLocalError: local variable 'x' referenced before assignment
Make sure builtins are customizable.
>>> t('$min(1, 2)')()
u'1\n'
>>> t('$min(1, 2)', builtins={})()
Traceback (most recent call last):
...
NameError: global name 'min' is not defined
Test vars.
>>> x = t('$var x: 1')()
>>> x.x
u'1'
>>> x = t('$var x = 1')()
>>> x.x
1
>>> x = t('$var x: \n foo\n bar')()
>>> x.x
u'foo\nbar\n'
Test BOM chars.
>>> t('\xef\xbb\xbf$def with(x)\n$x')('foo')
u'foo\n'
Test for with weird cases.
>>> t('$for i in range(10)[1:5]:\n $i')()
u'1\n2\n3\n4\n'
>>> t("$for k, v in {'a': 1, 'b': 2}.items():\n $k $v")()
u'a 1\nb 2\n'
>>> t("$for k, v in ({'a': 1, 'b': 2}.items():\n $k $v")()
Traceback (most recent call last):
...
SyntaxError: invalid syntax
Test datetime.
>>> import datetime
>>> t("$def with (date)\n$date.strftime('%m %Y')")(datetime.datetime(2009, 1, 1))
u'01 2009\n'
"""
pass
if __name__ == "__main__":
import sys
if '--compile' in sys.argv:
compile_templates(sys.argv[2])
else:
import doctest
doctest.testmod()
webpy/ChangeLog.txt 0000644 0001750 0001750 00000041326 11773133455 012747 0 ustar wmb wmb # web.py changelog
## 2012-06-26 0.37
* Fixed datestr issue on Windows -- #155
* Fixed Python 2.4 compatability issues (tx fredludlow)
* Fixed error in utils.safewrite (tx shuge) -- #95
* Allow use of web.data() with app.request() -- #105
* Fixed an issue with session initializaton (tx beardedprojamz) -- #109
* Allow custom message on 400 Bad Request (tx patryk) -- #121
* Made djangoerror work on GAE. -- #80
* Handle malformatted data in the urls. -- #117
* Made it easier to stop the dev server -- #100, #122
* Added support fot customizing cookie_path in session (tx larsga) -- #89
* Added exception for "415 Unsupported Media" (tx JirkaChadima) -- #145
* Added GroupedDropdown to support `