#!/bin/python3
import sys,re
from pprint import pprint
sys.path.insert(0, '../../')
from fred import list2int

input_f = 'input'

part = 2
#########################################
#                                       #
#              Part 1                   #
#                                       #
#########################################

def visit(next,visited,records):
    #print('Starting at',next, ' having visited',visited)
    #print(records[next])
    if len(records[next]) == 1:
        if records[next][0] not in visited:
            visited.append(records[next][0])
            next = records[next][0]
            visit(next,visited,records)
    else:
        for i in records[next]:
            if i not in visited:
                visited.append(i)
                visit(i,visited,records)
    return visited

if part == 1:

    records = {}
    visited = []

    with open(input_f) as file:
        for line in file:
            l = line.rstrip().replace(' ','')
            if '<->' in l:
                x = l.split('<->')
                records[int(x[0])] = list2int(x[1].split(','))
                
    #pprint(records)
    start_node = 0

    visited = visit(start_node,visited,records)
                

    print(visited)
    print(len(visited))


#########################################
#                                       #
#              Part 2                   #
#                                       #
#########################################
def visit(next,visited,records):
    #print('Starting at',next, ' having visited',visited)
    #print(records[next])
    if len(records[next]) == 1:
        if records[next][0] not in visited:
            visited.append(records[next][0])
            next = records[next][0]
            visit(next,visited,records)
    else:
        for i in records[next]:
            if i not in visited:
                visited.append(i)
                visit(i,visited,records)
    return visited

if part == 2:

    records = {}
    visited = []

    with open(input_f) as file:
        for line in file:
            l = line.rstrip().replace(' ','')
            if '<->' in l:
                x = l.split('<->')
                records[int(x[0])] = list2int(x[1].split(','))
                
    #pprint(records)
    start_node = 0

    groups = []

    for i in range(0,len(records)):
        if not any(i in sublist for sublist in groups):
            #print(i, ' not in ', groups)
            visited = visit(i,visited,records)
            groups.append(visited)
            #print('Visited',visited)
            visited = []
            #print(len(visited))
        
    pprint(groups)
    print(len(groups))