From bdf13516d2647064628546db895f31916efcdb29 Mon Sep 17 00:00:00 2001 From: unknown Date: Mon, 17 Dec 2012 17:49:38 +0100 Subject: [PATCH] Implemented encoder support for Python Set object and derivatives. Bumped version --- python/objToJSON.c | 75 ++++++++++++++++++++++++++++++++++++++++++++-- python/version.h | 2 +- tests/tests.py | 24 +++++++++++++-- 3 files changed, 96 insertions(+), 5 deletions(-) diff --git a/python/objToJSON.c b/python/objToJSON.c index f836e78..0ae0504 100644 --- a/python/objToJSON.c +++ b/python/objToJSON.c @@ -61,6 +61,7 @@ typedef struct __TypeContext PyObject *itemValue; PyObject *itemName; PyObject *attrList; + PyObject *iterator; JSINT64 longValue; @@ -212,6 +213,62 @@ char *Tuple_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen) return NULL; } +//============================================================================= +// Iterator iteration functions +// itemValue is borrowed reference, no ref counting +//============================================================================= +void Iter_iterBegin(JSOBJ obj, JSONTypeContext *tc) +{ + GET_TC(tc)->itemValue = NULL; + GET_TC(tc)->iterator = PyObject_GetIter(obj); +} + +int Iter_iterNext(JSOBJ obj, JSONTypeContext *tc) +{ + PyObject *item; + + if (GET_TC(tc)->itemValue) + { + Py_DECREF(GET_TC(tc)->itemValue); + GET_TC(tc)->itemValue = NULL; + } + + item = PyIter_Next(GET_TC(tc)->iterator); + + if (item == NULL) + { + return 0; + } + + GET_TC(tc)->itemValue = item; + return 1; +} + +void Iter_iterEnd(JSOBJ obj, JSONTypeContext *tc) +{ + if (GET_TC(tc)->itemValue) + { + Py_DECREF(GET_TC(tc)->itemValue); + GET_TC(tc)->itemValue = NULL; + } + + if (GET_TC(tc)->iterator) + { + Py_DECREF(GET_TC(tc)->iterator); + GET_TC(tc)->iterator = NULL; + } +} + +JSOBJ Iter_iterGetValue(JSOBJ obj, JSONTypeContext *tc) +{ + return GET_TC(tc)->itemValue; +} + +char *Iter_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen) +{ + return NULL; +} + //============================================================================= // Dir iteration functions // itemName ref is borrowed from PyObject_Dir (attrList). No refcount @@ -481,6 +538,7 @@ void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc) if (PyIter_Check(obj)) { + PRINTMARK(); goto ISITERABLE; } @@ -564,7 +622,6 @@ void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc) ISITERABLE: - if (PyDict_Check(obj)) { PRINTMARK(); @@ -603,6 +660,19 @@ ISITERABLE: pc->iterGetName = Tuple_iterGetName; return; } + else + if (PyAnySet_Check(obj)) + { + PRINTMARK(); + tc->type = JT_ARRAY; + pc->iterBegin = Iter_iterBegin; + pc->iterEnd = Iter_iterEnd; + pc->iterNext = Iter_iterNext; + pc->iterGetValue = Iter_iterGetValue; + pc->iterGetName = Iter_iterGetName; + return; + + } toDictFunc = PyObject_GetAttrString(obj, "toDict"); @@ -640,7 +710,8 @@ ISITERABLE: } PyErr_Clear(); - + + PRINTMARK(); tc->type = JT_OBJECT; pc->iterBegin = Dir_iterBegin; pc->iterEnd = Dir_iterEnd; diff --git a/python/version.h b/python/version.h index 741c08b..76f1127 100644 --- a/python/version.h +++ b/python/version.h @@ -31,4 +31,4 @@ Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights rese */ -#define UJSON_VERSION "1.26" +#define UJSON_VERSION "1.27" diff --git a/tests/tests.py b/tests/tests.py index 066a772..6a30d30 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -830,6 +830,25 @@ class UltraJSONTests(TestCase): self.assertEquals(1.7893, ujson.loads("1.7893")) self.assertEquals(1.893, ujson.loads("1.893")) self.assertEquals(1.3, ujson.loads("1.3")) + + def test_encodeBigSet(self): + s = set() + for x in xrange(0, 100000): + s.add(x) + ujson.encode(s) + + def test_encodeEmptySet(self): + s = set() + self.assertEquals("[]", ujson.encode(s)) + + def test_encodeSet(self): + s = set([1,2,3,4,5,6,7,8,9]) + enc = ujson.encode(s) + dec = ujson.decode(enc) + + for v in dec: + self.assertTrue(v in s) + """ def test_decodeNumericIntFrcOverflow(self): @@ -857,8 +876,9 @@ raise NotImplementedError("Implement this test!") if __name__ == "__main__": unittest.main() -""" + # Use this to look for memory leaks +""" if __name__ == '__main__': from guppy import hpy hp = hpy() @@ -870,4 +890,4 @@ if __name__ == '__main__': pass heap = hp.heapu() print heap -""" +"""