tester: Change to a simpler TFTP server

- Add a simpler TFTP to allow parallel test hardware

- Remove the imported tftpy server

Closes #4063
This commit is contained in:
Chris Johns 2020-08-26 13:38:54 +10:00
parent 37ad446d9d
commit eb3608133b
13 changed files with 766 additions and 2193 deletions

View File

@ -43,7 +43,7 @@ import sys
from rtemstoolkit import error from rtemstoolkit import error
from rtemstoolkit import reraise from rtemstoolkit import reraise
import tftpy import tftpserver
class tftp(object): class tftp(object):
'''RTEMS Testing TFTP base.''' '''RTEMS Testing TFTP base.'''
@ -88,7 +88,8 @@ class tftp(object):
def _stop(self): def _stop(self):
try: try:
if self.server: if self.server:
self.server.stop(now = True) self.server.stop()
self.finished = True
except: except:
pass pass
@ -101,6 +102,10 @@ class tftp(object):
def _timeout(self): def _timeout(self):
self._stop() self._stop()
while self.running or not self.finished:
self._unlock('_timeout')
time.sleep(0.1)
self._lock('_timeout')
if self.timeout is not None: if self.timeout is not None:
self.timeout() self.timeout()
@ -119,22 +124,21 @@ class tftp(object):
return None return None
def _listener(self): def _listener(self):
tftpy_log = logging.getLogger('tftpy') self._lock('_listener')
tftpy_log.setLevel(100) exe = self.exe
self.exe = None
self._unlock('_listener')
self.server = tftpserver.tftp_server(host = 'all',
port = self.port,
timeout = 1,
forced_file = exe,
sessions = 1)
try: try:
self.server = tftpy.TftpServer(tftproot = '.', self.server.start()
dyn_file_func = self._exe_handle) self.server.run()
except tftpy.TftpException as te: except:
raise error.general('tftp: %s' % (str(te))) self.server.stop()
if self.server is not None: raise
try:
self.server.listen('0.0.0.0', self.port, 0.5)
except tftpy.TftpException as te:
raise error.general('tftp: %s' % (str(te)))
except IOError as ie:
if ie.errno == errno.EACCES:
raise error.general('tftp: permissions error: check tftp server port')
raise error.general('tftp: io error: %s' % (str(ie)))
def _runner(self): def _runner(self):
self._lock('_runner') self._lock('_runner')
@ -146,9 +150,7 @@ class tftp(object):
except: except:
caught = sys.exc_info() caught = sys.exc_info()
self._lock('_runner') self._lock('_runner')
self._init()
self.running = False self.running = False
self.finished = True
self.caught = caught self.caught = caught
self._unlock('_runner') self._unlock('_runner')
@ -187,6 +189,7 @@ class tftp(object):
self._timeout() self._timeout()
caught = self.caught caught = self.caught
self.caught = None self.caught = None
self._init()
self._unlock('_open') self._unlock('_open')
if caught is not None: if caught is not None:
reraise.reraise(*caught) reraise.reraise(*caught)

699
tester/rt/tftpserver.py Normal file
View File

