Source code for glare.tests.unit.test_utils

# Copyright 2016 OpenStack Foundation.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import os
import tempfile

from datetime import datetime
import mock
from OpenSSL import crypto
import six

from glare.common import exception as exc
from glare.common import utils
from glare.tests.unit import base


[docs]class TestUtils(base.BaseTestCase): """Test class for glare.common.utils"""
[docs] def test_validate_quotes(self): self.assertIsNone(utils.validate_quotes('"classic"')) self.assertIsNone(utils.validate_quotes('This is a good string')) self.assertIsNone(utils.validate_quotes ('"comma after quotation mark should work",')) self.assertIsNone(utils.validate_quotes (',"comma before quotation mark should work"')) self.assertIsNone(utils.validate_quotes('"we have quotes \\" inside"'))
[docs] def test_validate_quotes_negative(self): self.assertRaises(exc.InvalidParameterValue, utils.validate_quotes, 'not_comma"blabla"') self.assertRaises(exc.InvalidParameterValue, utils.validate_quotes, '"No comma after quotation mark"Not_comma') self.assertRaises(exc.InvalidParameterValue, utils.validate_quotes, '"The quote is not closed')
[docs] def test_no_4bytes_params(self): @utils.no_4byte_params def test_func(*args, **kwargs): return args, kwargs bad_char = u'\U0001f62a' # params without 4bytes unicode are okay args, kwargs = test_func('val1', param='val2') self.assertEqual(('val1',), args) self.assertEqual({'param': 'val2'}, kwargs) # test various combinations with bad param self.assertRaises(exc.BadRequest, test_func, bad_char) self.assertRaises(exc.BadRequest, test_func, **{bad_char: 'val1'}) self.assertRaises(exc.BadRequest, test_func, **{'param': bad_char})
[docs] def test_split_filter_op(self): time = datetime.today().isoformat() self.assertEqual(True, ("and", "eq", "5.0") == (utils.split_filter_op("5.0"))) self.assertEqual(True, ("and", "eq", "or") == utils.split_filter_op("or")) self.assertEqual(True, ("and", "eq", time) == utils.split_filter_op(time)) self.assertEqual(True, ("or", "eq", "5.0") == utils.split_filter_op("or:5.0")) self.assertEqual(True, ("or", "eq", "eq") == utils.split_filter_op("or:eq")) self.assertEqual(True, ("and", "lt", "or") == utils.split_filter_op("lt:or")) self.assertEqual(True, ("or", "eq", "art_name") == utils.split_filter_op("or:art_name")) self.assertEqual(True, ("and", "t", "tt") == utils.split_filter_op("t:tt")) self.assertEqual(True, ("or", "eq", time) == utils.split_filter_op("or:" + time)) self.assertEqual(True, ("and", "lt", time) == utils.split_filter_op("lt:" + time)) self.assertEqual(True, ("and", "eq", "tom and_jerry") == utils.split_filter_op("eq:tom and_jerry")) self.assertEqual(True, ("or", "eq", "5.0") == utils.split_filter_op("or:eq:5.0")) self.assertEqual(True, ("and", "invalid_combiner", "eq:5.0") == utils.split_filter_op("invalid_combiner:eq:5.0")) self.assertEqual(True, ("or", "eq", "and") == utils.split_filter_op("or:eq:and")) self.assertEqual(True, ("and", "eq", "or:and") == utils.split_filter_op("eq:or:and")) self.assertEqual(True, ("or", "lt", time) == utils.split_filter_op("or:lt:" + time)) self.assertEqual(True, ("or", "in", "1,2,3") == utils.split_filter_op("or:in:1,2,3")) self.assertEqual(True, ("and", "tt", "t") == utils.split_filter_op("tt:t")) self.assertEqual(True, ("or", "tt", "t") == utils.split_filter_op("or:tt:t")) self.assertEqual(True, ("and", "eq", "rt:and") == utils.split_filter_op("eq:rt:and")) self.assertEqual(True, ("and", "tt", "tt:t") == utils.split_filter_op("and:tt:tt:t")) self.assertEqual(True, ("or", "lt", "tt:tt:tt:tt:tt:tt:tt") == utils.split_filter_op("or:lt:tt:tt:tt:tt:tt:tt:tt"))
[docs]class TestReaders(base.BaseTestCase): """Test various readers in glare.common.utils"""
[docs] def test_cooperative_reader_iterator(self): """Ensure cooperative reader class accesses all bytes of file""" BYTES = 1024 bytes_read = 0 with tempfile.TemporaryFile('w+') as tmp_fd: tmp_fd.write('*' * BYTES) tmp_fd.seek(0) for chunk in utils.CooperativeReader(tmp_fd): bytes_read += len(chunk) self.assertEqual(BYTES, bytes_read)
[docs] def test_cooperative_reader_explicit_read(self): BYTES = 1024 bytes_read = 0 with tempfile.TemporaryFile('w+') as tmp_fd: tmp_fd.write('*' * BYTES) tmp_fd.seek(0) reader = utils.CooperativeReader(tmp_fd) byte = reader.read(1) while len(byte) != 0: bytes_read += 1 byte = reader.read(1) self.assertEqual(BYTES, bytes_read)
[docs] def test_cooperative_reader_no_read_method(self): BYTES = 1024 stream = [b'*'] * BYTES reader = utils.CooperativeReader(stream) bytes_read = 0 byte = reader.read() while len(byte) != 0: bytes_read += 1 byte = reader.read() self.assertEqual(BYTES, bytes_read) # some data may be left in the buffer reader = utils.CooperativeReader(stream) reader.buffer = 'some data' buffer_string = reader.read() self.assertEqual('some data', buffer_string)
[docs] def test_cooperative_reader_no_read_method_buffer_size(self): # Decrease buffer size to 1000 bytes to test its overflow with mock.patch('glare.common.utils.MAX_COOP_READER_BUFFER_SIZE', 1000): BYTES = 1024 stream = [b'*'] * BYTES reader = utils.CooperativeReader(stream) # Reading 1001 bytes to the buffer leads to 413 error self.assertRaises(exc.RequestEntityTooLarge, reader.read, 1001)
[docs] def test_cooperative_reader_of_iterator(self): """Ensure cooperative reader supports iterator backends too""" data = b'abcdefgh' data_list = [data[i:i + 1] * 3 for i in range(len(data))] reader = utils.CooperativeReader(data_list) chunks = [] while True: chunks.append(reader.read(3)) if chunks[-1] == b'': break meat = b''.join(chunks) self.assertEqual(b'aaabbbcccdddeeefffggghhh', meat)
[docs] def test_cooperative_reader_of_iterator_stop_iteration_err(self): """Ensure cooperative reader supports iterator backends too""" reader = utils.CooperativeReader([l * 3 for l in '']) chunks = [] while True: chunks.append(reader.read(3)) if chunks[-1] == b'': break meat = b''.join(chunks) self.assertEqual(b'', meat)
def _create_generator(self, chunk_size, max_iterations): chars = b'abc' iteration = 0 while True: index = iteration % len(chars) chunk = chars[index:index + 1] * chunk_size yield chunk iteration += 1 if iteration >= max_iterations: return def _test_reader_chunked(self, chunk_size, read_size, max_iterations=5): generator = self._create_generator(chunk_size, max_iterations) reader = utils.CooperativeReader(generator) result = bytearray() while True: data = reader.read(read_size) if len(data) == 0: break self.assertLessEqual(len(data), read_size) result += data expected = (b'a' * chunk_size + b'b' * chunk_size + b'c' * chunk_size + b'a' * chunk_size + b'b' * chunk_size) self.assertEqual(expected, bytes(result))
[docs] def test_cooperative_reader_preserves_size_chunk_less_then_read(self): self._test_reader_chunked(43, 101)
[docs] def test_cooperative_reader_preserves_size_chunk_equals_read(self): self._test_reader_chunked(1024, 1024)
[docs] def test_cooperative_reader_preserves_size_chunk_more_then_read(self): chunk_size = 16 * 1024 * 1024 # 16 Mb, as in remote http source read_size = 8 * 1024 # 8k, as in httplib self._test_reader_chunked(chunk_size, read_size)
[docs] def test_limiting_reader(self): """Ensure limiting reader class accesses all bytes of file""" BYTES = 1024 bytes_read = 0 data = six.BytesIO(b"*" * BYTES) for chunk in utils.LimitingReader(data, BYTES): bytes_read += len(chunk) self.assertEqual(BYTES, bytes_read) bytes_read = 0 data = six.BytesIO(b"*" * BYTES) reader = utils.LimitingReader(data, BYTES) byte = reader.read(1) while len(byte) != 0: bytes_read += 1 byte = reader.read(1) self.assertEqual(BYTES, bytes_read)
[docs] def test_limiting_reader_fails(self): """Ensure limiting reader class throws exceptions if limit exceeded""" BYTES = 1024 def _consume_all_iter(): bytes_read = 0 data = six.BytesIO(b"*" * BYTES) for chunk in utils.LimitingReader(data, BYTES - 1): bytes_read += len(chunk) self.assertRaises(exc.RequestEntityTooLarge, _consume_all_iter) def _consume_all_read(): bytes_read = 0 data = six.BytesIO(b"*" * BYTES) reader = utils.LimitingReader(data, BYTES - 1) byte = reader.read(1) while len(byte) != 0: bytes_read += 1 byte = reader.read(1) self.assertRaises(exc.RequestEntityTooLarge, _consume_all_read)
[docs] def test_blob_iterator(self): BYTES = 1024 bytes_read = 0 stream = [b'*'] * BYTES for chunk in utils.BlobIterator(stream, 64): bytes_read += len(chunk) self.assertEqual(BYTES, bytes_read)
[docs] def test_blob_iterator_big_data(self): BYTES = 1000000 bytes_read = 0 stream = [b'*'] * BYTES for chunk in utils.BlobIterator(stream): bytes_read += len(chunk) self.assertEqual(BYTES, bytes_read)
[docs]class TestKeyCert(base.BaseTestCase):
[docs] def test_validate_key_cert_key(self): var_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../', 'var')) keyfile = os.path.join(var_dir, 'privatekey.key') certfile = os.path.join(var_dir, 'certificate.crt') utils.validate_key_cert(keyfile, certfile)
[docs] def test_validate_key_cert_no_private_key(self): with tempfile.NamedTemporaryFile('w+') as tmpf: self.assertRaises(RuntimeError, utils.validate_key_cert, "/not/a/file", tmpf.name)
[docs] def test_validate_key_cert_cert_cant_read(self): with tempfile.NamedTemporaryFile('w+') as keyf: with tempfile.NamedTemporaryFile('w+') as certf: os.chmod(certf.name, 0) self.assertRaises(RuntimeError, utils.validate_key_cert, keyf.name, certf.name)
[docs] def test_validate_key_cert_key_cant_read(self): with tempfile.NamedTemporaryFile('w+') as keyf: with tempfile.NamedTemporaryFile('w+') as certf: os.chmod(keyf.name, 0) self.assertRaises(RuntimeError, utils.validate_key_cert, keyf.name, certf.name)
[docs] def test_validate_key_cert_key_crypto_error(self): var_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../', 'var')) keyfile = os.path.join(var_dir, 'privatekey.key') certfile = os.path.join(var_dir, 'certificate.crt') with mock.patch('OpenSSL.crypto.verify', side_effect=crypto.Error): self.assertRaises(RuntimeError, utils.validate_key_cert, keyfile, certfile)