diff --git a/secdep.py b/secdep.py index 48f10a6..31cabfc 100755 --- a/secdep.py +++ b/secdep.py @@ -680,7 +680,7 @@ def choose_from_list(listFromlistFunction,listName): elif listName == "azureSize": printFormat = "{}) {}\n\nRam: {}\nDisk: {}\nPrice: {}\n" printstring = "print(printFormat.format(count, item.name, item.ram, item.disk, item.price))" - elif listName == "awsImage" or listName == "azureImage" or listName == "awsRegion": + elif listName == "awsImage" or listName == "azureImage" or listName == "awsRegion" or listName == "aws_region": printFormat = "{}) {}" printstring = "print(printFormat.format(count, item))" elif listName == "gceImage": @@ -1227,8 +1227,8 @@ def list_all_nodes(provider, filterIn=None, awsRegion=None): elif provider == "aws": if SECDEP_AWS_ACCESS_KEY != "": print("Getting AWS nodes...") + awsLocations = ["ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-south-1", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-north-1", "eu-west-1", "eu-west-2", "eu-west-3", "sa-east-1", "us-east-1", "us-east-2", "us-west-1", "us-west-2"] if awsRegion is None: - awsLocations = ["ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-south-1", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-north-1", "eu-west-1", "eu-west-2", "eu-west-3", "sa-east-1", "us-east-1", "us-east-2", "us-west-1", "us-west-2"] for region in awsLocations: driver3 = get_driver(Provider.EC2)(SECDEP_AWS_ACCESS_KEY, SECDEP_AWS_SECRET_KEY, region=region) # make it so it tries all drivers @@ -1237,6 +1237,10 @@ def list_all_nodes(provider, filterIn=None, awsRegion=None): for node in awsNodes: nodes.append(node) else: + if awsRegion not in awsLocations: + print("Invalid region") + awsRegion = choose_from_list(listAWSregions(awsLocations),"aws_region") + assert awsRegion is not None, "You chose an invalid aws region so we can't continue unless you choose a corect one" driver3 = get_driver(Provider.EC2)(SECDEP_AWS_ACCESS_KEY, SECDEP_AWS_SECRET_KEY, region=awsRegion) awsNodes = driver3.list_nodes() if len(awsNodes) > 0: