package main

import (
	"flag"
	"fmt"
	"net"
	"strconv"
	"strings"

	log "github.com/sirupsen/logrus"
)

func startManifold(listenSocket *net.UDPConn, sendSockets []*net.UDPConn) error {
	var buffer [10000]byte
	for {
		length, _, err := listenSocket.ReadFromUDP(buffer[:])
		if err != nil {
			return err
		}
		for _, conn := range sendSockets {
			go conn.Write(buffer[:length])
		}
	}
}

func openReceiveSocket(listenPort int, multicastGroup string) (*net.UDPConn, error) {
	var err error
	var listenSocket *net.UDPConn
	var resolvedMulticastGroup *net.IPAddr
	if len(multicastGroup) > 0 {
		resolvedMulticastGroup, err = net.ResolveIPAddr("ip4", multicastGroup)
		if err != nil {
			return nil, err
		}
		listenSocket, err = net.ListenMulticastUDP("udp4", nil, &net.UDPAddr{
			IP: resolvedMulticastGroup.IP, Zone: resolvedMulticastGroup.Zone, Port: listenPort})
	} else {
		listenSocket, err = net.ListenUDP("udp4", &net.UDPAddr{
			IP: resolvedMulticastGroup.IP, Zone: resolvedMulticastGroup.Zone, Port: listenPort})
	}
	if err != nil {
		return nil, err
	}
	return listenSocket, nil
}

func printError(err error) {
	log.WithFields(log.Fields{
		"error": err.Error(),
	}).Error("Unhandled error")
}

func createSendConnections(sendAddressString string) ([]*net.UDPConn, error) {
	var connections []*net.UDPConn

	for _, rawString := range strings.Split(sendAddressString, ",") {
		parts := strings.Split(rawString, ":")
		if len(parts) != 2 {
			return nil, fmt.Errorf("%[1]s is an invalid address", rawString)
		}
		ip, err := net.ResolveIPAddr("ip4", parts[0])
		if err != nil {
			return nil, err
		}
		port, err := strconv.Atoi(parts[1])
		if err != nil {
			return nil, err
		}
		conn, err := net.DialUDP("udp", nil, &net.UDPAddr{IP: ip.IP, Port: port})
		if err != nil {
			return nil, err
		}
		connections = append(connections, conn)
	}

	return connections, nil
}

func main() {

	listenPort := flag.Int("listenPort", 4000, "Udp port to listen on")
	multicastAddress := flag.String("multicastGroup", "230.0.0.1", "Optional mutlicast group to subscribe to")
	sendAddresses := flag.String("sendAddresses", "localhost:4001", "Udp addresses to send out on (comma separated")

	flag.Parse()
	receiveSocket, err := openReceiveSocket(*listenPort, *multicastAddress)
	if err != nil {
		printError(err)
		return
	}
	sendSockets, err := createSendConnections(*sendAddresses)
	if err != nil {
		printError(err)
		return
	}
	err = startManifold(receiveSocket, sendSockets)
	if err != nil {
		printError(err)
		return
	}
}