# Copyright 2016 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import tempfile
import mock
from OpenSSL import crypto
import six
from glare.common import exception as exc
from glare.common import utils
from glare.tests.unit import base
[docs]class TestUtils(base.BaseTestCase):
"""Test class for glare.common.utils"""
[docs] def test_validate_quotes(self):
self.assertIsNone(utils.validate_quotes('"classic"'))
self.assertIsNone(utils.validate_quotes('This is a good string'))
self.assertIsNone(utils.validate_quotes
('"comma after quotation mark should work",'))
self.assertIsNone(utils.validate_quotes
(',"comma before quotation mark should work"'))
self.assertIsNone(utils.validate_quotes('"we have quotes \\" inside"'))
[docs] def test_validate_quotes_negative(self):
self.assertRaises(exc.InvalidParameterValue,
utils.validate_quotes, 'not_comma"blabla"')
self.assertRaises(exc.InvalidParameterValue, utils.validate_quotes,
'"No comma after quotation mark"Not_comma')
self.assertRaises(exc.InvalidParameterValue,
utils.validate_quotes, '"The quote is not closed')
[docs] def test_no_4bytes_params(self):
@utils.no_4byte_params
def test_func(*args, **kwargs):
return args, kwargs
bad_char = u'\U0001f62a'
# params without 4bytes unicode are okay
args, kwargs = test_func('val1', param='val2')
self.assertEqual(('val1',), args)
self.assertEqual({'param': 'val2'}, kwargs)
# test various combinations with bad param
self.assertRaises(exc.BadRequest, test_func,
bad_char)
self.assertRaises(exc.BadRequest, test_func,
**{bad_char: 'val1'})
self.assertRaises(exc.BadRequest, test_func,
**{'param': bad_char})
[docs]class TestReaders(base.BaseTestCase):
"""Test various readers in glare.common.utils"""
[docs] def test_cooperative_reader_iterator(self):
"""Ensure cooperative reader class accesses all bytes of file"""
BYTES = 1024
bytes_read = 0
with tempfile.TemporaryFile('w+') as tmp_fd:
tmp_fd.write('*' * BYTES)
tmp_fd.seek(0)
for chunk in utils.CooperativeReader(tmp_fd):
bytes_read += len(chunk)
self.assertEqual(BYTES, bytes_read)
[docs] def test_cooperative_reader_explicit_read(self):
BYTES = 1024
bytes_read = 0
with tempfile.TemporaryFile('w+') as tmp_fd:
tmp_fd.write('*' * BYTES)
tmp_fd.seek(0)
reader = utils.CooperativeReader(tmp_fd)
byte = reader.read(1)
while len(byte) != 0:
bytes_read += 1
byte = reader.read(1)
self.assertEqual(BYTES, bytes_read)
[docs] def test_cooperative_reader_no_read_method(self):
BYTES = 1024
stream = [b'*'] * BYTES
reader = utils.CooperativeReader(stream)
bytes_read = 0
byte = reader.read()
while len(byte) != 0:
bytes_read += 1
byte = reader.read()
self.assertEqual(BYTES, bytes_read)
# some data may be left in the buffer
reader = utils.CooperativeReader(stream)
reader.buffer = 'some data'
buffer_string = reader.read()
self.assertEqual('some data', buffer_string)
[docs] def test_cooperative_reader_no_read_method_buffer_size(self):
# Decrease buffer size to 1000 bytes to test its overflow
with mock.patch('glare.common.utils.MAX_COOP_READER_BUFFER_SIZE',
1000):
BYTES = 1024
stream = [b'*'] * BYTES
reader = utils.CooperativeReader(stream)
# Reading 1001 bytes to the buffer leads to 413 error
self.assertRaises(exc.RequestEntityTooLarge, reader.read, 1001)
[docs] def test_cooperative_reader_of_iterator(self):
"""Ensure cooperative reader supports iterator backends too"""
data = b'abcdefgh'
data_list = [data[i:i + 1] * 3 for i in range(len(data))]
reader = utils.CooperativeReader(data_list)
chunks = []
while True:
chunks.append(reader.read(3))
if chunks[-1] == b'':
break
meat = b''.join(chunks)
self.assertEqual(b'aaabbbcccdddeeefffggghhh', meat)
[docs] def test_cooperative_reader_of_iterator_stop_iteration_err(self):
"""Ensure cooperative reader supports iterator backends too"""
reader = utils.CooperativeReader([l * 3 for l in ''])
chunks = []
while True:
chunks.append(reader.read(3))
if chunks[-1] == b'':
break
meat = b''.join(chunks)
self.assertEqual(b'', meat)
def _create_generator(self, chunk_size, max_iterations):
chars = b'abc'
iteration = 0
while True:
index = iteration % len(chars)
chunk = chars[index:index + 1] * chunk_size
yield chunk
iteration += 1
if iteration >= max_iterations:
raise StopIteration()
def _test_reader_chunked(self, chunk_size, read_size, max_iterations=5):
generator = self._create_generator(chunk_size, max_iterations)
reader = utils.CooperativeReader(generator)
result = bytearray()
while True:
data = reader.read(read_size)
if len(data) == 0:
break
self.assertLessEqual(len(data), read_size)
result += data
expected = (b'a' * chunk_size +
b'b' * chunk_size +
b'c' * chunk_size +
b'a' * chunk_size +
b'b' * chunk_size)
self.assertEqual(expected, bytes(result))
[docs] def test_cooperative_reader_preserves_size_chunk_less_then_read(self):
self._test_reader_chunked(43, 101)
[docs] def test_cooperative_reader_preserves_size_chunk_equals_read(self):
self._test_reader_chunked(1024, 1024)
[docs] def test_cooperative_reader_preserves_size_chunk_more_then_read(self):
chunk_size = 16 * 1024 * 1024 # 16 Mb, as in remote http source
read_size = 8 * 1024 # 8k, as in httplib
self._test_reader_chunked(chunk_size, read_size)
[docs] def test_limiting_reader(self):
"""Ensure limiting reader class accesses all bytes of file"""
BYTES = 1024
bytes_read = 0
data = six.BytesIO(b"*" * BYTES)
for chunk in utils.LimitingReader(data, BYTES):
bytes_read += len(chunk)
self.assertEqual(BYTES, bytes_read)
bytes_read = 0
data = six.BytesIO(b"*" * BYTES)
reader = utils.LimitingReader(data, BYTES)
byte = reader.read(1)
while len(byte) != 0:
bytes_read += 1
byte = reader.read(1)
self.assertEqual(BYTES, bytes_read)
[docs] def test_limiting_reader_fails(self):
"""Ensure limiting reader class throws exceptions if limit exceeded"""
BYTES = 1024
def _consume_all_iter():
bytes_read = 0
data = six.BytesIO(b"*" * BYTES)
for chunk in utils.LimitingReader(data, BYTES - 1):
bytes_read += len(chunk)
self.assertRaises(exc.RequestEntityTooLarge, _consume_all_iter)
def _consume_all_read():
bytes_read = 0
data = six.BytesIO(b"*" * BYTES)
reader = utils.LimitingReader(data, BYTES - 1)
byte = reader.read(1)
while len(byte) != 0:
bytes_read += 1
byte = reader.read(1)
self.assertRaises(exc.RequestEntityTooLarge, _consume_all_read)
[docs] def test_blob_iterator(self):
BYTES = 1024
bytes_read = 0
stream = [b'*'] * BYTES
for chunk in utils.BlobIterator(stream, 64):
bytes_read += len(chunk)
self.assertEqual(BYTES, bytes_read)
[docs]class TestKeyCert(base.BaseTestCase):
[docs] def test_validate_key_cert_key(self):
var_dir = os.path.abspath(os.path.join(os.path.dirname(__file__),
'../', 'var'))
keyfile = os.path.join(var_dir, 'privatekey.key')
certfile = os.path.join(var_dir, 'certificate.crt')
utils.validate_key_cert(keyfile, certfile)
[docs] def test_validate_key_cert_no_private_key(self):
with tempfile.NamedTemporaryFile('w+') as tmpf:
self.assertRaises(RuntimeError,
utils.validate_key_cert,
"/not/a/file", tmpf.name)
[docs] def test_validate_key_cert_cert_cant_read(self):
with tempfile.NamedTemporaryFile('w+') as keyf:
with tempfile.NamedTemporaryFile('w+') as certf:
os.chmod(certf.name, 0)
self.assertRaises(RuntimeError,
utils.validate_key_cert,
keyf.name, certf.name)
[docs] def test_validate_key_cert_key_cant_read(self):
with tempfile.NamedTemporaryFile('w+') as keyf:
with tempfile.NamedTemporaryFile('w+') as certf:
os.chmod(keyf.name, 0)
self.assertRaises(RuntimeError,
utils.validate_key_cert,
keyf.name, certf.name)
[docs] def test_validate_key_cert_key_crypto_error(self):
var_dir = os.path.abspath(os.path.join(os.path.dirname(__file__),
'../', 'var'))
keyfile = os.path.join(var_dir, 'privatekey.key')
certfile = os.path.join(var_dir, 'certificate.crt')
with mock.patch('OpenSSL.crypto.verify', side_effect=crypto.Error):
self.assertRaises(RuntimeError,
utils.validate_key_cert,
keyfile, certfile)