@ -0,0 +1,699 @@
# SPDX-License-Identifier: BSD-2-Clause
'''The TFTP Server handles a read only TFTP session.'''
# Copyright (C) 2020 Chris Johns (chrisj@rtems.org)
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
from __future__ import print_function
import argparse
import os
import socket
import sys
import time
import threading
try:
import socketserver
except ImportError:
import SocketServer as socketserver
from rtemstoolkit import error
from rtemstoolkit import log
from rtemstoolkit import version
class tftp_session(object):
'''Handle the TFTP session packets initiated on the TFTP port (69).
'''
# pylint: disable=useless-object-inheritance
# pylint: disable=too-many-instance-attributes
opcodes = ['nul', 'RRQ', 'WRQ', 'DATA', 'ACK', 'ERROR', 'OACK']
OP_RRQ = 1
OP_WRQ = 2
OP_DATA = 3
OP_ACK = 4
OP_ERROR = 5
OP_OACK = 6
E_NOT_DEFINED = 0
E_FILE_NOT_FOUND = 1
E_ACCESS_VIOLATION = 2
E_DISK_FULL = 3
E_ILLEGAL_TFTP_OP = 4
E_UKNOWN_TID = 5
E_FILE_ALREADY_EXISTS = 6
E_NO_SUCH_USER = 7
E_NO_ERROR = 10
def __init__(self, host, port, base, forced_file, reader=None):
# pylint: disable=too-many-arguments
self.host = host
self.port = port
self.base = base
self.forced_file = forced_file
if reader is None:
self.data_reader = self._file_reader
else:
self.data_reader = reader
self.filein = None
self.resends_limit = 5
# These are here to shut pylint up
self.block = 0
self.block_size = 512
self.timeout = 0
self.resends = 0
self.finished = False
self.filename = None
self._reinit()
def _reinit(self):
'''Reinitialise all the class variables used by the protocol.'''
if self.filein is not None:
self.filein.close()
self.filein = None
self.block = 0
self.block_size = 512
self.timeout = 0
self.resends = 0
self.finished = False
self.filename = None
def _file_reader(self, command, **kwargs):
'''The default file reader if the user does not provide one.
The call returns a two element tuple where the first element
is an error code, and the second element is data if the error
code is 0 else it is an error message.
'''
# pylint: disable=too-many-return-statements
if command == 'open':
if 'filename' not in kwargs:
raise error.general('tftp-reader: invalid open: no filename')
filename = kwargs['filename']
try:
self.filein = open(filename, 'rb')
filesize = os.stat(filename).st_size
except FileNotFoundError:
return self.E_FILE_NOT_FOUND, 'file not found (%s)' % (
filename)
except PermissionError:
return self.E_ACCESS_VIOLATION, 'access violation'
except IOError as ioe:
return self.E_NOT_DEFINED, str(ioe)
return self.E_NO_ERROR, str(filesize)
if command == 'read':
if self.filein is None:
raise error.general('tftp-reader: read when not open')
if 'blksize' not in kwargs:
raise error.general('tftp-reader: invalid read: no blksize')
# pylint: disable=bare-except
try:
return self.E_NO_ERROR, self.filein.read(kwargs['blksize'])
except IOError as ioe:
return self.E_NOT_DEFINED, str(ioe)
except:
return self.E_NOT_DEFINED, 'unknown error'
if command == 'close':
if self.filein is not None:
self.filein.close()
self.filein = None
return self.E_NO_ERROR, "closed"
return self.E_NOT_DEFINED, 'invalid reader state'
@staticmethod
def _pack_bytes(data=None):
bdata = bytearray()
if data is not None:
if not isinstance(data, list):
data = [data]
for item in data:
if isinstance(item, int):
bdata.append(item >> 8)
bdata.append(item & 0xff)
elif isinstance(item, str):
bdata.extend(item.encode())
bdata.append(0)
else:
bdata.extend(item)
return bdata
def _response(self, opcode, data):
code = self.opcodes.index(opcode)
if code == 0 or code >= len(self.opcodes):
raise error.general('invalid opcode: ' + opcode)
bdata = self._pack_bytes([code, data])
#print(''.join(format(x, '02x') for x in bdata))
return bytes(bdata)
def _error_response(self, code, message):
if log.tracing:
log.trace('tftp: error: %s:%d: %d: %s' %
(self.host, self.port, code, message))
self.finished = True
return self._response('ERROR', self._pack_bytes([code, message, 0]))
def _data_response(self, block, data):
if len(data) < self.block_size:
self.finished = True
return self._response('DATA', self._pack_bytes([block, data]))
def _oack_response(self, data):
self.resends += 1
if self.resends >= self.resends_limit:
return self._error_response(self.E_NOT_DEFINED,
'resend limit reached')
return self._response('OACK', self._pack_bytes(data))
def _next_block(self, block):
# has the current block been acknowledged?
if block == self.block:
self.resends = 0
self.block += 1
err, data = self.data_reader('read', blksize=self.block_size)
if err != self.E_NO_ERROR:
return self._error_response(err, data)
# close if the length of data is less than the block size
if len(data) < self.block_size:
self.data_reader('close')
else:
self.resends += 1
if self.resends >= self.resends_limit:
return self._error_response(self.E_NOT_DEFINED,
'resend limit reached')
return self._data_response(self.block, data)
def _read_req(self, data):
# if the last block is not 0 something has gone wrong and
# TID match. Restart the session. It could be the client
# is a simple implementation that does not move the send
# port on each retry.
if self.block != 0:
self.data_reader('close')
self._reinit()
# Get the filename, mode and options
self.filename = self.get_option('filename', data)
if self.filename is None:
return self._error_response(self.E_NOT_DEFINED,
'filename not found in request')
if self.forced_file is not None:
self.filename = self.forced_file
# open the reader
err, message = self.data_reader('open', filename=self.filename)
if err != self.E_NO_ERROR:
return self._error_response(err, message)
# the no error on open message is the file size
try:
tsize = int(message)
except ValueError:
tsize = 0
mode = self.get_option('mode', data)
if mode is None:
return self._error_response(self.E_NOT_DEFINED,
'mode not found in request')
oack_data = self._pack_bytes()
value = self.get_option('timeout', data)
if value is not None:
oack_data += self._pack_bytes(['timeout', value])
self.timeout = int(value)
value = self.get_option('blksize', data)
if value is not None:
oack_data += self._pack_bytes(['blksize', value])
self.block_size = int(value)
else:
self.block_size = 512
value = self.get_option('tsize', data)
if value is not None and tsize > 0:
oack_data += self._pack_bytes(['tsize', str(tsize)])
# Send the options ack
return self._oack_response(oack_data)
def _write_req(self):
# WRQ is not supported
return self._error_response(self.E_ILLEGAL_TFTP_OP,
"writes not supported")
def _op_ack(self, data):
# send the next block of data
block = (data[2] << 8) | data[3]
return self._next_block(block)
def process(self, host, port, data):
'''Process the incoming client data sending a response. If the session
has finished return None.
'''
if host != self.host and port != self.port:
return self._error_response(self.E_UKNOWN_TID,
'unkown transfer ID')
if self.finished:
return None
opcode = (data[0] << 8) | data[1]
if opcode == self.OP_RRQ:
return self._read_req(data)
if opcode in [self.OP_WRQ, self.OP_DATA]:
return self._write_req()
if opcode == self.OP_ACK:
return self._op_ack(data)
return self._error_response(self.E_ILLEGAL_TFTP_OP,
"unknown or unsupported opcode")
def decode(self, host, port, data):
'''Decode the packet for diagnostic purposes.
'''
# pylint: disable=too-many-branches
out = ''
dlen = len(data)
if dlen > 2:
opcode = (data[0] << 8) | data[1]
if 0 < opcode < len(self.opcodes):
if opcode in [self.OP_RRQ, self.OP_WRQ]:
out += ' ' + self.opcodes[opcode] + ', '
i = 2
while data[i] != 0:
out += chr(data[i])
i += 1
while i < dlen - 1:
out += ', '
i += 1
while data[i] != 0:
out += chr(data[i])
i += 1
elif opcode == self.OP_DATA:
block = (data[2] << 8) | data[3]
out += ' ' + self.opcodes[opcode] + ', '
out += '#' + str(block) + ', '
if dlen > 4:
out += '%02x%02x..%02x%02x' % (data[4], data[5],
data[-2], data[-1])
else:
out += '%02x%02x%02x%02x' % (data[4], data[5], data[6],
data[6])
out += ',' + str(dlen - 4)
elif opcode == self.OP_ACK:
block = (data[2] << 8) | data[3]
out += ' ' + self.opcodes[opcode] + ' ' + str(block)
elif opcode == self.OP_ERROR:
out += 'E ' + self.opcodes[opcode] + ', '
out += str((data[2] << 8) | (data[3]))
out += ': ' + str(data[4:].decode())
i = 2
while data[i] != 0:
out += chr(data[i])
i += 1
elif opcode == self.OP_OACK:
out += ' ' + self.opcodes[opcode]
i = 1
while i < dlen - 1:
out += ', '
i += 1
while data[i] != 0:
out += chr(data[i])
i += 1
else:
out += 'E INV(%d)' % (opcode)
else:
out += 'E INVALID LENGTH'
return out[:2] + '[%s:%d] (%d) ' % (host, port, len(data)) + out[2:]
@staticmethod
def get_option(option, data):
'''Get the option from the TFTP packet.'''
dlen = len(data) - 1
opcode = (data[0] << 8) | data[1]
next_option = False
if opcode in [1, 2]:
count = 0
i = 2
while i < dlen:
value = ''
while data[i] != 0:
value += chr(data[i])
i += 1
i += 1
if option == 'filename' and count == 0:
return value
if option == 'mode' and count == 1:
return value
if value == option and (count % 1) == 0:
next_option = True
elif next_option:
return value
count += 1
return None
def get_timeout(self, default_timeout, timeout_guard):
'''Get the timeout. The timeout can be an option.'''
if self.timeout == 0:
return self.timeout + timeout_guard
return default_timeout
def get_block_size(self):
'''Get the block size. The block size can be an option.'''
return self.block_size
class udp_handler(socketserver.BaseRequestHandler):
'''TFTP UDP handler for a TFTP session.'''
def _notice(self, text):
if self.server.tftp.notices:
log.notice(text)
else:
log.trace(text)
def handle_session(self, index):
'''Handle the TFTP session data.'''
# pylint: disable=too-many-locals
# pylint: disable=broad-except
# pylint: disable=too-many-branches
client_ip = self.client_address[0]
client_port = self.client_address[1]
client = '%s:%i' % (client_ip, client_port)
self._notice('] tftp: %d: start: %s' % (index, client))
try:
session = tftp_session(client_ip, client_port,
self.server.tftp.base,
self.server.tftp.forced_file,
self.server.tftp.reader)
response = session.process(client_ip, client_port, self.request[0])
if response is not None:
if log.tracing and self.server.tftp.packet_trace:
log.trace(' > ' + session.decode(client_ip, client_port,
self.request[0]))
timeout = session.get_timeout(self.server.tftp.timeout, 1)
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(('', 0))
sock.settimeout(timeout)
while response is not None:
if log.tracing and self.server.tftp.packet_trace:
log.trace(
' < ' +
session.decode(client_ip, client_port, response))
sock.sendto(response, (client_ip, client_port))
if session.finished:
break
try:
data, address = sock.recvfrom(2 + 2 +
session.get_block_size())
if log.tracing and self.server.tftp.packet_trace:
log.trace(
' > ' +
session.decode(address[0], address[1], data))
except socket.error as serr:
if log.tracing:
log.trace('] tftp: %d: receive: %s: error: %s' \
% (index, client, serr))
return
except socket.gaierror as serr:
if log.tracing:
log.trace('] tftp: %d: receive: %s: error: %s' \
% (index, client, serr))
return
response = session.process(address[0], address[1], data)
except error.general as gerr:
self._notice('] tftp: %dd: error: %s' % (index, gerr))
except error.internal as ierr:
self._notice('] tftp: %d: error: %s' % (index, ierr))
except error.exit:
pass
except KeyboardInterrupt:
pass
except Exception as exp:
self._notice('] tftp: %d: error: %s: %s' % (index, type(exp), exp))
self._notice('] tftp: %d: end: %s' % (index, client))
def handle(self):
'''The UDP server handle method.'''
if self.server.tftp.sessions is None \
or self.server.tftp.session < self.server.tftp.sessions:
self.handle_session(self.server.tftp.next_session())
class udp_server(socketserver.ThreadingMixIn, socketserver.UDPServer):
'''UDP server. Default behaviour.'''
class tftp_server(object):
'''TFTP server runs a UDP server to handle TFTP sessions.'''
# pylint: disable=useless-object-inheritance
# pylint: disable=too-many-instance-attributes
def __init__(self,
host,
port,
timeout=10,
base=None,
forced_file=None,
sessions=None,
reader=None):
# pylint: disable=too-many-arguments
self.lock = threading.Lock()
self.notices = False
self.packet_trace = False
self.timeout = timeout
self.host = host
self.port = port
self.server = None
self.server_thread = None
if base is None:
base = os.getcwd()
self.base = base
self.forced_file = forced_file
if sessions is not None and not isinstance(sessions, int):
raise error.general('tftp session count is not a number')
self.sessions = sessions
self.session = 0
self.reader = reader
def __del__(self):
self.stop()
def _lock(self):
self.lock.acquire()
def _unlock(self):
self.lock.release()
def start(self):
'''Start the TFTP server. Returns once started.'''
# pylint: disable=attribute-defined-outside-init
if log.tracing:
log.trace('] tftp: server: %s:%i' % (self.host, self.port))
if self.host == 'all':
host = ''
else:
host = self.host
try:
self.server = udp_server((host, self.port), udp_handler)
except Exception as exp:
raise error.general('tftp server create: %s' % (exp))
# We cannot set tftp in __init__ because the object is created
# in a separate package.
self.server.tftp = self
self.server_thread = threading.Thread(target=self.server.serve_forever)
self.server_thread.daemon = True
self.server_thread.start()
def stop(self):
'''Stop the TFTP server and close the server port.'''
self._lock()
try:
if self.server is not None:
self.server.shutdown()
self.server.server_close()
self.server = None
finally:
self._unlock()
def run(self):
'''Run the TFTP server for the specified number of sessions.'''
running = True
while running:
period = 1
self._lock()
if self.server is None:
running = False
period = 0
elif self.sessions is not None:
if self.sessions == 0:
running = False
period = 0
else:
period = 0.25
self._unlock()
if period > 0:
time.sleep(period)
self.stop()
def get_session(self):
'''Return the session count.'''
count = 0
self._lock()
try:
count = self.session
finally:
self._unlock()
return count
def next_session(self):
'''Return the next session number.'''
count = 0
self._lock()
try:
self.session += 1
count = self.session
finally:
self._unlock()
return count
def enable_notices(self):
'''Call to enable notices. The server is quiet without this call.'''
self._lock()
self.notices = True
self._unlock()
def trace_packets(self):
'''Call to enable packet tracing as a diagnostic.'''
self._lock()
self.packet_trace = True
self._unlock()
def load_log(logfile):
'''Set the log file.'''
if logfile is None:
log.default = log.log(streams=['stdout'])
else:
log.default = log.log(streams=[logfile])
def run(args=sys.argv, command_path=None):
'''Run a TFTP server session.'''
# pylint: disable=dangerous-default-value
# pylint: disable=unused-argument
# pylint: disable=too-many-statements
ecode = 0
notice = None
server = None
# pylint: disable=bare-except
try:
description = 'A TFTP Server that supports a read only TFTP session.'
nice_cwd = os.path.relpath(os.getcwd())
if len(nice_cwd) > len(os.path.abspath(nice_cwd)):
nice_cwd = os.path.abspath(nice_cwd)
argsp = argparse.ArgumentParser(prog='rtems-tftp-server',
description=description)
argsp.add_argument('-l',
'--log',
help='log file.',
type=str,
default=None)
argsp.add_argument('-v',
'--trace',
help='enable trace logging for debugging.',
action='store_true',
default=False)
argsp.add_argument('--trace-packets',
help='enable trace logging of packets.',
action='store_true',
default=False)
argsp.add_argument(
'-B',
'--bind',
help='address to bind the server too (default: %(default)s).',
type=str,
default='all')
argsp.add_argument(
'-P',
'--port',
help='port to bind the server too (default: %(default)s).',
type=int,
default='69')
argsp.add_argument('-t', '--timeout',
help = 'timeout in seconds, client can override ' \
'(default: %(default)s).',
type = int, default = '10')
argsp.add_argument(
'-b',
'--base',
help='base path, not checked (default: %(default)s).',
type=str,
default=nice_cwd)
argsp.add_argument(
'-F',
'--force-file',
help='force the file to be downloaded overriding the client.',
type=str,
default=None)
argsp.add_argument('-s', '--sessions',
help = 'number of TFTP sessions to run before exiting ' \
'(default: forever.',
type = int, default = None)
argopts = argsp.parse_args(args[1:])
load_log(argopts.log)
log.notice('RTEMS Tools - TFTP Server, %s' % (version.string()))
log.output(log.info(args))
log.tracing = argopts.trace
server = tftp_server(argopts.bind, argopts.port, argopts.timeout,
argopts.base, argopts.force_file,
argopts.sessions)
server.enable_notices()
try:
server.start()
server.run()
finally:
server.stop()
except error.general as gerr:
notice = str(gerr)
ecode = 1
except error.internal as ierr:
notice = str(ierr)
ecode = 1
except error.exit:
pass
except KeyboardInterrupt:
notice = 'abort: user terminated'
ecode = 1
except SystemExit:
pass
except:
notice = 'abort: unknown error'
ecode = 1
if server is not None:
del server
if notice is not None:
log.stderr(notice)
sys.exit(ecode)
if __name__ == "__main__":
run()

View File

