1
0
Fork 0
mirror of https://github.com/ultrajson/ultrajson.git synced 2024-06-07 23:46:10 +02:00

Add a default keyword argument to dumps

dump and dumps functions in python json stdlib have a default keyword argument.
It's useful for serializing complex objects. Supporting this argument will improve compatibility and flexibility of ujson.
This commit is contained in:
garenchan 2021-08-31 17:11:25 +08:00
parent 8f5ad61f3d
commit b7fba98136
2 changed files with 90 additions and 4 deletions

View File

@ -72,6 +72,12 @@ typedef struct __TypeContext
#define GET_TC(__ptrtc) ((TypeContext *)((__ptrtc)->prv)) #define GET_TC(__ptrtc) ((TypeContext *)((__ptrtc)->prv))
// If newObj is set, we should use it rather than JSOBJ
#define GET_OBJ(__jsobj, __ptrtc) (GET_TC(__ptrtc)->newObj ? GET_TC(__ptrtc)->newObj : __jsobj)
// Avoid infinite loop caused by the default function
#define DEFAULT_FN_MAX_DEPTH 3
struct PyDictIterState struct PyDictIterState
{ {
PyObject *keys; PyObject *keys;
@ -432,7 +438,8 @@ static void SetupDictIter(PyObject *dictObj, TypeContext *pc, JSONObjectEncoder
static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObjectEncoder *enc) static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObjectEncoder *enc)
{ {
PyObject *obj, *objRepr, *exc; PyObject *obj, *objRepr, *exc, *defaultFn, *newObj;
int level = 0;
TypeContext *pc; TypeContext *pc;
PRINTMARK(); PRINTMARK();
if (!_obj) if (!_obj)
@ -442,6 +449,7 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
} }
obj = (PyObject*) _obj; obj = (PyObject*) _obj;
defaultFn = (PyObject*) enc->prv;
tc->prv = PyObject_Malloc(sizeof(TypeContext)); tc->prv = PyObject_Malloc(sizeof(TypeContext));
pc = (TypeContext *) tc->prv; pc = (TypeContext *) tc->prv;
@ -462,6 +470,7 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
pc->longValue = 0; pc->longValue = 0;
pc->rawJSONValue = NULL; pc->rawJSONValue = NULL;
BEGIN:
if (PyIter_Check(obj)) if (PyIter_Check(obj))
{ {
PRINTMARK(); PRINTMARK();
@ -553,7 +562,6 @@ static void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObject
return; return;
} }
ISITERABLE: ISITERABLE:
if (PyDict_Check(obj)) if (PyDict_Check(obj))
{ {
@ -651,6 +659,31 @@ ISITERABLE:
return; return;
} }
DEFAULT:
if (defaultFn)
{
// Break infinite loop
if (level >= DEFAULT_FN_MAX_DEPTH)
{
PRINTMARK();
PyErr_Format(PyExc_TypeError, "maximum recursion depth exceeded");
goto INVALID;
}
newObj = PyObject_CallFunctionObjArgs(defaultFn, obj, NULL);
if (newObj)
{
PRINTMARK();
obj = pc->newObj = newObj;
level += 1;
goto BEGIN;
}
else
{
goto INVALID;
}
}
PRINTMARK(); PRINTMARK();
PyErr_Clear(); PyErr_Clear();
@ -682,12 +715,14 @@ static void Object_endTypeContext(JSOBJ obj, JSONTypeContext *tc)
static const char *Object_getStringValue(JSOBJ obj, JSONTypeContext *tc, size_t *_outLen) static const char *Object_getStringValue(JSOBJ obj, JSONTypeContext *tc, size_t *_outLen)
{ {
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->PyTypeToJSON (obj, tc, NULL, _outLen); return GET_TC(tc)->PyTypeToJSON (obj, tc, NULL, _outLen);
} }
static JSINT64 Object_getLongValue(JSOBJ obj, JSONTypeContext *tc) static JSINT64 Object_getLongValue(JSOBJ obj, JSONTypeContext *tc)
{ {
JSINT64 ret; JSINT64 ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL); GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret; return ret;
} }
@ -695,6 +730,7 @@ static JSINT64 Object_getLongValue(JSOBJ obj, JSONTypeContext *tc)
static JSUINT64 Object_getUnsignedLongValue(JSOBJ obj, JSONTypeContext *tc) static JSUINT64 Object_getUnsignedLongValue(JSOBJ obj, JSONTypeContext *tc)
{ {
JSUINT64 ret; JSUINT64 ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL); GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret; return ret;
} }
@ -702,6 +738,7 @@ static JSUINT64 Object_getUnsignedLongValue(JSOBJ obj, JSONTypeContext *tc)
static JSINT32 Object_getIntValue(JSOBJ obj, JSONTypeContext *tc) static JSINT32 Object_getIntValue(JSOBJ obj, JSONTypeContext *tc)
{ {
JSINT32 ret; JSINT32 ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL); GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret; return ret;
} }
@ -709,6 +746,7 @@ static JSINT32 Object_getIntValue(JSOBJ obj, JSONTypeContext *tc)
static double Object_getDoubleValue(JSOBJ obj, JSONTypeContext *tc) static double Object_getDoubleValue(JSOBJ obj, JSONTypeContext *tc)
{ {
double ret; double ret;
obj = GET_OBJ(obj, tc);
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL); GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret; return ret;
} }
@ -720,27 +758,31 @@ static void Object_releaseObject(JSOBJ _obj)
static int Object_iterNext(JSOBJ obj, JSONTypeContext *tc) static int Object_iterNext(JSOBJ obj, JSONTypeContext *tc)
{ {
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->iterNext(obj, tc); return GET_TC(tc)->iterNext(obj, tc);
} }
static void Object_iterEnd(JSOBJ obj, JSONTypeContext *tc) static void Object_iterEnd(JSOBJ obj, JSONTypeContext *tc)
{ {
obj = GET_OBJ(obj, tc);
GET_TC(tc)->iterEnd(obj, tc); GET_TC(tc)->iterEnd(obj, tc);
} }
static JSOBJ Object_iterGetValue(JSOBJ obj, JSONTypeContext *tc) static JSOBJ Object_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
{ {
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->iterGetValue(obj, tc); return GET_TC(tc)->iterGetValue(obj, tc);
} }
static char *Object_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen) static char *Object_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
{ {
obj = GET_OBJ(obj, tc);
return GET_TC(tc)->iterGetName(obj, tc, outLen); return GET_TC(tc)->iterGetName(obj, tc, outLen);
} }
PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs) PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
{ {
static char *kwlist[] = { "obj", "ensure_ascii", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", "allow_nan", "reject_bytes", NULL }; static char *kwlist[] = { "obj", "ensure_ascii", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", "allow_nan", "reject_bytes", "default", NULL };
char buffer[65536]; char buffer[65536];
char *ret; char *ret;
@ -751,6 +793,7 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
PyObject *oencodeHTMLChars = NULL; PyObject *oencodeHTMLChars = NULL;
PyObject *oescapeForwardSlashes = NULL; PyObject *oescapeForwardSlashes = NULL;
PyObject *osortKeys = NULL; PyObject *osortKeys = NULL;
PyObject *odefaultFn = NULL;
int allowNan = -1; int allowNan = -1;
int orejectBytes = -1; int orejectBytes = -1;
@ -785,7 +828,7 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
PRINTMARK(); PRINTMARK();
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OOOOiii", kwlist, &oinput, &oensureAscii, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent, &allowNan, &orejectBytes)) if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OOOOiiiO", kwlist, &oinput, &oensureAscii, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent, &allowNan, &orejectBytes, &odefaultFn))
{ {
return NULL; return NULL;
} }
@ -815,6 +858,12 @@ PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
encoder.allowNan = allowNan; encoder.allowNan = allowNan;
} }
if (odefaultFn != NULL && odefaultFn != Py_None)
{
// Here use prv to store default function
encoder.prv = odefaultFn;
}
if (encoder.allowNan) if (encoder.allowNan)
{ {
csInf = "Inf"; csInf = "Inf";

View File

@ -1,9 +1,11 @@
import datetime as dt
import decimal import decimal
import io import io
import json import json
import math import math
import re import re
import sys import sys
import uuid
from collections import OrderedDict from collections import OrderedDict
import pytest import pytest
@ -828,6 +830,41 @@ def test_encode_none_key():
assert ujson.dumps(data) == '{"null":null}' assert ujson.dumps(data) == '{"null":null}'
def test_default_function():
iso8601_time_format = "%Y-%m-%dT%H:%M:%S.%f"
class CustomObject:
pass
class UnjsonableObject:
pass
def default(value):
if isinstance(value, dt.datetime):
return value.strftime(iso8601_time_format)
elif isinstance(value, uuid.UUID):
return value.hex
elif isinstance(value, CustomObject):
raise ValueError("invalid value")
return value
now = dt.datetime.now()
expected_output = '"%s"' % now.strftime(iso8601_time_format)
assert ujson.dumps(now, default=default) == expected_output
uuid4 = uuid.uuid4()
expected_output = '"%s"' % uuid4.hex
assert ujson.dumps(uuid4, default=default) == expected_output
custom_obj = CustomObject()
with pytest.raises(ValueError, match="invalid value"):
ujson.dumps(custom_obj, default=default)
unjsonable_obj = UnjsonableObject()
with pytest.raises(TypeError, match="maximum recursion depth exceeded"):
ujson.dumps(unjsonable_obj, default=default)
""" """
def test_decode_numeric_int_frc_overflow(): def test_decode_numeric_int_frc_overflow():
input = "X.Y" input = "X.Y"