diff --git a/ate/__init__.py b/ate/__init__.py index 08d79c0e..f93e0653 100644 --- a/ate/__init__.py +++ b/ate/__init__.py @@ -1 +1 @@ -__version__ = '0.5.1' \ No newline at end of file +__version__ = '0.5.2' \ No newline at end of file diff --git a/ate/client.py b/ate/client.py index a6d29ea0..81ca0ace 100644 --- a/ate/client.py +++ b/ate/client.py @@ -1,17 +1,27 @@ +import json import logging import re import time import requests +from ate.exception import ParamsError from requests import Request, Response from requests.exceptions import (InvalidSchema, InvalidURL, MissingSchema, RequestException) -from ate.exception import ParamsError - absolute_http_url_regexp = re.compile(r"^https?://", re.I) +def process_kwargs(method, **kwargs): + if method == "POST": + # if request content-type is application/json, request data should be dumped + content_type = kwargs.get("headers", {}).get("content-type", "") + if content_type.startswith("application/json") and "data" in kwargs: + kwargs["data"] = json.dumps(kwargs["data"]) + + return kwargs + + class ApiResponse(Response): def raise_for_status(self): @@ -142,6 +152,7 @@ class HttpSession(requests.Session): Safe mode has been removed from requests 1.x. """ try: + kwargs = process_kwargs(method, **kwargs) return requests.Session.request(self, method, url, **kwargs) except (MissingSchema, InvalidSchema, InvalidURL): raise diff --git a/tests/test_client.py b/tests/test_client.py index b6ac1009..6e13f47a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,4 @@ -from ate.client import HttpSession +from ate.client import HttpSession, process_kwargs from tests.base import ApiServerUnittest class TestHttpClient(ApiServerUnittest): @@ -35,3 +35,16 @@ class TestHttpClient(ApiServerUnittest): resp = self.api_client.post(url, json=data, headers=self.headers) self.assertEqual(201, resp.status_code) self.assertEqual(True, resp.json()['success']) + + def test_process_kwargs(self): + kwargs = { + "headers": { + "content-type": "application/json; charset=utf-8" + }, + "data": { + "a": 1, + "b": 2 + } + } + kwargs = process_kwargs("POST", **kwargs) + self.assertEqual(kwargs["data"], '{"a": 1, "b": 2}')