mirror of
https://git.rtems.org/rtems-tools/
synced 2025-05-15 19:17:49 +08:00

- Uses a config INI file to map clients to servers - Handle a number of requests to a single server's TFTP port (69) and multiplex to a non-su ports or different servers. - Supports running rtems-test to more than one hardware device using TFTP at once.
424 lines
15 KiB
Python
424 lines
15 KiB
Python
#
|
|
# Copyright 2019 Chris Johns (chris@contemporary.software)
|
|
# All rights reserved.
|
|
#
|
|
# Permission to use, copy, modify, and/or distribute this software for any
|
|
# purpose with or without fee is hereby granted, provided that the above
|
|
# copyright notice and this permission notice appear in all copies.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
|
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
|
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
|
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
|
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
|
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
|
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
|
|
|
#
|
|
# The TFTP proxy redirects a TFTP session to another host. If you have a
|
|
# farm of boards you can configure them to point to this proxy and it will
|
|
# redirect the requests to another machine that is testing it.
|
|
#
|
|
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import os
|
|
import socket
|
|
import sys
|
|
import time
|
|
import threading
|
|
|
|
try:
|
|
import socketserver
|
|
except:
|
|
import SocketServer as socketserver
|
|
|
|
from rtemstoolkit import configuration
|
|
from rtemstoolkit import error
|
|
from rtemstoolkit import log
|
|
from rtemstoolkit import version
|
|
|
|
import getmac
|
|
|
|
def host_port_split(ip_port):
|
|
ips = ip_port.split(':')
|
|
port = 0
|
|
if len(ips) >= 1:
|
|
ip = ips[0]
|
|
if len(ips) == 2:
|
|
port = int(ips[1])
|
|
else:
|
|
raise error.general('invalid host:port: %s' % (ip_port))
|
|
return ip, port
|
|
|
|
class tftp_session(object):
|
|
|
|
opcodes = ['nul', 'RRQ', 'WRQ', 'DATA', 'ACK', 'ERROR', 'OACK']
|
|
|
|
def __init__(self):
|
|
self.packets = []
|
|
self.block = 0
|
|
self.block_size = 512
|
|
self.timeout = 0
|
|
self.finished = True
|
|
|
|
def __str__(self):
|
|
return os.linesep.join([self.decode(p[0], p[1], p[2]) for p in self.packets])
|
|
|
|
def data(self, host, port, data):
|
|
finished = False
|
|
self.packets += [(host, port, data)]
|
|
opcode = (data[0] << 8) | data[1]
|
|
if opcode == 1 or opcode == 2:
|
|
self.block = 0
|
|
self.finished = False
|
|
value = self.get_option('timeout', data)
|
|
if value is not None:
|
|
self.timeout = int(value)
|
|
value = self.get_option('blksize', data)
|
|
if value is not None:
|
|
self.block_size = int(value)
|
|
else:
|
|
self.block_size = 512
|
|
elif opcode == 3:
|
|
self.block = (data[2] << 8) | data[3]
|
|
if len(data) - 4 < self.block_size:
|
|
self.finished = True
|
|
elif opcode == 4:
|
|
self.block = (data[2] << 8) | data[3]
|
|
if self.finished:
|
|
finished = True
|
|
return finished
|
|
|
|
def decode(self, host, port, data):
|
|
s = ''
|
|
dlen = len(data)
|
|
if dlen > 2:
|
|
opcode = (data[0] << 8) | data[1]
|
|
if opcode < len(self.opcodes):
|
|
if opcode == 1 or opcode == 2:
|
|
s += ' ' + self.opcodes[opcode] + ', '
|
|
i = 2
|
|
while data[i] != 0:
|
|
s += chr(data[i])
|
|
i += 1
|
|
while i < dlen - 1:
|
|
s += ', '
|
|
i += 1
|
|
while data[i] != 0:
|
|
s += chr(data[i])
|
|
i += 1
|
|
elif opcode == 3:
|
|
block = (data[2] << 8) | data[3]
|
|
s += ' ' + self.opcodes[opcode] + ', '
|
|
s += '#' + str(block) + ', '
|
|
if dlen > 4:
|
|
s += '%02x%02x..%02x%02x' % (data[4], data[5], data[-2], data[-1])
|
|
else:
|
|
s += '%02x%02x%02x%02x' % (data[4], data[5], data[6], data[6])
|
|
s += ',' + str(dlen - 4)
|
|
elif opcode == 4:
|
|
block = (data[2] << 8) | data[3]
|
|
s += ' ' + self.opcodes[opcode] + ' ' + str(block)
|
|
elif opcode == 5:
|
|
s += 'E ' + self.opcodes[opcode] + ', '
|
|
s += str((data[2] << 8) | (data[3]))
|
|
i = 2
|
|
while data[i] != 0:
|
|
s += chr(data[i])
|
|
i += 1
|
|
elif opcode == 6:
|
|
s += ' ' + self.opcodes[opcode]
|
|
i = 1
|
|
while i < dlen - 1:
|
|
s += ', '
|
|
i += 1
|
|
while data[i] != 0:
|
|
s += chr(data[i])
|
|
i += 1
|
|
else:
|
|
s += 'E INV(%d)' % (opcode)
|
|
else:
|
|
s += 'E INVALID LENGTH'
|
|
return s[:2] + '[%s:%d] ' % (host, port) + s[2:]
|
|
|
|
def get_option(self, option, data):
|
|
dlen = len(data)
|
|
opcode = (data[0] << 8) | data[1]
|
|
next_option = False
|
|
if opcode == 1 or opcode == 2:
|
|
i = 1
|
|
while i < dlen - 1:
|
|
o = ''
|
|
i += 1
|
|
while data[i] != 0:
|
|
o += chr(data[i])
|
|
i += 1
|
|
if o == option:
|
|
next_option = True
|
|
elif next_option:
|
|
return o
|
|
return None
|
|
|
|
def get_timeout(self, default_timeout, timeout_guard):
|
|
if self.timeout == 0:
|
|
return self.timeout + timeout_guard
|
|
return default_timeout
|
|
|
|
def get_block_size(self):
|
|
return self.block_size
|
|
|
|
class udp_handler(socketserver.BaseRequestHandler):
|
|
|
|
def handle(self):
|
|
client_ip = self.client_address[0]
|
|
client_port = self.client_address[1]
|
|
client = '%s:%i' % (client_ip, client_port)
|
|
session = tftp_session()
|
|
finished = session.data(client_ip, client_port, self.request[0])
|
|
if not finished:
|
|
timeout = session.get_timeout(self.server.proxy.session_timeout, 1)
|
|
host = self.server.proxy.get_host(client_ip)
|
|
if host is not None:
|
|
session_count = self.server.proxy.get_session_count()
|
|
log.notice(' ] %6d: session: %s -> %s: start' % (session_count,
|
|
client,
|
|
host))
|
|
host_ip, host_server_port = host_port_split(host)
|
|
host_port = host_server_port
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.settimeout(timeout)
|
|
log.trace(' > ' + session.decode(client_ip,
|
|
client_port,
|
|
self.request[0]))
|
|
sock.sendto(self.request[0], (host_ip, host_port))
|
|
while not finished:
|
|
try:
|
|
data, address = sock.recvfrom(16 * 1024)
|
|
except socket.error as se:
|
|
log.notice(' ] session: %s -> %s: error: %s' % (client,
|
|
host,
|
|
se))
|
|
return
|
|
except socket.gaierror as se:
|
|
log.notice(' ] session: %s -> %s: error: %s' % (client,
|
|
host,
|
|
se))
|
|
return
|
|
except:
|
|
return
|
|
finished = session.data(address[0], address[1], data)
|
|
if address[0] == host_ip:
|
|
if host_port == host_server_port:
|
|
host_port = address[1]
|
|
if address[1] == host_port:
|
|
log.trace(' < ' + session.decode(address[0],
|
|
address[1],
|
|
data))
|
|
sock.sendto(data, (client_ip, client_port))
|
|
elif address[0] == client_ip and address[1] == client_port:
|
|
log.trace(' > ' + session.decode(address[0],
|
|
address[1],
|
|
data))
|
|
sock.sendto(data, (host_ip, host_port))
|
|
log.notice(' ] %6d: session: %s -> %s: end' % (session_count,
|
|
client,
|
|
host))
|
|
else:
|
|
mac = getmac.get_mac_address(ip = client_ip)
|
|
log.trace(' . request: host not found: %s (%s)' % (client, mac))
|
|
|
|
class udp_server(socketserver.ThreadingMixIn, socketserver.UDPServer):
|
|
pass
|
|
|
|
class proxy_server(object):
|
|
def __init__(self, config, host, port):
|
|
self.lock = threading.Lock()
|
|
self.session_timeout = 10
|
|
self.host = host
|
|
self.port = port
|
|
self.server = None
|
|
self.clients = { }
|
|
self.config = configuration.configuration()
|
|
self._load(config)
|
|
self.session_counter = 0
|
|
|
|
def __del__(self):
|
|
self.stop()
|
|
|
|
def _lock(self):
|
|
self.lock.acquire()
|
|
|
|
def _unlock(self):
|
|
self.lock.release()
|
|
|
|
def _load_client(self, client, depth = 0):
|
|
if depth > 32:
|
|
raise error.general('\'clients\'" nesting too deep; circular?')
|
|
if not self.config.has_section(client):
|
|
raise error.general('client not found: %s' % (client))
|
|
for c in self.config.comma_list(client, 'clients', err = False):
|
|
self._load_client(c, depth + 1)
|
|
if client in self.clients:
|
|
raise error.general('repeated client: %s' % (client))
|
|
host = self.config.get_item(client, 'host', err = False)
|
|
if host is not None:
|
|
ips = self.config.comma_list(client, 'ip', err = False)
|
|
macs = self.config.comma_list(client, 'mac', err = False)
|
|
if len(ips) != 0 and len(macs) != 0:
|
|
raise error.general('client has ip and mac: %s' % (client))
|
|
if len(ips) != 0:
|
|
keys = ips
|
|
elif len(macs) != 0:
|
|
keys = macs
|
|
else:
|
|
raise error.general('not client ip or mac: %s' % (client))
|
|
for key in keys:
|
|
self.clients[key] = host
|
|
|
|
def _load(self, config):
|
|
self.config.load(config)
|
|
clients = self.config.comma_list('default', 'clients', err = False)
|
|
if len(clients) == 0:
|
|
raise error.general('\'clients\'" entry not found in config [defaults]')
|
|
for client in clients:
|
|
self._load_client(client)
|
|
|
|
def start(self):
|
|
log.notice('Proxy: %s:%i' % (self.host, self.port))
|
|
if self.host == 'all':
|
|
host = ''
|
|
else:
|
|
host = self.host
|
|
try:
|
|
self.server = udp_server((host, self.port), udp_handler)
|
|
except Exception as e:
|
|
raise error.general('proxy create: %s' % (e))
|
|
self.server.proxy = self
|
|
self._lock()
|
|
try:
|
|
self.server_thread = threading.Thread(target = self.server.serve_forever)
|
|
self.server_thread.daemon = True
|
|
self.server_thread.start()
|
|
finally:
|
|
self._unlock()
|
|
|
|
def stop(self):
|
|
self._lock()
|
|
try:
|
|
if self.server is not None:
|
|
self.server.shutdown()
|
|
self.server.server_close()
|
|
self.server = None
|
|
finally:
|
|
self._unlock()
|
|
|
|
def run(self):
|
|
while True:
|
|
time.sleep(1)
|
|
|
|
def get_host(self, client):
|
|
host = None
|
|
self._lock()
|
|
try:
|
|
if client in self.clients:
|
|
host = self.clients[client]
|
|
else:
|
|
mac = getmac.get_mac_address(ip = client)
|
|
if mac in self.clients:
|
|
host = self.clients[mac]
|
|
finally:
|
|
self._unlock()
|
|
return host
|
|
|
|
def get_session_count(self):
|
|
count = 0
|
|
self._lock()
|
|
try:
|
|
self.session_counter += 1
|
|
count = self.session_counter
|
|
finally:
|
|
self._unlock()
|
|
return count
|
|
|
|
|
|
def load_log(logfile):
|
|
if logfile is None:
|
|
log.default = log.log(streams = ['stdout'])
|
|
else:
|
|
log.default = log.log(streams = [logfile])
|
|
|
|
def run(args = sys.argv, command_path = None):
|
|
ec = 0
|
|
notice = None
|
|
proxy = None
|
|
try:
|
|
description = 'Proxy TFTP sessions from the host running this proxy'
|
|
description += 'to hosts and ports defined in the configuration file. '
|
|
description += 'The tool lets you create a farm of hardware and to run '
|
|
description += 'more than one TFTP test session on a host or multiple '
|
|
description += 'hosts at once. This proxy service is not considered secure'
|
|
description += 'and is for use in a secure environment.'
|
|
|
|
argsp = argparse.ArgumentParser(prog = 'rtems-tftp-proxy',
|
|
description = description)
|
|
argsp.add_argument('-l', '--log',
|
|
help = 'log file.',
|
|
type = str, default = None)
|
|
argsp.add_argument('-v', '--trace',
|
|
help = 'enable trace logging for debugging.',
|
|
action = 'store_true', default = False)
|
|
argsp.add_argument('-c', '--config',
|
|
help = 'proxy configuation (default: %(default)s).',
|
|
type = str, default = None)
|
|
argsp.add_argument('-B', '--bind',
|
|
help = 'address to bind the proxy too (default: %(default)s).',
|
|
type = str, default = 'all')
|
|
argsp.add_argument('-P', '--port',
|
|
help = 'port to bind the proxy too(default: %(default)s).',
|
|
type = int, default = '69')
|
|
|
|
argopts = argsp.parse_args(args[1:])
|
|
|
|
load_log(argopts.log)
|
|
log.notice('RTEMS Tools - TFTP Proxy, %s' % (version.string()))
|
|
log.output(log.info(args))
|
|
log.tracing = argopts.trace
|
|
|
|
if argopts.config is None:
|
|
raise error.general('no config file, see -h')
|
|
|
|
proxy = proxy_server(argopts.config, argopts.bind, argopts.port)
|
|
|
|
try:
|
|
proxy.start()
|
|
proxy.run()
|
|
except:
|
|
proxy.stop()
|
|
raise
|
|
|
|
except error.general as gerr:
|
|
notice = str(gerr)
|
|
ec = 1
|
|
except error.internal as ierr:
|
|
notice = str(ierr)
|
|
ec = 1
|
|
except error.exit as eerr:
|
|
pass
|
|
except KeyboardInterrupt:
|
|
notice = 'abort: user terminated'
|
|
ec = 1
|
|
except:
|
|
raise
|
|
notice = 'abort: unknown error'
|
|
ec = 1
|
|
if proxy is not None:
|
|
del proxy
|
|
if notice is not None:
|
|
log.stderr(notice)
|
|
sys.exit(ec)
|
|
|
|
if __name__ == "__main__":
|
|
run()
|