@ -1,21 +0,0 @@
The MIT License
Copyright (c) 2009 Michael P. Soulier
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -1,115 +0,0 @@
Copyright, Michael P. Soulier, 2010.
About Release 0.6.2:
====================
Maintenance release to fix a couple of reported issues.
About Release 0.6.1:
====================
Maintenance release to fix several reported problems, including a rollover
at 2^16 blocks, and some contributed work on dynamic file objects.
About Release 0.6.0:
====================
Maintenance update to fix several reported issues, including proper
retransmits on timeouts, and further expansion of unit tests.
About Release 0.5.1:
====================
Maintenance update to fix a bug in the server, overhaul the documentation for
the website, fix a typo in the unit tests, fix a failure to set default
blocksize, and a divide by zero error in speed calculations for very short
transfers.
Also, this release adds support for input/output in client as stdin/stdout
About Release 0.5.0:
====================
Complete rewrite of the state machine.
Now fully implements downloading and uploading.
About Release 0.4.6:
====================
Feature release to add the tsize option.
Thanks to Kuba Kończyk for the patch.
About Release 0.4.5:
====================
Bugfix release for compatability issues on Win32, among other small issues.
About Release 0.4.4:
====================
Bugfix release for poor tolerance of unsupported options in the server.
About Release 0.4.3:
====================
Bugfix release for an issue with the server's detection of the end of the file
during a download.
About Release 0.4.2:
====================
Bugfix release for some small installation issues with earlier Python
releases.
About Release 0.4.1:
====================
Bugfix release to fix the installation path, with some restructuring into a
tftpy package from the single module used previously.
About Release 0.4:
==================
This release adds a TftpServer class with a sample implementation in bin.
The server uses a single thread with multiple handlers and a select() loop to
handle multiple clients simultaneously.
Only downloads are supported at this time.
About Release 0.3:
==================
This release fixes a major RFC 1350 compliance problem with the remote TID.
About Release 0.2:
==================
This release adds variable block sizes, and general option support,
implementing RFCs 2347 and 2348. This is accessible in the TftpClient class
via the options dict, or in the sample client via the --blocksize option.
About Release 0.1:
==================
This is an initial release in the spirit of "release early, release often".
Currently the sample client works, supporting RFC 1350. The server is not yet
implemented, and RFC 2347 and 2348 support (variable block sizes) is underway,
planned for 0.2.
About Tftpy:
============
Purpose:
--------
Tftpy is a TFTP library for the Python programming language. It includes
client and server classes, with sample implementations. Hooks are included for
easy inclusion in a UI for populating progress indicators. It supports RFCs
1350, 2347, 2348 and the tsize option from RFC 2349.
Dependencies:
-------------
Python 2.3+, hopefully. Let me know if it fails to work.
Trifles:
--------
Home page: http://tftpy.sf.net/
Project page: http://sourceforge.net/projects/tftpy/
License is the MIT License
See COPYING in this distribution.
Limitations:
------------
- Only 'octet' mode is supported.
- The only options supported are blksize and tsize.
Author:
=======
Michael P. Soulier <msoulier@digitaltorque.ca>

View File

@ -1,107 +0,0 @@
# vim: ts=4 sw=4 et ai:
# -*- coding: utf8 -*-
"""This module implements the TFTP Client functionality. Instantiate an
instance of the client, and then use its upload or download method. Logging is
performed via a standard logging object set in TftpShared."""
import types
import logging
from .TftpShared import *
from .TftpPacketTypes import *
from .TftpContexts import TftpContextClientDownload, TftpContextClientUpload
log = logging.getLogger('tftpy.TftpClient')
class TftpClient(TftpSession):
"""This class is an implementation of a tftp client. Once instantiated, a
download can be initiated via the download() method, or an upload via the
upload() method."""
def __init__(self, host, port=69, options={}, localip = ""):
TftpSession.__init__(self)
self.context = None
self.host = host
self.iport = port
self.filename = None
self.options = options
self.localip = localip
if 'blksize' in self.options:
size = self.options['blksize']
tftpassert(int == type(size), "blksize must be an int")
if size < MIN_BLKSIZE or size > MAX_BLKSIZE:
raise TftpException("Invalid blksize: %d" % size)
def download(self, filename, output, packethook=None, timeout=SOCK_TIMEOUT):
"""This method initiates a tftp download from the configured remote
host, requesting the filename passed. It writes the file to output,
which can be a file-like object or a path to a local file. If a
packethook is provided, it must be a function that takes a single
parameter, which will be a copy of each DAT packet received in the
form of a TftpPacketDAT object. The timeout parameter may be used to
override the default SOCK_TIMEOUT setting, which is the amount of time
that the client will wait for a receive packet to arrive.
Note: If output is a hyphen, stdout is used."""
# We're downloading.
log.debug("Creating download context with the following params:")
log.debug("host = %s, port = %s, filename = %s" % (self.host, self.iport, filename))
log.debug("options = %s, packethook = %s, timeout = %s" % (self.options, packethook, timeout))
self.context = TftpContextClientDownload(self.host,
self.iport,
filename,
output,
self.options,
packethook,
timeout,
localip = self.localip)
self.context.start()
# Download happens here
self.context.end()
metrics = self.context.metrics
log.info('')
log.info("Download complete.")
if metrics.duration == 0:
log.info("Duration too short, rate undetermined")
else:
log.info("Downloaded %.2f bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
log.info("Average rate: %.2f kbps" % metrics.kbps)
log.info("%.2f bytes in resent data" % metrics.resent_bytes)
log.info("Received %d duplicate packets" % metrics.dupcount)
def upload(self, filename, input, packethook=None, timeout=SOCK_TIMEOUT):
"""This method initiates a tftp upload to the configured remote host,
uploading the filename passed. It reads the file from input, which
can be a file-like object or a path to a local file. If a packethook
is provided, it must be a function that takes a single parameter,
which will be a copy of each DAT packet sent in the form of a
TftpPacketDAT object. The timeout parameter may be used to override
the default SOCK_TIMEOUT setting, which is the amount of time that
the client will wait for a DAT packet to be ACKd by the server.
Note: If input is a hyphen, stdin is used."""
self.context = TftpContextClientUpload(self.host,
self.iport,
filename,
input,
self.options,
packethook,
timeout,
localip = self.localip)
self.context.start()
# Upload happens here
self.context.end()
metrics = self.context.metrics
log.info('')
log.info("Upload complete.")
if metrics.duration == 0:
log.info("Duration too short, rate undetermined")
else:
log.info("Uploaded %d bytes in %.2f seconds" % (metrics.bytes, metrics.duration))
log.info("Average rate: %.2f kbps" % metrics.kbps)
log.info("%.2f bytes in resent data" % metrics.resent_bytes)
log.info("Resent %d packets" % metrics.dupcount)

View File

@ -1,429 +0,0 @@
# vim: ts=4 sw=4 et ai:
# -*- coding: utf8 -*-
"""This module implements all contexts for state handling during uploads and
downloads, the main interface to which being the TftpContext base class.
The concept is simple. Each context object represents a single upload or
download, and the state object in the context object represents the current
state of that transfer. The state object has a handle() method that expects
the next packet in the transfer, and returns a state object until the transfer
is complete, at which point it returns None. That is, unless there is a fatal
error, in which case a TftpException is returned instead."""
from .TftpShared import *
from .TftpPacketTypes import *
from .TftpPacketFactory import TftpPacketFactory
from .TftpStates import *
import socket
import time
import sys
import os
import logging
log = logging.getLogger('tftpy.TftpContext')
###############################################################################
# Utility classes
###############################################################################
class TftpMetrics(object):
"""A class representing metrics of the transfer."""
def __init__(self):
# Bytes transferred
self.bytes = 0
# Bytes re-sent
self.resent_bytes = 0
# Duplicate packets received
self.dups = {}
self.dupcount = 0
# Times
self.start_time = 0
self.end_time = 0
self.duration = 0
# Rates
self.bps = 0
self.kbps = 0
# Generic errors
self.errors = 0
def compute(self):
# Compute transfer time
self.duration = self.end_time - self.start_time
if self.duration == 0:
self.duration = 1
log.debug("TftpMetrics.compute: duration is %s", self.duration)
self.bps = (self.bytes * 8.0) / self.duration
self.kbps = self.bps / 1024.0
log.debug("TftpMetrics.compute: kbps is %s", self.kbps)
for key in self.dups:
self.dupcount += self.dups[key]
def add_dup(self, pkt):
"""This method adds a dup for a packet to the metrics."""
log.debug("Recording a dup of %s", pkt)
s = str(pkt)
if s in self.dups:
self.dups[s] += 1
else:
self.dups[s] = 1
tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached")
###############################################################################
# Context classes
###############################################################################
class TftpContext(object):
"""The base class of the contexts."""
def __init__(self, host, port, timeout, localip = ""):
"""Constructor for the base context, setting shared instance
variables."""
self.file_to_transfer = None
self.fileobj = None
self.options = None
self.packethook = None
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
if localip != "":
self.sock.bind((localip, 0))
self.sock.settimeout(timeout)
self.timeout = timeout
self.state = None
self.next_block = 0
self.factory = TftpPacketFactory()
# Note, setting the host will also set self.address, as it's a property.
self.host = host
self.port = port
# The port associated with the TID
self.tidport = None
# Metrics
self.metrics = TftpMetrics()
# Fluag when the transfer is pending completion.
self.pending_complete = False
# Time when this context last received any traffic.
# FIXME: does this belong in metrics?
self.last_update = 0
# The last packet we sent, if applicable, to make resending easy.
self.last_pkt = None
# Count the number of retry attempts.
self.retry_count = 0
def getBlocksize(self):
"""Fetch the current blocksize for this session."""
return int(self.options.get('blksize', 512))
def __del__(self):
"""Simple destructor to try to call housekeeping in the end method if
not called explicitely. Leaking file descriptors is not a good
thing."""
self.end()
def checkTimeout(self, now):
"""Compare current time with last_update time, and raise an exception
if we're over the timeout time."""
log.debug("checking for timeout on session %s", self)
if now - self.last_update > self.timeout:
raise TftpTimeout("Timeout waiting for traffic")
def start(self):
raise NotImplementedError("Abstract method")
def end(self, close_fileobj=True):
"""Perform session cleanup, since the end method should always be
called explicitely by the calling code, this works better than the
destructor.
Set close_fileobj to False so fileobj can be returned open."""
log.debug("in TftpContext.end - closing socket")
self.sock.close()
if close_fileobj and self.fileobj is not None and not self.fileobj.closed:
log.debug("self.fileobj is open - closing")
self.fileobj.close()
def gethost(self):
"Simple getter method for use in a property."
return self.__host
def sethost(self, host):
"""Setter method that also sets the address property as a result
of the host that is set."""
self.__host = host
self.address = socket.gethostbyname(host)
host = property(gethost, sethost)
def setNextBlock(self, block):
if block >= 2 ** 16:
log.debug("Block number rollover to 0 again")
block = 0
self.__eblock = block
def getNextBlock(self):
return self.__eblock
next_block = property(getNextBlock, setNextBlock)
def cycle(self):
"""Here we wait for a response from the server after sending it
something, and dispatch appropriate action to that response."""
try:
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
except socket.timeout:
log.warning("Timeout waiting for traffic, retrying...")
raise TftpTimeout("Timed-out waiting for traffic")
# Ok, we've received a packet. Log it.
log.debug("Received %d bytes from %s:%s",
len(buffer), raddress, rport)
# And update our last updated time.
self.last_update = time.time()
# Decode it.
recvpkt = self.factory.parse(buffer)
# Check for known "connection".
if raddress != self.address:
log.warning("Received traffic from %s, expected host %s. Discarding"
% (raddress, self.host))
if self.tidport and self.tidport != rport:
log.warning("Received traffic from %s:%s but we're "
"connected to %s:%s. Discarding."
% (raddress, rport,
self.host, self.tidport))
# If there is a packethook defined, call it. We unconditionally
# pass all packets, it's up to the client to screen out different
# kinds of packets. This way, the client is privy to things like
# negotiated options.
if self.packethook:
self.packethook(recvpkt)
# And handle it, possibly changing state.
self.state = self.state.handle(recvpkt, raddress, rport)
# If we didn't throw any exceptions here, reset the retry_count to
# zero.
self.retry_count = 0
class TftpContextServer(TftpContext):
"""The context for the server."""
def __init__(self,
host,
port,
timeout,
root,
dyn_file_func=None,
upload_open=None):
TftpContext.__init__(self,
host,
port,
timeout,
)
# At this point we have no idea if this is a download or an upload. We
# need to let the start state determine that.
self.state = TftpStateServerStart(self)
self.root = root
self.dyn_file_func = dyn_file_func
self.upload_open = upload_open
def __str__(self):
return "%s:%s %s" % (self.host, self.port, self.state)
def start(self, buffer):
"""Start the state cycle. Note that the server context receives an
initial packet in its start method. Also note that the server does not
loop on cycle(), as it expects the TftpServer object to manage
that."""
log.debug("In TftpContextServer.start")
self.metrics.start_time = time.time()
log.debug("Set metrics.start_time to %s", self.metrics.start_time)
# And update our last updated time.
self.last_update = time.time()
pkt = self.factory.parse(buffer)
log.debug("TftpContextServer.start() - factory returned a %s", pkt)
# Call handle once with the initial packet. This should put us into
# the download or the upload state.
self.state = self.state.handle(pkt,
self.host,
self.port)
def end(self):
"""Finish up the context."""
TftpContext.end(self)
self.metrics.end_time = time.time()
log.debug("Set metrics.end_time to %s", self.metrics.end_time)
self.metrics.compute()
class TftpContextClientUpload(TftpContext):
"""The upload context for the client during an upload.
Note: If input is a hyphen, then we will use stdin."""
def __init__(self,
host,
port,
filename,
input,
options,
packethook,
timeout,
localip = ""):
TftpContext.__init__(self,
host,
port,
timeout,
localip)
self.file_to_transfer = filename
self.options = options
self.packethook = packethook
# If the input object has a read() function,
# assume it is file-like.
if hasattr(input, 'read'):
self.fileobj = input
elif input == '-':
self.fileobj = sys.stdin
else:
self.fileobj = open(input, "rb")
log.debug("TftpContextClientUpload.__init__()")
log.debug("file_to_transfer = %s, options = %s" %
(self.file_to_transfer, self.options))
def __str__(self):
return "%s:%s %s" % (self.host, self.port, self.state)
def start(self):
log.info("Sending tftp upload request to %s" % self.host)
log.info(" filename -> %s" % self.file_to_transfer)
log.info(" options -> %s" % self.options)
self.metrics.start_time = time.time()
log.debug("Set metrics.start_time to %s" % self.metrics.start_time)
# FIXME: put this in a sendWRQ method?
pkt = TftpPacketWRQ()
pkt.filename = self.file_to_transfer
pkt.mode = "octet" # FIXME - shouldn't hardcode this
pkt.options = self.options
self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
self.next_block = 1
self.last_pkt = pkt
# FIXME: should we centralize sendto operations so we can refactor all
# saving of the packet to the last_pkt field?
self.state = TftpStateSentWRQ(self)
while self.state:
try:
log.debug("State is %s" % self.state)
self.cycle()
except TftpTimeout as err:
log.error(str(err))
self.retry_count += 1
if self.retry_count >= TIMEOUT_RETRIES:
log.debug("hit max retries, giving up")
raise
else:
log.warning("resending last packet")
self.state.resendLast()
def end(self):
"""Finish up the context."""
TftpContext.end(self)
self.metrics.end_time = time.time()
log.debug("Set metrics.end_time to %s" % self.metrics.end_time)
self.metrics.compute()
class TftpContextClientDownload(TftpContext):
"""The download context for the client during a download.
Note: If output is a hyphen, then the output will be sent to stdout."""
def __init__(self,
host,
port,
filename,
output,
options,
packethook,
timeout,
localip = ""):
TftpContext.__init__(self,
host,
port,
timeout,
localip)
# FIXME: should we refactor setting of these params?
self.file_to_transfer = filename
self.options = options
self.packethook = packethook
self.filelike_fileobj = False
# If the output object has a write() function,
# assume it is file-like.
if hasattr(output, 'write'):
self.fileobj = output
self.filelike_fileobj = True
# If the output filename is -, then use stdout
elif output == '-':
self.fileobj = sys.stdout
self.filelike_fileobj = True
else:
self.fileobj = open(output, "wb")
log.debug("TftpContextClientDownload.__init__()")
log.debug("file_to_transfer = %s, options = %s" %
(self.file_to_transfer, self.options))
def __str__(self):
return "%s:%s %s" % (self.host, self.port, self.state)
def start(self):
"""Initiate the download."""
log.info("Sending tftp download request to %s" % self.host)
log.info(" filename -> %s" % self.file_to_transfer)
log.info(" options -> %s" % self.options)
self.metrics.start_time = time.time()
log.debug("Set metrics.start_time to %s" % self.metrics.start_time)
# FIXME: put this in a sendRRQ method?
pkt = TftpPacketRRQ()
pkt.filename = self.file_to_transfer
pkt.mode = "octet" # FIXME - shouldn't hardcode this
pkt.options = self.options
self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
self.next_block = 1
self.last_pkt = pkt
self.state = TftpStateSentRRQ(self)
while self.state:
try:
log.debug("State is %s" % self.state)
self.cycle()
except TftpTimeout as err:
log.error(str(err))
self.retry_count += 1
if self.retry_count >= TIMEOUT_RETRIES:
log.debug("hit max retries, giving up")
raise
else:
log.warning("resending last packet")
self.state.resendLast()
except TftpFileNotFoundError as err:
# If we received file not found, then we should not save the open
# output file or we'll be left with a size zero file. Delete it,
# if it exists.
log.error("Received File not found error")
if self.fileobj is not None and not self.filelike_fileobj:
if os.path.exists(self.fileobj.name):
log.debug("unlinking output file of %s", self.fileobj.name)
os.unlink(self.fileobj.name)
raise
def end(self):
"""Finish up the context."""
TftpContext.end(self, not self.filelike_fileobj)
self.metrics.end_time = time.time()
log.debug("Set metrics.end_time to %s" % self.metrics.end_time)
self.metrics.compute()

View File

@ -1,47 +0,0 @@
# vim: ts=4 sw=4 et ai:
# -*- coding: utf8 -*-
"""This module implements the TftpPacketFactory class, which can take a binary
buffer, and return the appropriate TftpPacket object to represent it, via the
parse() method."""
from .TftpShared import *
from .TftpPacketTypes import *
import logging
log = logging.getLogger('tftpy.TftpPacketFactory')
class TftpPacketFactory(object):
"""This class generates TftpPacket objects. It is responsible for parsing
raw buffers off of the wire and returning objects representing them, via
the parse() method."""
def __init__(self):
self.classes = {
1: TftpPacketRRQ,
2: TftpPacketWRQ,
3: TftpPacketDAT,
4: TftpPacketACK,
5: TftpPacketERR,
6: TftpPacketOACK
}
def parse(self, buffer):
"""This method is used to parse an existing datagram into its
corresponding TftpPacket object. The buffer is the raw bytes off of
the network."""
log.debug("parsing a %d byte packet" % len(buffer))
(opcode,) = struct.unpack(str("!H"), buffer[:2])
log.debug("opcode is %d" % opcode)
packet = self.__create(opcode)
packet.buffer = buffer
return packet.decode()
def __create(self, opcode):
"""This method returns the appropriate class object corresponding to
the passed opcode."""
tftpassert(opcode in self.classes,
"Unsupported opcode: %d" % opcode)
packet = self.classes[opcode]()
return packet

View File

@ -1,494 +0,0 @@
# vim: ts=4 sw=4 et ai:
# -*- coding: utf8 -*-
"""This module implements the packet types of TFTP itself, and the
corresponding encode and decode methods for them."""
import struct
import sys
import logging
from .TftpShared import *
log = logging.getLogger('tftpy.TftpPacketTypes')
class TftpSession(object):
"""This class is the base class for the tftp client and server. Any shared
code should be in this class."""
# FIXME: do we need this anymore?
pass
class TftpPacketWithOptions(object):
"""This class exists to permit some TftpPacket subclasses to share code
regarding options handling. It does not inherit from TftpPacket, as the
goal is just to share code here, and not cause diamond inheritance."""
def __init__(self):
self.options = {}
# Always use unicode strings, except at the encode/decode barrier.
# Simpler to keep things clear.
def setoptions(self, options):
log.debug("in TftpPacketWithOptions.setoptions")
log.debug("options: %s", options)
myoptions = {}
for key in options:
newkey = key
if isinstance(key, bytes):
newkey = newkey.decode('ascii')
newval = options[key]
if isinstance(newval, bytes):
newval = newval.decode('ascii')
myoptions[newkey] = newval
log.debug("populated myoptions with %s = %s", newkey, myoptions[newkey])
log.debug("setting options hash to: %s", myoptions)
self._options = myoptions
def getoptions(self):
log.debug("in TftpPacketWithOptions.getoptions")
return self._options
# Set up getter and setter on options to ensure that they are the proper
# type. They should always be strings, but we don't need to force the
# client to necessarily enter strings if we can avoid it.
options = property(getoptions, setoptions)
def decode_options(self, buffer):
"""This method decodes the section of the buffer that contains an
unknown number of options. It returns a dictionary of option names and
values."""
fmt = b"!"
options = {}
log.debug("decode_options: buffer is: %s", repr(buffer))
log.debug("size of buffer is %d bytes", len(buffer))
if len(buffer) == 0:
log.debug("size of buffer is zero, returning empty hash")
return {}
# Count the nulls in the buffer. Each one terminates a string.
log.debug("about to iterate options buffer counting nulls")
length = 0
for i in range(len(buffer)):
if ord(buffer[i:i+1]) == 0:
log.debug("found a null at length %d", length)
if length > 0:
fmt += b"%dsx" % length
length = -1
else:
raise TftpException("Invalid options in buffer")
length += 1
log.debug("about to unpack, fmt is: %s", fmt)
mystruct = struct.unpack(fmt, buffer)
tftpassert(len(mystruct) % 2 == 0,
"packet with odd number of option/value pairs")
for i in range(0, len(mystruct), 2):
key = mystruct[i].decode('ascii')
val = mystruct[i+1].decode('ascii')
log.debug("setting option %s to %s", key, val)
log.debug("types are %s and %s", type(key), type(val))
options[key] = val
return options
class TftpPacket(object):
"""This class is the parent class of all tftp packet classes. It is an
abstract class, providing an interface, and should not be instantiated
directly."""
def __init__(self):
self.opcode = 0
self.buffer = None
def encode(self):
"""The encode method of a TftpPacket takes keyword arguments specific
to the type of packet, and packs an appropriate buffer in network-byte
order suitable for sending over the wire.
This is an abstract method."""
raise NotImplementedError("Abstract method")
def decode(self):
"""The decode method of a TftpPacket takes a buffer off of the wire in
network-byte order, and decodes it, populating internal properties as
appropriate. This can only be done once the first 2-byte opcode has
already been decoded, but the data section does include the entire
datagram.
This is an abstract method."""
raise NotImplementedError("Abstract method")
class TftpPacketInitial(TftpPacket, TftpPacketWithOptions):
"""This class is a common parent class for the RRQ and WRQ packets, as
they share quite a bit of code."""
def __init__(self):
TftpPacket.__init__(self)
TftpPacketWithOptions.__init__(self)
self.filename = None
self.mode = None
def encode(self):
"""Encode the packet's buffer from the instance variables."""
tftpassert(self.filename, "filename required in initial packet")
tftpassert(self.mode, "mode required in initial packet")
# Make sure filename and mode are bytestrings.
filename = self.filename
mode = self.mode
if not isinstance(filename, bytes):
filename = filename.encode('ascii')
if not isinstance(self.mode, bytes):
mode = mode.encode('ascii')
ptype = None
if self.opcode == 1: ptype = "RRQ"
else: ptype = "WRQ"
log.debug("Encoding %s packet, filename = %s, mode = %s",
ptype, filename, mode)
for key in self.options:
log.debug(" Option %s = %s", key, self.options[key])
fmt = b"!H"
fmt += b"%dsx" % len(filename)
if mode == b"octet":
fmt += b"5sx"
else:
raise AssertionError("Unsupported mode: %s" % mode)
# Add options. Note that the options list must be bytes.
options_list = []
if len(list(self.options.keys())) > 0:
log.debug("there are options to encode")
for key in self.options:
# Populate the option name
name = key
if not isinstance(name, bytes):
name = name.encode('ascii')
options_list.append(name)
fmt += b"%dsx" % len(name)
# Populate the option value
value = self.options[key]
# Work with all strings.
if isinstance(value, int):
value = str(value)
if not isinstance(value, bytes):
value = value.encode('ascii')
options_list.append(value)
fmt += b"%dsx" % len(value)
log.debug("fmt is %s", fmt)
log.debug("options_list is %s", options_list)
log.debug("size of struct is %d", struct.calcsize(fmt))
self.buffer = struct.pack(fmt,
self.opcode,
filename,
mode,
*options_list)
log.debug("buffer is %s", repr(self.buffer))
return self
def decode(self):
tftpassert(self.buffer, "Can't decode, buffer is empty")
# FIXME - this shares a lot of code with decode_options
nulls = 0
fmt = b""
nulls = length = tlength = 0
log.debug("in decode: about to iterate buffer counting nulls")
subbuf = self.buffer[2:]
for i in range(len(subbuf)):
if ord(subbuf[i:i+1]) == 0:
nulls += 1
log.debug("found a null at length %d, now have %d", length, nulls)
fmt += b"%dsx" % length
length = -1
# At 2 nulls, we want to mark that position for decoding.
if nulls == 2:
break
length += 1
tlength += 1
log.debug("hopefully found end of mode at length %d", tlength)
# length should now be the end of the mode.
tftpassert(nulls == 2, "malformed packet")
shortbuf = subbuf[:tlength+1]
log.debug("about to unpack buffer with fmt: %s", fmt)
log.debug("unpacking buffer: %s", repr(shortbuf))
mystruct = struct.unpack(fmt, shortbuf)
tftpassert(len(mystruct) == 2, "malformed packet")
self.filename = mystruct[0].decode('ascii')
self.mode = mystruct[1].decode('ascii').lower() # force lc - bug 17
log.debug("set filename to %s", self.filename)
log.debug("set mode to %s", self.mode)
self.options = self.decode_options(subbuf[tlength+1:])
log.debug("options dict is now %s", self.options)
return self
class TftpPacketRRQ(TftpPacketInitial):
"""
::
2 bytes string 1 byte string 1 byte
-----------------------------------------------
RRQ/ | 01/02 | Filename | 0 | Mode | 0 |
WRQ -----------------------------------------------
"""
def __init__(self):
TftpPacketInitial.__init__(self)
self.opcode = 1
def __str__(self):
s = 'RRQ packet: filename = %s' % self.filename
s += ' mode = %s' % self.mode
if self.options:
s += '\n options = %s' % self.options
return s
class TftpPacketWRQ(TftpPacketInitial):
"""
::
2 bytes string 1 byte string 1 byte
-----------------------------------------------
RRQ/ | 01/02 | Filename | 0 | Mode | 0 |
WRQ -----------------------------------------------
"""
def __init__(self):
TftpPacketInitial.__init__(self)
self.opcode = 2
def __str__(self):
s = 'WRQ packet: filename = %s' % self.filename
s += ' mode = %s' % self.mode
if self.options:
s += '\n options = %s' % self.options
return s
class TftpPacketDAT(TftpPacket):
"""
::
2 bytes 2 bytes n bytes
---------------------------------
DATA | 03 | Block # | Data |
---------------------------------
"""
def __init__(self):
TftpPacket.__init__(self)
self.opcode = 3
self.blocknumber = 0
self.data = None
def __str__(self):
s = 'DAT packet: block %s' % self.blocknumber
if self.data:
s += '\n data: %d bytes' % len(self.data)
return s
def encode(self):
"""Encode the DAT packet. This method populates self.buffer, and
returns self for easy method chaining."""
if len(self.data) == 0:
log.debug("Encoding an empty DAT packet")
data = self.data
if not isinstance(self.data, bytes):
data = self.data.encode('ascii')
fmt = b"!HH%ds" % len(data)
self.buffer = struct.pack(fmt,
self.opcode,
self.blocknumber,
data)
return self
def decode(self):
"""Decode self.buffer into instance variables. It returns self for
easy method chaining."""
# We know the first 2 bytes are the opcode. The second two are the
# block number.
(self.blocknumber,) = struct.unpack(str("!H"), self.buffer[2:4])
log.debug("decoding DAT packet, block number %d", self.blocknumber)
log.debug("should be %d bytes in the packet total", len(self.buffer))
# Everything else is data.
self.data = self.buffer[4:]
log.debug("found %d bytes of data", len(self.data))
return self
class TftpPacketACK(TftpPacket):
"""
::
2 bytes 2 bytes
-------------------
ACK | 04 | Block # |
--------------------
"""
def __init__(self):
TftpPacket.__init__(self)
self.opcode = 4
self.blocknumber = 0
def __str__(self):
return 'ACK packet: block %d' % self.blocknumber
def encode(self):
log.debug("encoding ACK: opcode = %d, block = %d",
self.opcode, self.blocknumber)
self.buffer = struct.pack(str("!HH"), self.opcode, self.blocknumber)
return self
def decode(self):
if len(self.buffer) > 4:
log.debug("detected TFTP ACK but request is too large, will truncate")
log.debug("buffer was: %s", repr(self.buffer))
self.buffer = self.buffer[0:4]
self.opcode, self.blocknumber = struct.unpack(str("!HH"), self.buffer)
log.debug("decoded ACK packet: opcode = %d, block = %d",
self.opcode, self.blocknumber)
return self
class TftpPacketERR(TftpPacket):
"""
::
2 bytes 2 bytes string 1 byte
----------------------------------------
ERROR | 05 | ErrorCode | ErrMsg | 0 |
----------------------------------------
Error Codes
Value Meaning
0 Not defined, see error message (if any).
1 File not found.
2 Access violation.
3 Disk full or allocation exceeded.
4 Illegal TFTP operation.
5 Unknown transfer ID.
6 File already exists.
7 No such user.
8 Failed to negotiate options
"""
def __init__(self):
TftpPacket.__init__(self)
self.opcode = 5
self.errorcode = 0
# FIXME: We don't encode the errmsg...
self.errmsg = None
# FIXME - integrate in TftpErrors references?
self.errmsgs = {
1: b"File not found",
2: b"Access violation",
3: b"Disk full or allocation exceeded",
4: b"Illegal TFTP operation",
5: b"Unknown transfer ID",
6: b"File already exists",
7: b"No such user",
8: b"Failed to negotiate options"
}
def __str__(self):
s = 'ERR packet: errorcode = %d' % self.errorcode
s += '\n msg = %s' % self.errmsgs.get(self.errorcode, '')
return s
def encode(self):
"""Encode the DAT packet based on instance variables, populating
self.buffer, returning self."""
fmt = b"!HH%dsx" % len(self.errmsgs[self.errorcode])
log.debug("encoding ERR packet with fmt %s", fmt)
self.buffer = struct.pack(fmt,
self.opcode,
self.errorcode,
self.errmsgs[self.errorcode])
return self
def decode(self):
"Decode self.buffer, populating instance variables and return self."
buflen = len(self.buffer)
tftpassert(buflen >= 4, "malformed ERR packet, too short")
log.debug("Decoding ERR packet, length %s bytes", buflen)
if buflen == 4:
log.debug("Allowing this affront to the RFC of a 4-byte packet")
fmt = b"!HH"
log.debug("Decoding ERR packet with fmt: %s", fmt)
self.opcode, self.errorcode = struct.unpack(fmt,
self.buffer)
else:
log.debug("Good ERR packet > 4 bytes")
fmt = b"!HH%dsx" % (len(self.buffer) - 5)
log.debug("Decoding ERR packet with fmt: %s", fmt)
self.opcode, self.errorcode, self.errmsg = struct.unpack(fmt,
self.buffer)
log.error("ERR packet - errorcode: %d, message: %s"
% (self.errorcode, self.errmsg))
return self
class TftpPacketOACK(TftpPacket, TftpPacketWithOptions):
"""
::
+-------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+
| opc | opt1 | 0 | value1 | 0 | optN | 0 | valueN | 0 |
+-------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+
"""
def __init__(self):
TftpPacket.__init__(self)
TftpPacketWithOptions.__init__(self)
self.opcode = 6
def __str__(self):
return 'OACK packet:\n options = %s' % self.options
def encode(self):
fmt = b"!H" # opcode
options_list = []
log.debug("in TftpPacketOACK.encode")
for key in self.options:
value = self.options[key]
if isinstance(value, int):
value = str(value)
if not isinstance(key, bytes):
key = key.encode('ascii')
if not isinstance(value, bytes):
value = value.encode('ascii')
log.debug("looping on option key %s", key)
log.debug("value is %s", value)
fmt += b"%dsx" % len(key)
fmt += b"%dsx" % len(value)
options_list.append(key)
options_list.append(value)
self.buffer = struct.pack(fmt, self.opcode, *options_list)
return self
def decode(self):
self.options = self.decode_options(self.buffer[2:])
return self
def match_options(self, options):
"""This method takes a set of options, and tries to match them with
its own. It can accept some changes in those options from the server as
part of a negotiation. Changed or unchanged, it will return a dict of
the options so that the session can update itself to the negotiated
options."""
for name in self.options:
if name in options:
if name == 'blksize':
# We can accept anything between the min and max values.
size = int(self.options[name])
if size >= MIN_BLKSIZE and size <= MAX_BLKSIZE:
log.debug("negotiated blksize of %d bytes", size)
options['blksize'] = size
else:
raise TftpException("blksize %s option outside allowed range" % size)
elif name == 'tsize':
size = int(self.options[name])
if size < 0:
raise TftpException("Negative file sizes not supported")
else:
raise TftpException("Unsupported option: %s" % name)
return True

View File

@ -1,271 +0,0 @@
# vim: ts=4 sw=4 et ai:
# -*- coding: utf8 -*-
"""This module implements the TFTP Server functionality. Instantiate an
instance of the server, and then run the listen() method to listen for client
requests. Logging is performed via a standard logging object set in
TftpShared."""
import socket, os, time
import select
import threading
import logging
from errno import EINTR
from .TftpShared import *
from .TftpPacketTypes import *
from .TftpPacketFactory import TftpPacketFactory
from .TftpContexts import TftpContextServer
log = logging.getLogger('tftpy.TftpServer')
class TftpServer(TftpSession):
"""This class implements a tftp server object. Run the listen() method to
listen for client requests.
tftproot is the path to the tftproot directory to serve files from and/or
write them to.
dyn_file_func is a callable that takes a requested download
path that is not present on the file system and must return either a
file-like object to read from or None if the path should appear as not
found. This permits the serving of dynamic content.
upload_open is a callable that is triggered on every upload with the
requested destination path and server context. It must either return a
file-like object ready for writing or None if the path is invalid."""
def __init__(self,
tftproot='/tftpboot',
dyn_file_func=None,
upload_open=None):
self.listenip = None
self.listenport = None
self.sock = None
# FIXME: What about multiple roots?
self.root = os.path.abspath(tftproot)
self.dyn_file_func = dyn_file_func
self.upload_open = upload_open
# A dict of sessions, where each session is keyed by a string like
# ip:tid for the remote end.
self.sessions = {}
# A threading event to help threads synchronize with the server
# is_running state.
self.is_running = threading.Event()
self.shutdown_gracefully = False
self.shutdown_immediately = False
for name in 'dyn_file_func', 'upload_open':
attr = getattr(self, name)
if attr and not callable(attr):
raise TftpException("{} supplied, but it is not callable.".format(name))
if os.path.exists(self.root):
log.debug("tftproot %s does exist", self.root)
if not os.path.isdir(self.root):
raise TftpException("The tftproot must be a directory.")
else:
log.debug("tftproot %s is a directory" % self.root)
if os.access(self.root, os.R_OK):
log.debug("tftproot %s is readable" % self.root)
else:
raise TftpException("The tftproot must be readable")
if os.access(self.root, os.W_OK):
log.debug("tftproot %s is writable" % self.root)
else:
log.warning("The tftproot %s is not writable" % self.root)
else:
raise TftpException("The tftproot does not exist.")
def __del__(self):
if self.sock is not None:
try:
self.sock.close()
except:
pass
def listen(self, listenip="", listenport=DEF_TFTP_PORT,
timeout=SOCK_TIMEOUT):
"""Start a server listening on the supplied interface and port. This
defaults to INADDR_ANY (all interfaces) and UDP port 69. You can also
supply a different socket timeout value, if desired."""
tftp_factory = TftpPacketFactory()
# Don't use new 2.5 ternary operator yet
# listenip = listenip if listenip else '0.0.0.0'
if not listenip: listenip = '0.0.0.0'
log.info("Server requested on ip %s, port %s" % (listenip, listenport))
try:
# FIXME - sockets should be non-blocking
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.bind((listenip, listenport))
_, self.listenport = self.sock.getsockname()
except socket.error as err:
# Reraise it for now.
raise err
self.is_running.set()
log.info("Starting receive loop...")
while True:
log.debug("shutdown_immediately is %s" % self.shutdown_immediately)
log.debug("shutdown_gracefully is %s" % self.shutdown_gracefully)
if self.shutdown_immediately:
log.warning("Shutting down now. Session count: %d" %
len(self.sessions))
self.sock.close()
for key in self.sessions:
self.sessions[key].end()
self.sessions = []
break
elif self.shutdown_gracefully:
if not self.sessions:
log.warning("In graceful shutdown mode and all "
"sessions complete.")
self.sock.close()
break
# Build the inputlist array of sockets to select() on.
inputlist = []
inputlist.append(self.sock)
for key in self.sessions:
inputlist.append(self.sessions[key].sock)
# Block until some socket has input on it.
log.debug("Performing select on this inputlist: %s", inputlist)
try:
readyinput, readyoutput, readyspecial = \
select.select(inputlist, [], [], SOCK_TIMEOUT)
except select.error as err:
if err[0] == EINTR:
# Interrupted system call
log.debug("Interrupted syscall, retrying")
continue
else:
raise
deletion_list = []
# Handle the available data, if any. Maybe we timed-out.
for readysock in readyinput:
# Is the traffic on the main server socket? ie. new session?
if readysock == self.sock:
log.debug("Data ready on our main socket")
buffer, (raddress, rport) = self.sock.recvfrom(MAX_BLKSIZE)
log.debug("Read %d bytes", len(buffer))
if self.shutdown_gracefully:
log.warning("Discarding data on main port, "
"in graceful shutdown mode")
continue
# Forge a session key based on the client's IP and port,
# which should safely work through NAT.
key = "%s:%s" % (raddress, rport)
if not key in self.sessions:
log.debug("Creating new server context for "
"session key = %s" % key)
try:
self.sessions[key] = TftpContextServer(raddress,
rport,
timeout,
self.root,
self.dyn_file_func,
self.upload_open)
self.sessions[key].start(buffer)
except TftpException as err:
deletion_list.append(key)
log.error("Fatal exception thrown from "
"session %s: %s" % (key, str(err)))
except KeyboardInterrupt:
pass
except:
deletion_list.append(key)
log.error("Fatal exception thrown from "
"session %s: %s" % (key, str(err)))
else:
log.warning("received traffic on main socket for "
"existing session??")
log.info("Currently handling these sessions:")
for session_key, session in list(self.sessions.items()):
log.info(" %s" % session)
else:
# Must find the owner of this traffic.
for key in self.sessions:
if readysock == self.sessions[key].sock:
log.debug("Matched input to session key %s"
% key)
try:
self.sessions[key].cycle()
if self.sessions[key].state == None:
log.info("Successful transfer.")
deletion_list.append(key)
except TftpException as err:
deletion_list.append(key)
log.error("Fatal exception thrown from "
"session %s: %s"
% (key, str(err)))
# Break out of for loop since we found the correct
# session.
break
else:
log.error("Can't find the owner for this packet. "
"Discarding.")
log.debug("Looping on all sessions to check for timeouts")
now = time.time()
for key in self.sessions:
try:
self.sessions[key].checkTimeout(now)
except TftpTimeout as err:
log.error(str(err))
self.sessions[key].retry_count += 1
if self.sessions[key].retry_count >= TIMEOUT_RETRIES:
log.debug("hit max retries on %s, giving up" %
self.sessions[key])
deletion_list.append(key)
else:
log.debug("resending on session %s" % self.sessions[key])
self.sessions[key].state.resendLast()
log.debug("Iterating deletion list.")
for key in deletion_list:
log.info('')
log.info("Session %s complete" % key)
if key in self.sessions:
log.debug("Gathering up metrics from session before deleting")
self.sessions[key].end()
metrics = self.sessions[key].metrics
if metrics.duration == 0:
log.info("Duration too short, rate undetermined")
else:
log.info("Transferred %d bytes in %.2f seconds"
% (metrics.bytes, metrics.duration))
log.info("Average rate: %.2f kbps" % metrics.kbps)
log.info("%.2f bytes in resent data" % metrics.resent_bytes)
log.info("%d duplicate packets" % metrics.dupcount)
log.debug("Deleting session %s" % key)
del self.sessions[key]
log.debug("Session list is now %s" % self.sessions)
else:
log.warning(
"Strange, session %s is not on the deletion list" % key)
self.is_running.clear()
log.debug("server returning from while loop")
self.shutdown_gracefully = self.shutdown_immediately = False
def stop(self, now=False):
"""Stop the server gracefully. Do not take any new transfers,
but complete the existing ones. If force is True, drop everything
and stop. Note, immediately will not interrupt the select loop, it
will happen when the server returns on ready data, or a timeout.
ie. SOCK_TIMEOUT"""
if now:
self.shutdown_immediately = True
else:
self.shutdown_gracefully = True

View File

@ -1,52 +0,0 @@
# vim: ts=4 sw=4 et ai:
# -*- coding: utf8 -*-
"""This module holds all objects shared by all other modules in tftpy."""
MIN_BLKSIZE = 8
DEF_BLKSIZE = 512
MAX_BLKSIZE = 65536
SOCK_TIMEOUT = 5
MAX_DUPS = 20
TIMEOUT_RETRIES = 5
DEF_TFTP_PORT = 69
# A hook for deliberately introducing delay in testing.
DELAY_BLOCK = 0
def tftpassert(condition, msg):
"""This function is a simple utility that will check the condition
passed for a false state. If it finds one, it throws a TftpException
with the message passed. This just makes the code throughout cleaner
by refactoring."""
if not condition:
raise TftpException(msg)
class TftpErrors(object):
"""This class is a convenience for defining the common tftp error codes,
and making them more readable in the code."""
NotDefined = 0
FileNotFound = 1
AccessViolation = 2
DiskFull = 3
IllegalTftpOp = 4
UnknownTID = 5
FileAlreadyExists = 6
NoSuchUser = 7
FailedNegotiation = 8
class TftpException(Exception):
"""This class is the parent class of all exceptions regarding the handling
of the TFTP protocol."""
pass
class TftpTimeout(TftpException):
"""This class represents a timeout error waiting for a response from the
other end."""
pass
class TftpFileNotFoundError(TftpException):
"""This class represents an error condition where we received a file
not found error."""
pass

View File

@ -1,611 +0,0 @@
# vim: ts=4 sw=4 et ai:
# -*- coding: utf8 -*-
"""This module implements all state handling during uploads and downloads, the
main interface to which being the TftpState base class.
The concept is simple. Each context object represents a single upload or
download, and the state object in the context object represents the current
state of that transfer. The state object has a handle() method that expects
the next packet in the transfer, and returns a state object until the transfer
is complete, at which point it returns None. That is, unless there is a fatal
error, in which case a TftpException is returned instead."""
from .TftpShared import *
from .TftpPacketTypes import *
import os
import logging
log = logging.getLogger('tftpy.TftpStates')
###############################################################################
# State classes
###############################################################################
class TftpState(object):
"""The base class for the states."""
def __init__(self, context):
"""Constructor for setting up common instance variables. The involved
file object is required, since in tftp there's always a file
involved."""
self.context = context
def handle(self, pkt, raddress, rport):
"""An abstract method for handling a packet. It is expected to return
a TftpState object, either itself or a new state."""
raise NotImplementedError("Abstract method")
def handleOACK(self, pkt):
"""This method handles an OACK from the server, syncing any accepted
options."""
if len(pkt.options.keys()) > 0:
if pkt.match_options(self.context.options):
log.info("Successful negotiation of options")
# Set options to OACK options
self.context.options = pkt.options
for key in self.context.options:
log.info(" %s = %s" % (key, self.context.options[key]))
else:
log.error("Failed to negotiate options")
raise TftpException("Failed to negotiate options")
else:
raise TftpException("No options found in OACK")
def returnSupportedOptions(self, options):
"""This method takes a requested options list from a client, and
returns the ones that are supported."""
# We support the options blksize and tsize right now.
# FIXME - put this somewhere else?
accepted_options = {}
for option in options:
if option == 'blksize':
# Make sure it's valid.
if int(options[option]) > MAX_BLKSIZE:
log.info("Client requested blksize greater than %d "
"setting to maximum" % MAX_BLKSIZE)
accepted_options[option] = MAX_BLKSIZE
elif int(options[option]) < MIN_BLKSIZE:
log.info("Client requested blksize less than %d "
"setting to minimum" % MIN_BLKSIZE)
accepted_options[option] = MIN_BLKSIZE
else:
accepted_options[option] = options[option]
elif option == 'tsize':
log.debug("tsize option is set")
accepted_options['tsize'] = 0
else:
log.info("Dropping unsupported option '%s'" % option)
log.debug("Returning these accepted options: %s", accepted_options)
return accepted_options
def sendDAT(self):
"""This method sends the next DAT packet based on the data in the
context. It returns a boolean indicating whether the transfer is
finished."""
finished = False
blocknumber = self.context.next_block
# Test hook
if DELAY_BLOCK and DELAY_BLOCK == blocknumber:
import time
log.debug("Deliberately delaying 10 seconds...")
time.sleep(10)
dat = None
blksize = self.context.getBlocksize()
buffer = self.context.fileobj.read(blksize)
log.debug("Read %d bytes into buffer", len(buffer))
if len(buffer) < blksize:
log.info("Reached EOF on file %s"
% self.context.file_to_transfer)
finished = True
dat = TftpPacketDAT()
dat.data = buffer
dat.blocknumber = blocknumber
self.context.metrics.bytes += len(dat.data)
log.debug("Sending DAT packet %d", dat.blocknumber)
self.context.sock.sendto(dat.encode().buffer,
(self.context.host, self.context.tidport))
if self.context.packethook:
self.context.packethook(dat)
self.context.last_pkt = dat
return finished
def sendACK(self, blocknumber=None):
"""This method sends an ack packet to the block number specified. If
none is specified, it defaults to the next_block property in the
parent context."""
log.debug("In sendACK, passed blocknumber is %s", blocknumber)
if blocknumber is None:
blocknumber = self.context.next_block
log.info("Sending ack to block %d" % blocknumber)
ackpkt = TftpPacketACK()
ackpkt.blocknumber = blocknumber
self.context.sock.sendto(ackpkt.encode().buffer,
(self.context.host,
self.context.tidport))
self.context.last_pkt = ackpkt
def sendError(self, errorcode):
"""This method uses the socket passed, and uses the errorcode to
compose and send an error packet."""
log.debug("In sendError, being asked to send error %d", errorcode)
errpkt = TftpPacketERR()
errpkt.errorcode = errorcode
if self.context.tidport == None:
log.debug("Error packet received outside session. Discarding")
else:
self.context.sock.sendto(errpkt.encode().buffer,
(self.context.host,
self.context.tidport))
self.context.last_pkt = errpkt
def sendOACK(self):
"""This method sends an OACK packet with the options from the current
context."""
log.debug("In sendOACK with options %s", self.context.options)
pkt = TftpPacketOACK()
pkt.options = self.context.options
self.context.sock.sendto(pkt.encode().buffer,
(self.context.host,
self.context.tidport))
self.context.last_pkt = pkt
def resendLast(self):
"Resend the last sent packet due to a timeout."
log.warning("Resending packet %s on sessions %s"
% (self.context.last_pkt, self))
self.context.metrics.resent_bytes += len(self.context.last_pkt.buffer)
self.context.metrics.add_dup(self.context.last_pkt)
sendto_port = self.context.tidport
if not sendto_port:
# If the tidport wasn't set, then the remote end hasn't even
# started talking to us yet. That's not good. Maybe it's not
# there.
sendto_port = self.context.port
self.context.sock.sendto(self.context.last_pkt.encode().buffer,
(self.context.host, sendto_port))
if self.context.packethook:
self.context.packethook(self.context.last_pkt)
def handleDat(self, pkt):
"""This method handles a DAT packet during a client download, or a
server upload."""
log.info("Handling DAT packet - block %d" % pkt.blocknumber)
log.debug("Expecting block %s", self.context.next_block)
if pkt.blocknumber == self.context.next_block:
log.debug("Good, received block %d in sequence", pkt.blocknumber)
self.sendACK()
self.context.next_block += 1
log.debug("Writing %d bytes to output file", len(pkt.data))
self.context.fileobj.write(pkt.data)
self.context.metrics.bytes += len(pkt.data)
# Check for end-of-file, any less than full data packet.
if len(pkt.data) < self.context.getBlocksize():
log.info("End of file detected")
return None
elif pkt.blocknumber < self.context.next_block:
if pkt.blocknumber == 0:
log.warning("There is no block zero!")
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("There is no block zero!")
log.warning("Dropping duplicate block %d" % pkt.blocknumber)
self.context.metrics.add_dup(pkt)
log.debug("ACKing block %d again, just in case", pkt.blocknumber)
self.sendACK(pkt.blocknumber)
else:
# FIXME: should we be more tolerant and just discard instead?
msg = "Whoa! Received future block %d but expected %d" \
% (pkt.blocknumber, self.context.next_block)
log.error(msg)
raise TftpException(msg)
# Default is to ack
return TftpStateExpectDAT(self.context)
class TftpServerState(TftpState):
"""The base class for server states."""
def __init__(self, context):
TftpState.__init__(self, context)
# This variable is used to store the absolute path to the file being
# managed.
self.full_path = None
def serverInitial(self, pkt, raddress, rport):
"""This method performs initial setup for a server context transfer,
put here to refactor code out of the TftpStateServerRecvRRQ and
TftpStateServerRecvWRQ classes, since their initial setup is
identical. The method returns a boolean, sendoack, to indicate whether
it is required to send an OACK to the client."""
options = pkt.options
sendoack = False
if not self.context.tidport:
self.context.tidport = rport
log.info("Setting tidport to %s" % rport)
log.debug("Setting default options, blksize")
self.context.options = { 'blksize': DEF_BLKSIZE }
if options:
log.debug("Options requested: %s", options)
supported_options = self.returnSupportedOptions(options)
self.context.options.update(supported_options)
sendoack = True
# FIXME - only octet mode is supported at this time.
if pkt.mode != 'octet':
#self.sendError(TftpErrors.IllegalTftpOp)
#raise TftpException("Only octet transfers are supported at this time.")
log.warning("Received non-octet mode request. I'll reply with binary data.")
# test host/port of client end
if self.context.host != raddress or self.context.port != rport:
self.sendError(TftpErrors.UnknownTID)
log.error("Expected traffic from %s:%s but received it "
"from %s:%s instead."
% (self.context.host,
self.context.port,
raddress,
rport))
# FIXME: increment an error count?
# Return same state, we're still waiting for valid traffic.
return self
log.debug("Requested filename is %s", pkt.filename)
# Build the filename on this server and ensure it is contained
# in the specified root directory.
#
# Filenames that begin with server root are accepted. It's
# assumed the client and server are tightly connected and this
# provides backwards compatibility.
#
# Filenames otherwise are relative to the server root. If they
# begin with a '/' strip it off as otherwise os.path.join will
# treat it as absolute (regardless of whether it is ntpath or
# posixpath module
if pkt.filename.startswith(self.context.root):
full_path = pkt.filename
else:
full_path = os.path.join(self.context.root, pkt.filename.lstrip('/'))
# Use abspath to eliminate any remaining relative elements
# (e.g. '..') and ensure that is still within the server's
# root directory
self.full_path = os.path.abspath(full_path)
log.debug("full_path is %s", full_path)
if self.full_path.startswith(self.context.root):
log.info("requested file is in the server root - good")
else:
log.warning("requested file is not within the server root - bad")
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("bad file path")
self.context.file_to_transfer = pkt.filename
return sendoack
class TftpStateServerRecvRRQ(TftpServerState):
"""This class represents the state of the TFTP server when it has just
received an RRQ packet."""
def handle(self, pkt, raddress, rport):
"Handle an initial RRQ packet as a server."
log.debug("In TftpStateServerRecvRRQ.handle")
sendoack = self.serverInitial(pkt, raddress, rport)
path = self.full_path
log.info("Opening file %s for reading" % path)
if os.path.exists(path):
# Note: Open in binary mode for win32 portability, since win32
# blows.
self.context.fileobj = open(path, "rb")
elif self.context.dyn_file_func:
log.debug("No such file %s but using dyn_file_func", path)
self.context.fileobj = \
self.context.dyn_file_func(self.context.file_to_transfer, raddress=raddress, rport=rport)
if self.context.fileobj is None:
log.debug("dyn_file_func returned 'None', treating as "
"FileNotFound")
self.sendError(TftpErrors.FileNotFound)
raise TftpException("File not found: %s" % path)
else:
log.warn("File not found: %s", path)
self.sendError(TftpErrors.FileNotFound)
raise TftpException("File not found: {}".format(path))
# Options negotiation.
if sendoack and 'tsize' in self.context.options:
# getting the file size for the tsize option. As we handle
# file-like objects and not only real files, we use this seeking
# method instead of asking the OS
self.context.fileobj.seek(0, os.SEEK_END)
tsize = str(self.context.fileobj.tell())
self.context.fileobj.seek(0, 0)
self.context.options['tsize'] = tsize
if sendoack:
# Note, next_block is 0 here since that's the proper
# acknowledgement to an OACK.
# FIXME: perhaps we do need a TftpStateExpectOACK class...
self.sendOACK()
# Note, self.context.next_block is already 0.
else:
self.context.next_block = 1
log.debug("No requested options, starting send...")
self.context.pending_complete = self.sendDAT()
# Note, we expect an ack regardless of whether we sent a DAT or an
# OACK.
return TftpStateExpectACK(self.context)
# Note, we don't have to check any other states in this method, that's
# up to the caller.
class TftpStateServerRecvWRQ(TftpServerState):
"""This class represents the state of the TFTP server when it has just
received a WRQ packet."""
def make_subdirs(self):
"""The purpose of this method is to, if necessary, create all of the
subdirectories leading up to the file to the written."""
# Pull off everything below the root.
subpath = self.full_path[len(self.context.root):]
log.debug("make_subdirs: subpath is %s", subpath)
# Split on directory separators, but drop the last one, as it should
# be the filename.
dirs = subpath.split(os.sep)[:-1]
log.debug("dirs is %s", dirs)
current = self.context.root
for dir in dirs:
if dir:
current = os.path.join(current, dir)
if os.path.isdir(current):
log.debug("%s is already an existing directory", current)
else:
os.mkdir(current, 0o700)
def handle(self, pkt, raddress, rport):
"Handle an initial WRQ packet as a server."
log.debug("In TftpStateServerRecvWRQ.handle")
sendoack = self.serverInitial(pkt, raddress, rport)
path = self.full_path
if self.context.upload_open:
f = self.context.upload_open(path, self.context)
if f is None:
self.sendError(TftpErrors.AccessViolation)
raise TftpException("Dynamic path %s not permitted" % path)
else:
self.context.fileobj = f
else:
log.info("Opening file %s for writing" % path)
if os.path.exists(path):
# FIXME: correct behavior?
log.warning("File %s exists already, overwriting..." % (
self.context.file_to_transfer))
# FIXME: I think we should upload to a temp file and not overwrite
# the existing file until the file is successfully uploaded.
self.make_subdirs()
self.context.fileobj = open(path, "wb")
# Options negotiation.
if sendoack:
log.debug("Sending OACK to client")
self.sendOACK()
else:
log.debug("No requested options, expecting transfer to begin...")
self.sendACK()
# Whether we're sending an oack or not, we're expecting a DAT for
# block 1
self.context.next_block = 1
# We may have sent an OACK, but we're expecting a DAT as the response
# to either the OACK or an ACK, so lets unconditionally use the
# TftpStateExpectDAT state.
return TftpStateExpectDAT(self.context)
# Note, we don't have to check any other states in this method, that's
# up to the caller.
class TftpStateServerStart(TftpState):
"""The start state for the server. This is a transitory state since at
this point we don't know if we're handling an upload or a download. We
will commit to one of them once we interpret the initial packet."""
def handle(self, pkt, raddress, rport):
"""Handle a packet we just received."""
log.debug("In TftpStateServerStart.handle")
if isinstance(pkt, TftpPacketRRQ):
log.debug("Handling an RRQ packet")
return TftpStateServerRecvRRQ(self.context).handle(pkt,
raddress,
rport)
elif isinstance(pkt, TftpPacketWRQ):
log.debug("Handling a WRQ packet")
return TftpStateServerRecvWRQ(self.context).handle(pkt,
raddress,
rport)
else:
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("Invalid packet to begin up/download: %s" % pkt)
class TftpStateExpectACK(TftpState):
"""This class represents the state of the transfer when a DAT was just
sent, and we are waiting for an ACK from the server. This class is the
same one used by the client during the upload, and the server during the
download."""
def handle(self, pkt, raddress, rport):
"Handle a packet, hopefully an ACK since we just sent a DAT."
if isinstance(pkt, TftpPacketACK):
log.debug("Received ACK for packet %d" % pkt.blocknumber)
# Is this an ack to the one we just sent?
if self.context.next_block == pkt.blocknumber:
if self.context.pending_complete:
log.info("Received ACK to final DAT, we're done.")
return None
else:
log.debug("Good ACK, sending next DAT")
self.context.next_block += 1
log.debug("Incremented next_block to %d",
self.context.next_block)
self.context.pending_complete = self.sendDAT()
elif pkt.blocknumber < self.context.next_block:
log.warning("Received duplicate ACK for block %d"
% pkt.blocknumber)
self.context.metrics.add_dup(pkt)
else:
log.warning("Oooh, time warp. Received ACK to packet we "
"didn't send yet. Discarding.")
self.context.metrics.errors += 1
return self
elif isinstance(pkt, TftpPacketERR):
log.error("Received ERR packet from peer: %s" % str(pkt))
raise TftpException("Received ERR packet from peer: %s" % str(pkt))
else:
log.warning("Discarding unsupported packet: %s" % str(pkt))
return self
class TftpStateExpectDAT(TftpState):
"""Just sent an ACK packet. Waiting for DAT."""
def handle(self, pkt, raddress, rport):
"""Handle the packet in response to an ACK, which should be a DAT."""
if isinstance(pkt, TftpPacketDAT):
return self.handleDat(pkt)
# Every other packet type is a problem.
elif isinstance(pkt, TftpPacketACK):
# Umm, we ACK, you don't.
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("Received ACK from peer when expecting DAT")
elif isinstance(pkt, TftpPacketWRQ):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("Received WRQ from peer when expecting DAT")
elif isinstance(pkt, TftpPacketERR):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("Received ERR from peer: " + str(pkt))
else:
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("Received unknown packet type from peer: " + str(pkt))
class TftpStateSentWRQ(TftpState):
"""Just sent an WRQ packet for an upload."""
def handle(self, pkt, raddress, rport):
"""Handle a packet we just received."""
if not self.context.tidport:
self.context.tidport = rport
log.debug("Set remote port for session to %s", rport)
# If we're going to successfully transfer the file, then we should see
# either an OACK for accepted options, or an ACK to ignore options.
if isinstance(pkt, TftpPacketOACK):
log.info("Received OACK from server")
try:
self.handleOACK(pkt)
except TftpException:
log.error("Failed to negotiate options")
self.sendError(TftpErrors.FailedNegotiation)
raise
else:
log.debug("Sending first DAT packet")
self.context.pending_complete = self.sendDAT()
log.debug("Changing state to TftpStateExpectACK")
return TftpStateExpectACK(self.context)
elif isinstance(pkt, TftpPacketACK):
log.info("Received ACK from server")
log.debug("Apparently the server ignored our options")
# The block number should be zero.
if pkt.blocknumber == 0:
log.debug("Ack blocknumber is zero as expected")
log.debug("Sending first DAT packet")
self.context.pending_complete = self.sendDAT()
log.debug("Changing state to TftpStateExpectACK")
return TftpStateExpectACK(self.context)
else:
log.warning("Discarding ACK to block %s" % pkt.blocknumber)
log.debug("Still waiting for valid response from server")
return self
elif isinstance(pkt, TftpPacketERR):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("Received ERR from server: %s" % pkt)
elif isinstance(pkt, TftpPacketRRQ):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("Received RRQ from server while in upload")
elif isinstance(pkt, TftpPacketDAT):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("Received DAT from server while in upload")
else:
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("Received unknown packet type from server: %s" % pkt)
# By default, no state change.
return self
class TftpStateSentRRQ(TftpState):
"""Just sent an RRQ packet."""
def handle(self, pkt, raddress, rport):
"""Handle the packet in response to an RRQ to the server."""
if not self.context.tidport:
self.context.tidport = rport
log.info("Set remote port for session to %s" % rport)
# Now check the packet type and dispatch it properly.
if isinstance(pkt, TftpPacketOACK):
log.info("Received OACK from server")
try:
self.handleOACK(pkt)
except TftpException as err:
log.error("Failed to negotiate options: %s" % str(err))
self.sendError(TftpErrors.FailedNegotiation)
raise
else:
log.debug("Sending ACK to OACK")
self.sendACK(blocknumber=0)
log.debug("Changing state to TftpStateExpectDAT")
return TftpStateExpectDAT(self.context)
elif isinstance(pkt, TftpPacketDAT):
# If there are any options set, then the server didn't honour any
# of them.
log.info("Received DAT from server")
if self.context.options:
log.info("Server ignored options, falling back to defaults")
self.context.options = { 'blksize': DEF_BLKSIZE }
return self.handleDat(pkt)
# Every other packet type is a problem.
elif isinstance(pkt, TftpPacketACK):
# Umm, we ACK, the server doesn't.
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("Received ACK from server while in download")
elif isinstance(pkt, TftpPacketWRQ):
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("Received WRQ from server while in download")
elif isinstance(pkt, TftpPacketERR):
self.sendError(TftpErrors.IllegalTftpOp)
log.debug("Received ERR packet: %s", pkt)
if pkt.errorcode == TftpErrors.FileNotFound:
raise TftpFileNotFoundError("File not found")
else:
raise TftpException("Received ERR from server: {}".format(pkt))
else:
self.sendError(TftpErrors.IllegalTftpOp)
raise TftpException("Received unknown packet type from server: %s" % pkt)
# By default, no state change.
return self

View File

@ -1,27 +0,0 @@
# vim: ts=4 sw=4 et ai:
# -*- coding: utf8 -*-
"""
This library implements the tftp protocol, based on rfc 1350.
http://www.faqs.org/rfcs/rfc1350.html
At the moment it implements only a client class, but will include a server,
with support for variable block sizes.
As a client of tftpy, this is the only module that you should need to import
directly. The TftpClient and TftpServer classes can be reached through it.
"""
import sys
# Make sure that this is at least Python 2.7
required_version = (2, 7)
if sys.version_info < required_version:
raise ImportError("Requires at least Python 2.7")
from .TftpShared import *
from . import TftpPacketTypes
from . import TftpPacketFactory
from .TftpClient import TftpClient
from .TftpServer import TftpServer
from . import TftpContexts
from . import TftpStates

45
tester/rtems-tftp-server Executable file
View File

@ -0,0 +1,45 @@
#! /usr/bin/env python
# SPDX-License-Identifier: BSD-2-Clause
'''A command line standalone TFTP Server. This is useful when testing
and setting up a TFTP target.'''
# Copyright (C) 2020 Chris Johns (chrisj@rtems.org)
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# pylint: disable=invalid-name
from __future__ import print_function
import os
import sys
base = os.path.dirname(os.path.abspath(sys.argv[0]))
rtems = os.path.dirname(base)
sys.path = [rtems] + sys.path
try:
import rt.tftpserver
rt.tftpserver.run(sys.argv)
except ImportError:
print("Incorrect RTEMS Tools installation", file=sys.stderr)
sys.exit(1)